How to use a custom model with fastai cnn_learner?

fastai
Published

February 11, 2022

Code
def pprint_model(model, truncate=64):
    print("="*80)
    print("Model modules:")
    print("="*80)
    print()
    for i, (name, module) in enumerate(model.named_children()):
        desc = repr(module).replace('\n', '').replace('  ', ' ')[:64] + '...'
        print(f"{i+1} - {desc}\n")

fastai library offers many pre-trained models for vision tasks. However, we sometimes need to use a custom model available in another library or created from scratch. In this post, we’ll see how to use fastai’s cnn_learner with a custom model.

cnn_learner is a utility function which creates a Learner from given a pretrained CNN architecture such as resnet18.

def cnn_learner(dls, arch, normalize=True, n_out=None, pretrained=True, config=None,
                # learner args
                loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=None, cbs=None, metrics=None, path=None,
                model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95,0.85,0.95),
                # other model args
                **kwargs):
    "Build a convnet style learner from `dls` and `arch`"
    ...
    meta = model_meta.get(arch, _default_meta)
    if normalize: _add_norm(dls, meta, pretrained)

    if n_out is None: n_out = get_c(dls)
    assert n_out, "`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`"
    model = create_cnn_model(arch, n_out, pretrained=pretrained, **kwargs)

    splitter=ifnone(splitter, meta['split'])
    learn = Learner(dls=dls, model=model, loss_func=loss_func, opt_func=opt_func, lr=lr, splitter=splitter, cbs=cbs,
                   metrics=metrics, path=path, model_dir=model_dir, wd=wd, wd_bn_bias=wd_bn_bias, train_bn=train_bn,
                   moms=moms)
    if pretrained: learn.freeze()
    ...
    return learn

To do that, it uses the model metadata from model_meta registry. model_meta registry is simply a mapping (dictionary) from architecture to its metadata.

def _xresnet_split(m): return L(m[0][:3], m[0][3:], m[1:]).map(params)

model_meta = {
    models.xresnet.xresnet18 :{
        'cut':-4, 
        'split':_xresnet_split, 
        'stats':imagenet_stats
    },
    ...
}

The cut value is used for stripping off the existing classification head of the network so that we can add a custom head and fine-tune it for our task.

The split function is used when discriminative learning rate schema is applied such that the layers of a model are trained with different learning rates.

The stats refer to the channel means and standard deviations of the images in ImageNet dataset, which the model is pretrained on.

There are two alternative ways to to use a custom model not present in model registry: 1. Create a new helper function similar to cnn_learner that splits the network into backbone and head. Check out Zachary Mueller’s awesome blog post to see how it’s done. 2. Register the architecture in model_meta and use cnn_learner.

We will cover the second option in this post.

Let’s first inspect an architecture registered already, e.g. resnet18.

Here is its model meta data from the registry:

from fastai.vision.all import *
model_meta[resnet18]
{'cut': -2,
 'split': <function fastai.vision.learner._resnet_split>,
 'stats': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])}

And the model layers:

m = resnet18()
pprint_model(m)
================================================================================
Model modules:
================================================================================

1 - Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3),...

2 - BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runn...

3 - ReLU(inplace=True)...

4 - MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_m...

