class LossRouting[source]

LossRouting(loss_func:typing.Callable, pred_idx:int, target_idx:int, weight:float=1.0)

LossRouting(loss_func: Callable, pred_idx: int, target_idx: int, weight: float = 1.0)

class CombinedLoss[source]

CombinedLoss(*loss_routings)

Applies loss functions to multiple model outputs and sums them. If applicable, it can decode and compute activations for each model output.

Assume that a multi-task learning model produces two outputs:

  1. The logits for multi-class single-label classification, for which we want to use cross-entropy loss and softmax activation
  2. A logit for single-class classification, for which we want to use binary cross-entropy and sigmoid activation

CombinedLoss enables using the corresponding loss function and its activation function for each model output.

from fastai.vision.all import *

ce = CrossEntropyLossFlat()
bce = BCEWithLogitsLossFlat()
comb_loss = CombinedLoss.from_one_to_one_routing(ce, bce)

bs = 8
target1, output1 = torch.randint(0, 5, (bs,)), torch.randn(bs, 5) # 5 classes
target2, output2 = torch.randint(0, 2, (bs,), dtype=float), torch.randn(bs)*10
actual = comb_loss((output1, output2), target1, target2)

loss1 = ce(output1, target1)
loss2 = bce(output2, target2)
expected = loss1 + loss2
test_close(expected, actual)

# activations
actual_acts_output1, actual_acts_output2 = comb_loss.activation([output1, output2])
expected_acts_output1, expected_acts_output2 = ce.activation(output1), bce.activation(output2)
test_close(expected_acts_output1, actual_acts_output1)
test_eq(expected_acts_output2, actual_acts_output2)

# decoding
actual_decoded_output1, actual_decoded_output2 = comb_loss.decodes([output1, output2])
expected_decoded_output1, expected_decoded_output2 = ce.decodes(output1), bce.decodes(output2)
test_close(expected_decoded_output1, actual_decoded_output1)
test_eq(expected_decoded_output2, actual_decoded_output2)

Here are raw model outputs (logits):

[output1, output2]
[tensor([[ 9.6559e-01,  1.2909e+00, -2.0418e-01, -9.8991e-02,  5.9807e-01],
         [-6.1768e-01, -8.8121e-01, -9.5900e-03, -1.4741e+00, -5.2530e-01],
         [ 9.5259e-01,  1.2350e+00, -5.7586e-01, -6.4723e-02, -8.5460e-01],
         [ 1.3948e+00,  6.7017e-01,  2.4812e+00, -2.3243e+00,  4.6702e-01],
         [ 3.4889e-02, -2.5438e-01, -1.0769e+00, -9.6301e-02,  1.1432e+00],
         [-9.2353e-01, -4.6509e-01,  1.2955e+00,  3.1447e-01, -2.5700e+00],
         [ 8.2171e-01, -2.3441e-01, -4.7117e-01,  5.1372e-01,  7.5967e-01],
         [-5.2264e-01,  3.5434e-01,  2.9362e-01,  8.5736e-04, -1.8668e-01]]),
 tensor([-2.1232, 11.6096, -5.7914, -7.5502,  7.1219, -5.3170,  4.3356, 13.7366])]

When applicable, it can decode the raw model outputs and compute activations. For instance, let's decode logits to class label indices and binary classes.

comb_loss.decodes([output1, output2])
[tensor([1, 2, 1, 2, 4, 2, 0, 1]),
 tensor([False,  True, False, False,  True, False,  True,  True])]

Similary, here are the activations for each model output.

comb_loss.activation([output1, output2])
[tensor([[0.2679, 0.3710, 0.0832, 0.0924, 0.1855],
         [0.1951, 0.1499, 0.3583, 0.0828, 0.2139],
         [0.3259, 0.4322, 0.0707, 0.1178, 0.0535],
         [0.2054, 0.0995, 0.6088, 0.0050, 0.0812],
         [0.1671, 0.1251, 0.0550, 0.1466, 0.5062],
         [0.0648, 0.1026, 0.5965, 0.2236, 0.0125],
         [0.3033, 0.1055, 0.0832, 0.2229, 0.2851],
         [0.1142, 0.2746, 0.2584, 0.1928, 0.1599]]),
 tensor([1.0686e-01, 9.9999e-01, 3.0443e-03, 5.2573e-04, 9.9919e-01, 4.8834e-03,
         9.8707e-01, 1.0000e+00])]