In this tutorial, we cover multi-task learning with fastai, specifically, we train a video distortion detection model.
We train our model on a small subset of VSQuaD dataset [1], which consists of 1332 surveillance videos with 9 types of distortions and 36 reference videos without any distortions. Each video is 10 second long. We split each video into 1 second long clips and sample 5 frames from each clip.
The model takes a set of video frames which contains a type of distortion as input and outputs distortion class and severity level. Since the model makes multiple types of predictions for a given data point, it's a multi-task learning problem.
The subset of VSQuaD dataset is hosted on Kaggle. We can download it from Kaggle website or via Kaggle CLI.
The dataset is organized by folders for video's frames. Each folder has 10 video frames in jpeg format. The size of a frame is 480x270. The folder name contain scene, distortion class and severity level labels. There are 5 frames in 1 second video clip. Hence, we'll split the video frames per second and perform classification for each clip. We split the dataset into train and validation splits by strafying on video label. To prevent data leakage, the split is performed before videos are splitted into clips. In other words, all frames of a video are only in either of train and validation sets. Then, we create a dataframe where each row corresponds to a second long video clip and its labels.
There are 5 distortion types in this subset of VSQuaD dataset and 4 severity levels. The reference videos, i.e. the videos with no distortion, has D0
label for distortion and S0
label for severity. We'll remove reference videos from the dataset as they're almost indistinguisable from severity 1 level distortions. Hence, there are 5 distortion classes and 4 severity classes to be predicted by the model. Since each video has only one type of distortion, this is a single-label multi-class classification problem.
Distortion types:
Severity levels:
Since the model input is video frames, we will create an ImageTuple type and ImageTupleBlock, similar to this fastai tutorial).
from fastai.vision.all import *
from fastcore.basics import fastuple
class ImageTuple(fastuple):
@classmethod
def create(cls, fns):
return cls(tuple(PILImage.create(fn) for fn in fns))
def show(self, ctx=None, **kwargs):
t1 = self[0]
if isinstance(t1, PILImage):
return self._show_pil(ctx, **kwargs)
elif isinstance(t1, Tensor):
return self._show_tensor(ctx, **kwargs)
else:
return ctx
def _show_pil(self, ctx, **kwargs):
t1 = np.asarray(self[0])
line = np.zeros((t1.shape[0], 10, t1.shape[2]), dtype=np.uint8)
img = PILImage.create(np.concatenate([x for img in self for x in (np.asarray(img), line)][:-1], axis=1))
return show_image(img, ctx=ctx, **kwargs)
def _show_tensor(self, ctx, **kwargs):
t1 = self[0]
line = t1.new_zeros(t1.shape[0], t1.shape[1], 10)
return show_image(torch.cat([x for img in self for x in (img, line)][:-1], dim=2), ctx=ctx, **kwargs)
@property
def shape(self):
t1 = self[0]
if isinstance(t1, Tensor):
return t1.shape
if isinstance(t1, PILImage):
return np.array(t1).shape
return t1.shape
def ImageTupleBlock():
return TransformBlock(type_tfms=ImageTuple.create, batch_tfms=IntToFloatTensor)
import random
BS = 16
db = DataBlock(
blocks=(ImageTupleBlock, CategoryBlock, CategoryBlock),
splitter=ColSplitter('is_valid'),
get_x=ColReader('frame_paths'),
get_y=[ColReader('distortion'), ColReader('severity')],
batch_tfms=[
Normalize.from_stats(*imagenet_stats)
],
n_inp=1
)
dls = db.dataloaders(df, bs=BS)
Here are frames for a few video clips.
We use a simple CNN to classify distortion and severity level of frames. The model uses a pretrained ResNet18 model as backbone and has a classification head outputs logits per distortion and severity level. Since there are multiple images fed into the model, it averages logits over all frames.
from fastai.callback.hook import model_sizes
def get_num_features(module):
return model_sizes(module, size=(300, 300))[-1][1]
class NaiveModel(Module):
def __init__(self, arch, n_distortion, n_severity, pretrained=True):
store_attr()
self.encoder = TimeDistributed(create_body(arch, pretrained=pretrained))
self.head = TimeDistributed(create_head(get_num_features(self.encoder.module), n_distortion + n_severity))
def forward(self, x):
feature_map = self.encoder(torch.stack(x, dim=1))
out = self.head(feature_map).mean(dim=1)
return [out[:, :self.n_distortion], out[:, self.n_distortion:]]
@staticmethod
def splitter(model):
return [params(model.encoder), params(model.head)]
model = NaiveModel(arch=resnet18, n_distortion=len(dls.vocab[0]), n_severity=len(dls.vocab[1])).to(DEVICE)
We will use cross-entropy losses for both distortion and severity. Since it's more difficult to detect severity of a distortion, we set the weight of severity classification loss higher than distortion loss. fastmtl
provides a utility class, CombinedLoss
, which computes weighted sum of losses for each type of prediction that the model makes. It also preserves decodes
and activation
functions of fastai loss classes so that raw logits can be decoded.
from fastmtl.loss import CombinedLoss
loss_func = CombinedLoss(CrossEntropyLossFlat(), CrossEntropyLossFlat(), weight=[1.0, 3.0])
We will evaluate the model with F1-macro and F1-micro metrics for distortion; and accuracy metric for severity. fastai already provides these metrics; but, they expect single prediction tensor and single target. Whereas, our model produces a tuple of prediction tensors for distortion and severity level. We can work around this by defining a function for each metric we use and make them pick the right prediction, such as
def distortion_f1_macro(preds, *targets):
return F1Score(average='macro')(preds[0].argmax(dim=-1), targets[0])
def distortion_f1_micro(preds, *targets):
return F1Score(average='micro')(preds[0].argmax(dim=-1), targets[0])
def severity_accuracy(preds, *targets):
return accuracy(preds[1], targets[1])
However, this is repetitive and more importantly, it converts class-based metrics to functional metrics, which behaves differently in fastai. For instance, in distortion_f1_macro
, we have to apply argmax ourselves, which wouldn't be the case, if we were able to use F1Score
directly.
fastmtl
provides a few utilities for metrics to solve these problems. mtl_metrics
function routes model predictions to the metrics provided by their order. For instance, for distortion, it routes the first output of model to the metrics in the first argument, thus to F1-macro and accuracy metrics. Check out the documentation if you are interested in more details on this.
from fastmtl.metric import mtl_metrics
distortion_f1_macro = F1Score(average='macro')
distortion_f1_macro.name = 'distortion_f1(macro)'
distortion_f1_micro = F1Score(average='micro')
distortion_f1_micro.name = 'distortion_f1(micro)'
severity_accuracy = accuracy
metrics = mtl_metrics([distortion_f1_macro, distortion_f1_micro], [severity_accuracy])
try:
metrics[2].name = 'severity_accuracy'
except AttributeError:
metrics[2].func.__name__ = 'severity_accuracy'
learn = Learner(
dls,
model,
loss_func=loss_func,
metrics=metrics,
splitter=model.splitter,
cbs = [SaveModelCallback()],
)
if DEVICE.type!='cpu':
learn = learn.to_fp16()
learn.freeze()
import warnings
from sklearn.exceptions import UndefinedMetricWarning
with warnings.catch_warnings():
warnings.filterwarnings(action='ignore', category=UndefinedMetricWarning, module=r'.*')
learn.fit_one_cycle(10, 3e-4)
The performance of the model is not great as we only train on a small subset of the original dataset. For the purpose of this tutorial, it is sufficient that the model is learning and each metric is calculated correctly.
Let's evaluate the model on validation set and visualize results.
Now, let's check model performance per distortion class and per severity level.
import warnings
from sklearn.exceptions import UndefinedMetricWarning
from sklearn.metrics import classification_report, ConfusionMatrixDisplay
from sklearn.metrics import multilabel_confusion_matrix
warnings.filterwarnings(action='ignore', category=UndefinedMetricWarning, module=r'.*')
def clf_report(learn, dl, vocabs):
probs, targets, preds = learn.get_preds(dl=dl, with_decoded=True)
for i, (vocab, target, pred) in enumerate(zip(vocabs, targets, preds)):
label_indices = list(range(len(vocab)))
y_true, y_pred = target.cpu().numpy(), pred.cpu().numpy()
print(classification_report(y_true, y_pred, labels=label_indices, target_names=vocab))
fig, ax = plt.subplots(figsize=(16, 9))
ConfusionMatrixDisplay.from_predictions(y_true, y_pred, labels=label_indices, display_labels=vocab, ax=ax)
distortion_f1_macro_val = F1Score(average='macro')(preds[0], targets[0])
distortion_accuracy_val = accuracy(probs[0], targets[0])
severity_accuracy_val = accuracy(probs[1], targets[1])
return dict(
distortion_f1_macro_val=distortion_f1_macro_val,
distortion_accuracy_val=distortion_accuracy_val,
severity_accuracy_val=severity_accuracy_val,
)
scores = clf_report(learn, dls.valid, dls.vocab)