|
|
|
import pytest |
|
import torch |
|
from mmcv.cnn.bricks import ConvModule |
|
|
|
from mmocr.utils import revert_sync_batchnorm |
|
|
|
|
|
def test_revert_sync_batchnorm(): |
|
conv_syncbn = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN')).to('cpu') |
|
conv_syncbn.train() |
|
x = torch.randn(1, 3, 10, 10) |
|
|
|
with pytest.raises(ValueError): |
|
y = conv_syncbn(x) |
|
conv_bn = revert_sync_batchnorm(conv_syncbn) |
|
y = conv_bn(x) |
|
assert y.shape == (1, 8, 9, 9) |
|
assert conv_bn.training == conv_syncbn.training |
|
conv_syncbn.eval() |
|
conv_bn = revert_sync_batchnorm(conv_syncbn) |
|
assert conv_bn.training == conv_syncbn.training |
|
|