|
|
|
import math |
|
|
|
import pytest |
|
import torch |
|
|
|
from mmocr.models.textrecog.decoders import (ABILanguageDecoder, |
|
ABIVisionDecoder, BaseDecoder, |
|
NRTRDecoder, ParallelSARDecoder, |
|
ParallelSARDecoderWithBS, |
|
SequentialSARDecoder) |
|
from mmocr.models.textrecog.decoders.sar_decoder_with_bs import DecodeNode |
|
|
|
|
|
def _create_dummy_input(): |
|
feat = torch.rand(1, 512, 4, 40) |
|
out_enc = torch.rand(1, 512) |
|
tgt_dict = {'padded_targets': torch.LongTensor([[1, 1, 1, 1, 36]])} |
|
img_metas = [{'valid_ratio': 1.0}] |
|
|
|
return feat, out_enc, tgt_dict, img_metas |
|
|
|
|
|
def test_base_decoder(): |
|
decoder = BaseDecoder() |
|
with pytest.raises(NotImplementedError): |
|
decoder.forward_train(None, None, None, None) |
|
with pytest.raises(NotImplementedError): |
|
decoder.forward_test(None, None, None) |
|
|
|
|
|
def test_parallel_sar_decoder(): |
|
|
|
decoder = ParallelSARDecoder(num_classes=37, padding_idx=36, max_seq_len=5) |
|
decoder.init_weights() |
|
decoder.train() |
|
|
|
feat, out_enc, tgt_dict, img_metas = _create_dummy_input() |
|
with pytest.raises(AssertionError): |
|
decoder(feat, out_enc, tgt_dict, [], True) |
|
with pytest.raises(AssertionError): |
|
decoder(feat, out_enc, tgt_dict, img_metas * 2, True) |
|
|
|
out_train = decoder(feat, out_enc, tgt_dict, img_metas, True) |
|
assert out_train.shape == torch.Size([1, 5, 36]) |
|
|
|
out_test = decoder(feat, out_enc, tgt_dict, img_metas, False) |
|
assert out_test.shape == torch.Size([1, 5, 36]) |
|
|
|
|
|
def test_sequential_sar_decoder(): |
|
|
|
decoder = SequentialSARDecoder( |
|
num_classes=37, padding_idx=36, max_seq_len=5) |
|
decoder.init_weights() |
|
decoder.train() |
|
|
|
feat, out_enc, tgt_dict, img_metas = _create_dummy_input() |
|
with pytest.raises(AssertionError): |
|
decoder(feat, out_enc, tgt_dict, []) |
|
with pytest.raises(AssertionError): |
|
decoder(feat, out_enc, tgt_dict, img_metas * 2) |
|
|
|
out_train = decoder(feat, out_enc, tgt_dict, img_metas, True) |
|
assert out_train.shape == torch.Size([1, 5, 36]) |
|
|
|
out_test = decoder(feat, out_enc, tgt_dict, img_metas, False) |
|
assert out_test.shape == torch.Size([1, 5, 36]) |
|
|
|
|
|
def test_parallel_sar_decoder_with_beam_search(): |
|
with pytest.raises(AssertionError): |
|
ParallelSARDecoderWithBS(beam_width='beam') |
|
with pytest.raises(AssertionError): |
|
ParallelSARDecoderWithBS(beam_width=0) |
|
|
|
feat, out_enc, tgt_dict, img_metas = _create_dummy_input() |
|
decoder = ParallelSARDecoderWithBS( |
|
beam_width=1, num_classes=37, padding_idx=36, max_seq_len=5) |
|
decoder.init_weights() |
|
decoder.train() |
|
with pytest.raises(AssertionError): |
|
decoder(feat, out_enc, tgt_dict, []) |
|
with pytest.raises(AssertionError): |
|
decoder(feat, out_enc, tgt_dict, img_metas * 2) |
|
|
|
out_test = decoder(feat, out_enc, tgt_dict, img_metas, train_mode=False) |
|
assert out_test.shape == torch.Size([1, 5, 36]) |
|
|
|
|
|
with pytest.raises(AssertionError): |
|
DecodeNode(1, 1) |
|
with pytest.raises(AssertionError): |
|
DecodeNode([1, 2], ['4', '3']) |
|
with pytest.raises(AssertionError): |
|
DecodeNode([1, 2], [0.5]) |
|
decode_node = DecodeNode([1, 2], [0.7, 0.8]) |
|
assert math.isclose(decode_node.eval(), 1.5) |
|
|
|
|
|
def test_transformer_decoder(): |
|
decoder = NRTRDecoder(num_classes=37, padding_idx=36, max_seq_len=5) |
|
decoder.init_weights() |
|
decoder.train() |
|
|
|
out_enc = torch.rand(1, 25, 512) |
|
tgt_dict = {'padded_targets': torch.LongTensor([[1, 1, 1, 1, 36]])} |
|
img_metas = [{'valid_ratio': 1.0}] |
|
tgt_dict['padded_targets'] = tgt_dict['padded_targets'] |
|
|
|
out_train = decoder(None, out_enc, tgt_dict, img_metas, True) |
|
assert out_train.shape == torch.Size([1, 5, 36]) |
|
|
|
out_test = decoder(None, out_enc, tgt_dict, img_metas, False) |
|
assert out_test.shape == torch.Size([1, 5, 36]) |
|
|
|
|
|
def test_abi_language_decoder(): |
|
decoder = ABILanguageDecoder(max_seq_len=25) |
|
logits = torch.randn(2, 25, 90) |
|
result = decoder( |
|
feat=None, out_enc=logits, targets_dict=None, img_metas=None) |
|
assert result['feature'].shape == torch.Size([2, 25, 512]) |
|
assert result['logits'].shape == torch.Size([2, 25, 90]) |
|
|
|
|
|
def test_abi_vision_decoder(): |
|
model = ABIVisionDecoder( |
|
in_channels=128, num_channels=16, max_seq_len=10, use_result=None) |
|
x = torch.randn(2, 128, 8, 32) |
|
result = model(x, None) |
|
assert result['feature'].shape == torch.Size([2, 10, 128]) |
|
assert result['logits'].shape == torch.Size([2, 10, 90]) |
|
assert result['attn_scores'].shape == torch.Size([2, 10, 8, 32]) |
|
|