In this tutorial, we cover multi-task learning with fastai, specifically, we train a video distortion detection model.

We train our model on a small subset of VSQuaD dataset [1], which consists of 1332 surveillance videos with 9 types of distortions and 36 reference videos without any distortions. Each video is 10 second long. We split each video into 1 second long clips and sample 5 frames from each clip.

The model takes a set of video frames which contains a type of distortion as input and outputs distortion class and severity level. Since the model makes multiple types of predictions for a given data point, it's a multi-task learning problem.

 

Data

The subset of VSQuaD dataset is hosted on Kaggle. We can download it from Kaggle website or via Kaggle CLI.

Dataset

The dataset is organized by folders for video's frames. Each folder has 10 video frames in jpeg format. The size of a frame is 480x270. The folder name contain scene, distortion class and severity level labels. There are 5 frames in 1 second video clip. Hence, we'll split the video frames per second and perform classification for each clip. We split the dataset into train and validation splits by strafying on video label. To prevent data leakage, the split is performed before videos are splitted into clips. In other words, all frames of a video are only in either of train and validation sets. Then, we create a dataframe where each row corresponds to a second long video clip and its labels.

video_name label distortion severity is_valid frames
1 airport_D1_1 (D1, S1) defocus-blur S1 False [0, 1, 2, 3, 4]
2 airport_D1_2 (D1, S2) defocus-blur S2 False [0, 1, 2, 3, 4]
3 airport_D1_3 (D1, S3) defocus-blur S3 False [0, 1, 2, 3, 4]
4 airport_D1_4 (D1, S4) defocus-blur S4 False [0, 1, 2, 3, 4]
5 airport_D5_1 (D5, S1) gaussian-white-noise S1 False [0, 1, 2, 3, 4]
... ... ... ... ... ... ...
460 vatican_square_night_D7_4 (D7, S4) smoke S4 False [20, 21, 22, 23, 24]
461 vatican_square_night_D8_1 (D8, S1) uneven-illumination S1 False [20, 21, 22, 23, 24]
462 vatican_square_night_D8_2 (D8, S2) uneven-illumination S2 False [20, 21, 22, 23, 24]
463 vatican_square_night_D8_3 (D8, S3) uneven-illumination S3 False [20, 21, 22, 23, 24]
464 vatican_square_night_D8_4 (D8, S4) uneven-illumination S4 False [20, 21, 22, 23, 24]

2200 rows × 6 columns

There are 5 distortion types in this subset of VSQuaD dataset and 4 severity levels. The reference videos, i.e. the videos with no distortion, has D0 label for distortion and S0 label for severity. We'll remove reference videos from the dataset as they're almost indistinguisable from severity 1 level distortions. Hence, there are 5 distortion classes and 4 severity classes to be predicted by the model. Since each video has only one type of distortion, this is a single-label multi-class classification problem.

Size of train set:	 1760
Size of validation set:	 440

Distortion types:

