|
|
|
import pytest |
|
import torch |
|
|
|
from mmocr.models.textdet.necks import FPNC, FPN_UNet |
|
|
|
|
|
def test_fpnc(): |
|
|
|
in_channels = [64, 128, 256, 512] |
|
size = [112, 56, 28, 14] |
|
for flag in [False, True]: |
|
fpnc = FPNC( |
|
in_channels=in_channels, |
|
bias_on_lateral=flag, |
|
bn_re_on_lateral=flag, |
|
bias_on_smooth=flag, |
|
bn_re_on_smooth=flag, |
|
conv_after_concat=flag) |
|
fpnc.init_weights() |
|
inputs = [] |
|
for i in range(4): |
|
inputs.append(torch.rand(1, in_channels[i], size[i], size[i])) |
|
outputs = fpnc.forward(inputs) |
|
assert list(outputs.size()) == [1, 256, 112, 112] |
|
|
|
|
|
def test_fpn_unet_neck(): |
|
s = 64 |
|
feat_sizes = [s // 2**i for i in range(4)] |
|
in_channels = [8, 16, 32, 64] |
|
out_channels = 4 |
|
|
|
|
|
with pytest.raises(AssertionError): |
|
FPN_UNet(in_channels + [128], out_channels) |
|
|
|
|
|
with pytest.raises(AssertionError): |
|
FPN_UNet(in_channels, [2, 4]) |
|
|
|
feats = [ |
|
torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i]) |
|
for i in range(len(in_channels)) |
|
] |
|
|
|
fpn_unet_neck = FPN_UNet(in_channels, out_channels) |
|
fpn_unet_neck.init_weights() |
|
|
|
out_neck = fpn_unet_neck(feats) |
|
assert out_neck.shape == torch.Size([1, out_channels, s * 4, s * 4]) |
|
|