5 - Sequential( (0): BasicBlock(  (conv1): Conv2d(64, 64, kernel_siz...

6 - Sequential( (0): BasicBlock(  (conv1): Conv2d(64, 128, kernel_si...

7 - Sequential( (0): BasicBlock(  (conv1): Conv2d(128, 256, kernel_s...

8 - Sequential( (0): BasicBlock(  (conv1): Conv2d(256, 512, kernel_s...

9 - AdaptiveAvgPool2d(output_size=(1, 1))...

10 - Linear(in_features=512, out_features=1000, bias=True)...

create_body function called by create_cnn_model which is called in cnn_learner, strips off the head by cut index as such:

...
if   isinstance(cut, int): return nn.Sequential(*list(model.children())[:cut])
...

In our case, it’ll remove the last two layers of resnet18 network: AdaptiveAvgPool2d and fully connected Linear layer.

body = create_body(resnet18, pretrained=False, cut=model_meta[resnet18]['cut'])
pprint_model(body)
================================================================================
Model modules:
================================================================================

1 - Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3),...

2 - BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runn...

3 - ReLU(inplace=True)...

4 - MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_m...

5 - Sequential( (0): BasicBlock(  (conv1): Conv2d(64, 64, kernel_siz...

6 - Sequential( (0): BasicBlock(  (conv1): Conv2d(64, 128, kernel_si...

7 - Sequential( (0): BasicBlock(  (conv1): Conv2d(128, 256, kernel_s...

8 - Sequential( (0): BasicBlock(  (conv1): Conv2d(256, 512, kernel_s...

Similarly, we need to determine the cut index for the custom model we use. Let’s try EfficientNetB0 architecture available in torchvision library. First, we inspect the network layers to find out where to split it into backbone and head.

from torchvision.models import efficientnet_b0

m = efficientnet_b0()
pprint_model(m)
================================================================================
Model modules:
================================================================================

1 - Sequential( (0): ConvNormActivation(  (0): Conv2d(3, 32, kernel_...

2 - AdaptiveAvgPool2d(output_size=1)...

3 - Sequential( (0): Dropout(p=0.2, inplace=True) (1): Linear(in_fea...

As it can be seen, the pooling layer is at index -2, which corresponds to the cut value. We’ll use the default_split for split function and ImageNet stats as the model is pre-trained on it.

from fastai.vision.learner import default_split
model_meta[efficientnet_b0] = {'cut': -2, 'split': default_split, 'stats': imagenet_stats}

Train and test the model

Now we can create a cnn_learner since our custom architecture is registered. Let’s create a toy dataloaders to train and test our model.

def label_func(f): 
    return f[0].isupper()

path = untar_data(URLs.PETS)
files = get_image_files(path / "images")
pattern = r'^(.*)_\d+.jpg'
dls = ImageDataLoaders.from_name_re(path, files, pattern, item_tfms=Resize(224))

dls.show_batch()
100.00% [811712512/811706944 00:20<00:00]

learn = cnn_learner(
    dls, 
    arch=efficientnet_b0,
    pretrained=True,
    metrics=accuracy, 
).to_fp16()
Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-3dd342df.pth

Let’s verify that the body and head are created correctly.

pprint_model(learn.model)
================================================================================
Model modules:
================================================================================

1 - Sequential( (0): Sequential(  (0): ConvNormActivation(   (0): Co...

2 - Sequential( (0): AdaptiveConcatPool2d(  (ap): AdaptiveAvgPool2d(...

Let’s inspect the custom head of the model:

pprint_model(learn.model[-1])
================================================================================
Model modules:
================================================================================

1 - AdaptiveConcatPool2d( (ap): AdaptiveAvgPool2d(output_size=1) (mp...

2 - Flatten(full=False)...

3 - BatchNorm1d(2560, eps=1e-05, momentum=0.1, affine=True, track_ru...

4 - Dropout(p=0.25, inplace=False)...

5 - Linear(in_features=2560, out_features=512, bias=False)...

6 - ReLU(inplace=True)...

7 - BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_run...

8 - Dropout(p=0.5, inplace=False)...

9 - Linear(in_features=512, out_features=37, bias=False)...

As it’s seen, cnn_learner created a new classification head starting with a pooling layer while keeping the original body from the pre-trained model.

Let’s train and test our model for multi-label classification task.

learn.lr_find(start_lr=1e-05, end_lr=1e-1)
SuggestedLRs(valley=0.0014454397605732083)

learn.fine_tune(3, base_lr=2e-3, freeze_epochs=3)
epoch train_loss valid_loss accuracy time
0 2.515583 0.687396 0.801759 01:12
1 1.083641 0.469568 0.852503 01:08
2 0.687156 0.430938 0.866035 01:07
epoch train_loss valid_loss accuracy time
0 0.414339 0.401623 0.876861 01:27
1 0.351286 0.369699 0.872801 01:27
2 0.272294 0.349316 0.881597 01:27
learn.show_results()