(#6) ['defocus-blur','gaussian-white-noise','smoke','uneven-illumination','haze','rain']

Severity levels:

(#4) ['S1','S2','S3','S4']

Dataloaders

Since the model input is video frames, we will create an ImageTuple type and ImageTupleBlock, similar to this fastai tutorial).

from fastai.vision.all import *
from fastcore.basics import fastuple

class ImageTuple(fastuple):
    @classmethod
    def create(cls, fns):
        return cls(tuple(PILImage.create(fn) for fn in fns))

    def show(self, ctx=None, **kwargs):
        t1 = self[0]
        if isinstance(t1, PILImage):
            return self._show_pil(ctx, **kwargs)
        elif isinstance(t1, Tensor):
            return self._show_tensor(ctx, **kwargs)
        else:
            return ctx

    def _show_pil(self, ctx, **kwargs):
        t1 = np.asarray(self[0])
        line = np.zeros((t1.shape[0], 10, t1.shape[2]), dtype=np.uint8)
        img = PILImage.create(np.concatenate([x for img in self for x in (np.asarray(img), line)][:-1], axis=1))
        return show_image(img, ctx=ctx, **kwargs)

    def _show_tensor(self, ctx, **kwargs):
        t1 = self[0]
        line = t1.new_zeros(t1.shape[0], t1.shape[1], 10)
        return show_image(torch.cat([x for img in self for x in (img, line)][:-1], dim=2), ctx=ctx, **kwargs)

    @property
    def shape(self):
        t1 = self[0]
        if isinstance(t1, Tensor):
            return t1.shape
        if isinstance(t1, PILImage):
            return np.array(t1).shape
        return t1.shape

def ImageTupleBlock():
    return TransformBlock(type_tfms=ImageTuple.create, batch_tfms=IntToFloatTensor)
import random

BS = 16

db = DataBlock(
    blocks=(ImageTupleBlock, CategoryBlock, CategoryBlock),
    splitter=ColSplitter('is_valid'),
    get_x=ColReader('frame_paths'),
    get_y=[ColReader('distortion'), ColReader('severity')],
    batch_tfms=[
        Normalize.from_stats(*imagenet_stats)
    ],
    n_inp=1
)
dls = db.dataloaders(df, bs=BS)

Here are frames for a few video clips.

Training

We use a simple CNN to classify distortion and severity level of frames. The model uses a pretrained ResNet18 model as backbone and has a classification head outputs logits per distortion and severity level. Since there are multiple images fed into the model, it averages logits over all frames.

from fastai.callback.hook import model_sizes

def get_num_features(module):
    return model_sizes(module, size=(300, 300))[-1][1]

class NaiveModel(Module):
    def __init__(self, arch, n_distortion, n_severity, pretrained=True):
        store_attr()
        self.encoder = TimeDistributed(create_body(arch, pretrained=pretrained))
        self.head = TimeDistributed(create_head(get_num_features(self.encoder.module), n_distortion + n_severity))
    
    def forward(self, x):
        feature_map = self.encoder(torch.stack(x, dim=1))
        out = self.head(feature_map).mean(dim=1)
        return [out[:, :self.n_distortion], out[:, self.n_distortion:]]
   
    @staticmethod
    def splitter(model): 
        return [params(model.encoder), params(model.head)]

model = NaiveModel(arch=resnet18, n_distortion=len(dls.vocab[0]), n_severity=len(dls.vocab[1])).to(DEVICE)
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth

We will use cross-entropy losses for both distortion and severity. Since it's more difficult to detect severity of a distortion, we set the weight of severity classification loss higher than distortion loss. fastmtl provides a utility class, CombinedLoss, which computes weighted sum of losses for each type of prediction that the model makes. It also preserves decodes and activation functions of fastai loss classes so that raw logits can be decoded.

from fastmtl.loss import CombinedLoss
loss_func = CombinedLoss(CrossEntropyLossFlat(), CrossEntropyLossFlat(), weight=[1.0, 3.0])

We will evaluate the model with F1-macro and F1-micro metrics for distortion; and accuracy metric for severity. fastai already provides these metrics; but, they expect single prediction tensor and single target. Whereas, our model produces a tuple of prediction tensors for distortion and severity level. We can work around this by defining a function for each metric we use and make them pick the right prediction, such as

def distortion_f1_macro(preds, *targets):
    return F1Score(average='macro')(preds[0].argmax(dim=-1), targets[0])

def distortion_f1_micro(preds, *targets):
    return F1Score(average='micro')(preds[0].argmax(dim=-1), targets[0])

def severity_accuracy(preds, *targets):
    return accuracy(preds[1], targets[1])

However, this is repetitive and more importantly, it converts class-based metrics to functional metrics, which behaves differently in fastai. For instance, in distortion_f1_macro, we have to apply argmax ourselves, which wouldn't be the case, if we were able to use F1Score directly.

fastmtl provides a few utilities for metrics to solve these problems. mtl_metrics function routes model predictions to the metrics provided by their order. For instance, for distortion, it routes the first output of model to the metrics in the first argument, thus to F1-macro and accuracy metrics. Check out the documentation if you are interested in more details on this.

from fastmtl.metric import mtl_metrics

distortion_f1_macro =  F1Score(average='macro')
distortion_f1_macro.name = 'distortion_f1(macro)'
distortion_f1_micro =  F1Score(average='micro')
distortion_f1_micro.name = 'distortion_f1(micro)'

severity_accuracy = accuracy

metrics = mtl_metrics([distortion_f1_macro, distortion_f1_micro], [severity_accuracy])
try:
    metrics[2].name = 'severity_accuracy'
except AttributeError:
    metrics[2].func.__name__ = 'severity_accuracy'

learn = Learner(
    dls, 
    model,
    loss_func=loss_func,
    metrics=metrics,
    splitter=model.splitter,
    cbs = [SaveModelCallback()],
)
if DEVICE.type!='cpu':
    learn = learn.to_fp16()
learn.freeze()
import warnings
from sklearn.exceptions import UndefinedMetricWarning

with warnings.catch_warnings():
    warnings.filterwarnings(action='ignore', category=UndefinedMetricWarning, module=r'.*')
    learn.fit_one_cycle(10, 3e-4)
epoch train_loss valid_loss distortion_f1(macro) distortion_f1(micro) severity_accuracy time
0 5.480896 5.333536 0.544070 0.588636 0.281818 00:50
1 4.033375 4.581513 0.802966 0.829545 0.370455 00:47
2 3.128713 4.275926 0.851705 0.868182 0.377273 00:47
3 2.537734 4.253480 0.856044 0.868182 0.370455 00:47
4 2.266960 3.817466 0.886277 0.897727 0.413636 00:48
5 2.019434 3.729045 0.901953 0.915909 0.406818 00:48
6 1.798262 3.675821 0.899470 0.911364 0.431818 00:47
7 1.673086 3.864920 0.904571 0.920455 0.406818 00:48
8 1.546667 3.702503 0.903532 0.920455 0.438636 00:47
9 1.550261 3.732250 0.905340 0.920455 0.447727 00:48
Better model found at epoch 0 with valid_loss value: 5.333535671234131.
Better model found at epoch 1 with valid_loss value: 4.581512928009033.
Better model found at epoch 2 with valid_loss value: 4.275925636291504.
Better model found at epoch 3 with valid_loss value: 4.253479957580566.
Better model found at epoch 4 with valid_loss value: 3.8174662590026855.
Better model found at epoch 5 with valid_loss value: 3.7290451526641846.
Better model found at epoch 6 with valid_loss value: 3.6758205890655518.

The performance of the model is not great as we only train on a small subset of the original dataset. For the purpose of this tutorial, it is sufficient that the model is learning and each metric is calculated correctly.

Evaluation

Let's evaluate the model on validation set and visualize results.

Now, let's check model performance per distortion class and per severity level.

import warnings
from sklearn.exceptions import UndefinedMetricWarning
from sklearn.metrics import classification_report, ConfusionMatrixDisplay
from sklearn.metrics import multilabel_confusion_matrix

warnings.filterwarnings(action='ignore', category=UndefinedMetricWarning, module=r'.*')

def clf_report(learn, dl, vocabs):
    probs, targets, preds = learn.get_preds(dl=dl, with_decoded=True)
    
    for i, (vocab, target, pred) in enumerate(zip(vocabs, targets, preds)):
        label_indices = list(range(len(vocab)))
        y_true, y_pred = target.cpu().numpy(), pred.cpu().numpy()
        print(classification_report(y_true, y_pred, labels=label_indices, target_names=vocab))

        fig, ax = plt.subplots(figsize=(16, 9))
        ConfusionMatrixDisplay.from_predictions(y_true, y_pred, labels=label_indices, display_labels=vocab, ax=ax)
    
    distortion_f1_macro_val = F1Score(average='macro')(preds[0], targets[0])
    distortion_accuracy_val = accuracy(probs[0], targets[0])
    severity_accuracy_val = accuracy(probs[1], targets[1])
    return dict(
        distortion_f1_macro_val=distortion_f1_macro_val,
        distortion_accuracy_val=distortion_accuracy_val,
        severity_accuracy_val=severity_accuracy_val,
    )

scores = clf_report(learn, dls.valid, dls.vocab)
                      precision    recall  f1-score   support

        defocus-blur       0.83      0.85      0.84       100
gaussian-white-noise       1.00      0.98      0.99       100
                haze       1.00      1.00      1.00        50
                rain       0.82      0.82      0.82        50
               smoke       1.00      0.93      0.96       100
 uneven-illumination       0.72      0.85      0.78        40

            accuracy                           0.91       440
           macro avg       0.90      0.90      0.90       440
        weighted avg       0.92      0.91      0.91       440

              precision    recall  f1-score   support

          S1       0.46      0.59      0.52       110
          S2       0.21      0.18      0.20       105
          S3       0.41      0.28      0.33       115
          S4       0.56      0.67      0.61       110

    accuracy                           0.43       440
   macro avg       0.41      0.43      0.41       440
weighted avg       0.41      0.43      0.42       440

Conclusion

Multi-task learning has been becoming more common in ML research and production since it motivates more capable and generalizable models. In this tutorial, we introduced fastmtl library offering utilities to make MTL with fastai easier.

References

  1. Zohaib Amjad Khan et al., “Video distortion detection and classification in the context of video surveillance,” in International Conference on Image Processing (ICIP) Grand Challenge Session, 2022.