|
|
|
import pytest |
|
import torch |
|
|
|
from mmocr.models.common.losses import DiceLoss |
|
from mmocr.models.textrecog.losses import (ABILoss, CELoss, CTCLoss, SARLoss, |
|
TFLoss) |
|
|
|
|
|
def test_ctc_loss(): |
|
with pytest.raises(AssertionError): |
|
CTCLoss(flatten='flatten') |
|
with pytest.raises(AssertionError): |
|
CTCLoss(blank=None) |
|
with pytest.raises(AssertionError): |
|
CTCLoss(reduction=1) |
|
with pytest.raises(AssertionError): |
|
CTCLoss(zero_infinity='zero') |
|
|
|
ctc_loss = CTCLoss() |
|
outputs = torch.zeros(2, 40, 37) |
|
targets_dict = { |
|
'flatten_targets': torch.IntTensor([1, 2, 3, 4, 5]), |
|
'target_lengths': torch.LongTensor([2, 3]) |
|
} |
|
|
|
losses = ctc_loss(outputs, targets_dict) |
|
assert isinstance(losses, dict) |
|
assert 'loss_ctc' in losses |
|
assert torch.allclose(losses['loss_ctc'], |
|
torch.tensor(losses['loss_ctc'].item()).float()) |
|
|
|
|
|
def test_ce_loss(): |
|
with pytest.raises(AssertionError): |
|
CELoss(ignore_index='ignore') |
|
with pytest.raises(AssertionError): |
|
CELoss(reduction=1) |
|
with pytest.raises(AssertionError): |
|
CELoss(reduction='avg') |
|
|
|
ce_loss = CELoss(ignore_index=0) |
|
outputs = torch.rand(1, 10, 37) |
|
targets_dict = { |
|
'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]]) |
|
} |
|
losses = ce_loss(outputs, targets_dict) |
|
assert isinstance(losses, dict) |
|
assert 'loss_ce' in losses |
|
assert losses['loss_ce'].size(1) == 10 |
|
|
|
ce_loss = CELoss(ignore_first_char=True) |
|
outputs = torch.rand(1, 10, 37) |
|
targets_dict = { |
|
'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]]) |
|
} |
|
new_output, new_target = ce_loss.format(outputs, targets_dict) |
|
assert new_output.shape == torch.Size([1, 37, 9]) |
|
assert new_target.shape == torch.Size([1, 9]) |
|
|
|
|
|
def test_sar_loss(): |
|
outputs = torch.rand(1, 10, 37) |
|
targets_dict = { |
|
'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]]) |
|
} |
|
sar_loss = SARLoss() |
|
new_output, new_target = sar_loss.format(outputs, targets_dict) |
|
assert new_output.shape == torch.Size([1, 37, 9]) |
|
assert new_target.shape == torch.Size([1, 9]) |
|
|
|
|
|
def test_tf_loss(): |
|
with pytest.raises(AssertionError): |
|
TFLoss(flatten=1.0) |
|
|
|
outputs = torch.rand(1, 10, 37) |
|
targets_dict = { |
|
'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]]) |
|
} |
|
tf_loss = TFLoss(flatten=False) |
|
new_output, new_target = tf_loss.format(outputs, targets_dict) |
|
assert new_output.shape == torch.Size([1, 37, 9]) |
|
assert new_target.shape == torch.Size([1, 9]) |
|
|
|
|
|
def test_dice_loss(): |
|
with pytest.raises(AssertionError): |
|
DiceLoss(eps='1') |
|
|
|
dice_loss = DiceLoss() |
|
pred = torch.rand(1, 1, 32, 32) |
|
gt = torch.rand(1, 1, 32, 32) |
|
|
|
loss = dice_loss(pred, gt, None) |
|
assert isinstance(loss, torch.Tensor) |
|
|
|
mask = torch.rand(1, 1, 1, 1) |
|
loss = dice_loss(pred, gt, mask) |
|
assert isinstance(loss, torch.Tensor) |
|
|
|
|
|
def test_abi_loss(): |
|
loss = ABILoss(num_classes=90) |
|
outputs = dict( |
|
out_enc=dict(logits=torch.randn(2, 10, 90)), |
|
out_decs=[ |
|
dict(logits=torch.randn(2, 10, 90)), |
|
dict(logits=torch.randn(2, 10, 90)) |
|
], |
|
out_fusers=[ |
|
dict(logits=torch.randn(2, 10, 90)), |
|
dict(logits=torch.randn(2, 10, 90)) |
|
]) |
|
targets_dict = { |
|
'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]]), |
|
'targets': |
|
[torch.LongTensor([1, 2, 3, 4]), |
|
torch.LongTensor([1, 2, 3])] |
|
} |
|
result = loss(outputs, targets_dict) |
|
assert isinstance(result, dict) |
|
assert isinstance(result['loss_visual'], torch.Tensor) |
|
assert isinstance(result['loss_lang'], torch.Tensor) |
|
assert isinstance(result['loss_fusion'], torch.Tensor) |
|
|
|
outputs.pop('out_enc') |
|
loss(outputs, targets_dict) |
|
outputs.pop('out_decs') |
|
loss(outputs, targets_dict) |
|
outputs.pop('out_fusers') |
|
with pytest.raises(AssertionError): |
|
loss(outputs, targets_dict) |
|
|