Assume that a multi-task learning model produces two outputs:
- The logits for multi-class single-label classification, for which we want to use cross-entropy loss and softmax activation
- A logit for single-class classification, for which we want to use binary cross-entropy and sigmoid activation
CombinedLoss
enables using the corresponding loss function and its activation function for each model output.
from fastai.vision.all import *
ce = CrossEntropyLossFlat()
bce = BCEWithLogitsLossFlat()
comb_loss = CombinedLoss.from_one_to_one_routing(ce, bce)
bs = 8
target1, output1 = torch.randint(0, 5, (bs,)), torch.randn(bs, 5) # 5 classes
target2, output2 = torch.randint(0, 2, (bs,), dtype=float), torch.randn(bs)*10
actual = comb_loss((output1, output2), target1, target2)
loss1 = ce(output1, target1)
loss2 = bce(output2, target2)
expected = loss1 + loss2
test_close(expected, actual)
# activations
actual_acts_output1, actual_acts_output2 = comb_loss.activation([output1, output2])
expected_acts_output1, expected_acts_output2 = ce.activation(output1), bce.activation(output2)
test_close(expected_acts_output1, actual_acts_output1)
test_eq(expected_acts_output2, actual_acts_output2)
# decoding
actual_decoded_output1, actual_decoded_output2 = comb_loss.decodes([output1, output2])
expected_decoded_output1, expected_decoded_output2 = ce.decodes(output1), bce.decodes(output2)
test_close(expected_decoded_output1, actual_decoded_output1)
test_eq(expected_decoded_output2, actual_decoded_output2)
Here are raw model outputs (logits):
[output1, output2]
When applicable, it can decode the raw model outputs and compute activations. For instance, let's decode logits to class label indices and binary classes.
comb_loss.decodes([output1, output2])
Similary, here are the activations for each model output.
comb_loss.activation([output1, output2])