|
|
|
import pytest |
|
import torch |
|
|
|
from mmocr.models.textrecog.encoders import (ABIVisionModel, BaseEncoder, |
|
NRTREncoder, SAREncoder, |
|
SatrnEncoder, TransformerEncoder) |
|
|
|
|
|
def test_sar_encoder(): |
|
with pytest.raises(AssertionError): |
|
SAREncoder(enc_bi_rnn='bi') |
|
with pytest.raises(AssertionError): |
|
SAREncoder(enc_do_rnn=2) |
|
with pytest.raises(AssertionError): |
|
SAREncoder(enc_gru='gru') |
|
with pytest.raises(AssertionError): |
|
SAREncoder(d_model=512.5) |
|
with pytest.raises(AssertionError): |
|
SAREncoder(d_enc=200.5) |
|
with pytest.raises(AssertionError): |
|
SAREncoder(mask='mask') |
|
|
|
encoder = SAREncoder() |
|
encoder.init_weights() |
|
encoder.train() |
|
|
|
feat = torch.randn(1, 512, 4, 40) |
|
img_metas = [{'valid_ratio': 1.0}] |
|
with pytest.raises(AssertionError): |
|
encoder(feat, img_metas * 2) |
|
out_enc = encoder(feat, img_metas) |
|
|
|
assert out_enc.shape == torch.Size([1, 512]) |
|
|
|
|
|
def test_nrtr_encoder(): |
|
tf_encoder = NRTREncoder() |
|
tf_encoder.init_weights() |
|
tf_encoder.train() |
|
|
|
feat = torch.randn(1, 512, 1, 25) |
|
out_enc = tf_encoder(feat) |
|
print('hello', out_enc.size()) |
|
assert out_enc.shape == torch.Size([1, 25, 512]) |
|
|
|
|
|
def test_satrn_encoder(): |
|
satrn_encoder = SatrnEncoder() |
|
satrn_encoder.init_weights() |
|
satrn_encoder.train() |
|
|
|
feat = torch.randn(1, 512, 8, 25) |
|
out_enc = satrn_encoder(feat) |
|
assert out_enc.shape == torch.Size([1, 200, 512]) |
|
|
|
|
|
def test_base_encoder(): |
|
encoder = BaseEncoder() |
|
encoder.init_weights() |
|
encoder.train() |
|
|
|
feat = torch.randn(1, 256, 4, 40) |
|
out_enc = encoder(feat) |
|
assert out_enc.shape == torch.Size([1, 256, 4, 40]) |
|
|
|
|
|
def test_transformer_encoder(): |
|
model = TransformerEncoder() |
|
x = torch.randn(10, 512, 8, 32) |
|
assert model(x).shape == torch.Size([10, 512, 8, 32]) |
|
|
|
|
|
def test_abi_vision_model(): |
|
model = ABIVisionModel( |
|
decoder=dict(type='ABIVisionDecoder', max_seq_len=10, use_result=None)) |
|
x = torch.randn(1, 512, 8, 32) |
|
result = model(x) |
|
assert result['feature'].shape == torch.Size([1, 10, 512]) |
|
assert result['logits'].shape == torch.Size([1, 10, 90]) |
|
assert result['attn_scores'].shape == torch.Size([1, 10, 8, 32]) |
|
|