|
|
|
import copy |
|
import os.path as osp |
|
import tempfile |
|
|
|
import pytest |
|
import torch |
|
|
|
from mmocr.models import build_detector |
|
|
|
|
|
def _create_dummy_vocab_file(vocab_file): |
|
with open(vocab_file, 'w') as fw: |
|
for char in list(map(chr, range(ord('a'), ord('z') + 1))): |
|
fw.write(char + '\n') |
|
|
|
|
|
def _get_config_module(fname): |
|
"""Load a configuration as a python module.""" |
|
from mmcv import Config |
|
config_mod = Config.fromfile(fname) |
|
return config_mod |
|
|
|
|
|
def _get_detector_cfg(fname): |
|
"""Grab configs necessary to create a detector. |
|
|
|
These are deep copied to allow for safe modification of parameters without |
|
influencing other tests. |
|
""" |
|
config = _get_config_module(fname) |
|
model = copy.deepcopy(config.model) |
|
return model |
|
|
|
|
|
@pytest.mark.parametrize( |
|
'cfg_file', ['configs/ner/bert_softmax/bert_softmax_cluener_18e.py']) |
|
def test_bert_softmax(cfg_file): |
|
|
|
texts = ['中'] * 47 |
|
img = [31] * 47 |
|
labels = [31] * 128 |
|
input_ids = [0] * 128 |
|
attention_mask = [0] * 128 |
|
token_type_ids = [0] * 128 |
|
img_metas = { |
|
'texts': texts, |
|
'labels': torch.tensor(labels).unsqueeze(0), |
|
'img': img, |
|
'input_ids': torch.tensor(input_ids).unsqueeze(0), |
|
'attention_masks': torch.tensor(attention_mask).unsqueeze(0), |
|
'token_type_ids': torch.tensor(token_type_ids).unsqueeze(0) |
|
} |
|
|
|
|
|
tmp_dir = tempfile.TemporaryDirectory() |
|
vocab_file = osp.join(tmp_dir.name, 'fake_vocab.txt') |
|
_create_dummy_vocab_file(vocab_file) |
|
|
|
model = _get_detector_cfg(cfg_file) |
|
model['label_convertor']['vocab_file'] = vocab_file |
|
|
|
detector = build_detector(model) |
|
losses = detector.forward(img, img_metas) |
|
assert isinstance(losses, dict) |
|
|
|
model['loss']['type'] = 'MaskedFocalLoss' |
|
detector = build_detector(model) |
|
losses = detector.forward(img, img_metas) |
|
assert isinstance(losses, dict) |
|
|
|
tmp_dir.cleanup() |
|
|
|
|
|
with torch.no_grad(): |
|
batch_results = [] |
|
result = detector.forward(None, img_metas, return_loss=False) |
|
batch_results.append(result) |
|
|