|
|
|
import json |
|
import os.path as osp |
|
import tempfile |
|
|
|
import torch |
|
|
|
from mmocr.datasets.ner_dataset import NerDataset |
|
from mmocr.models.ner.convertors.ner_convertor import NerConvertor |
|
from mmocr.utils import list_to_file |
|
|
|
|
|
def _create_dummy_ann_file(ann_file): |
|
data = { |
|
'text': '彭小军认为,国内银行现在走的是台湾的发卡模式', |
|
'label': { |
|
'address': { |
|
'台湾': [[15, 16]] |
|
}, |
|
'name': { |
|
'彭小军': [[0, 2]] |
|
} |
|
} |
|
} |
|
|
|
list_to_file(ann_file, [json.dumps(data, ensure_ascii=False)]) |
|
|
|
|
|
def _create_dummy_vocab_file(vocab_file): |
|
for char in list(map(chr, range(ord('a'), ord('z') + 1))): |
|
list_to_file(vocab_file, [json.dumps(char + '\n', ensure_ascii=False)]) |
|
|
|
|
|
def _create_dummy_loader(): |
|
loader = dict( |
|
type='HardDiskLoader', |
|
repeat=1, |
|
parser=dict(type='LineJsonParser', keys=['text', 'label'])) |
|
return loader |
|
|
|
|
|
def test_ner_dataset(): |
|
|
|
loader = _create_dummy_loader() |
|
categories = [ |
|
'address', 'book', 'company', 'game', 'government', 'movie', 'name', |
|
'organization', 'position', 'scene' |
|
] |
|
|
|
|
|
tmp_dir = tempfile.TemporaryDirectory() |
|
ann_file = osp.join(tmp_dir.name, 'fake_data.txt') |
|
vocab_file = osp.join(tmp_dir.name, 'fake_vocab.txt') |
|
_create_dummy_ann_file(ann_file) |
|
_create_dummy_vocab_file(vocab_file) |
|
|
|
max_len = 128 |
|
ner_convertor = dict( |
|
type='NerConvertor', |
|
annotation_type='bio', |
|
vocab_file=vocab_file, |
|
categories=categories, |
|
max_len=max_len) |
|
|
|
test_pipeline = [ |
|
dict( |
|
type='NerTransform', |
|
label_convertor=ner_convertor, |
|
max_len=max_len), |
|
dict(type='ToTensorNER') |
|
] |
|
dataset = NerDataset(ann_file, loader, pipeline=test_pipeline) |
|
|
|
|
|
img_info = dataset.data_infos[0] |
|
results = dict(img_info=img_info) |
|
dataset.pre_pipeline(results) |
|
|
|
|
|
dataset.prepare_train_img(0) |
|
|
|
|
|
result = [[['address', 15, 16], ['name', 0, 2]]] |
|
|
|
dataset.evaluate(result) |
|
|
|
|
|
pred = [ |
|
21, 7, 17, 17, 21, 21, 21, 21, 21, 21, 13, 21, 21, 21, 21, 21, 1, 11, |
|
21, 21, 7, 17, 17, 21, 21, 21, 21, 21, 21, 13, 21, 21, 21, 21, 21, 1, |
|
11, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, |
|
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, |
|
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, |
|
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, |
|
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, |
|
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 1, 21, 21, 21, 21, 21, |
|
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 1, 21, 21, 21, 21, |
|
21, 21 |
|
] |
|
preds = [pred[:128]] |
|
mask = [0] * 128 |
|
for i in range(10): |
|
mask[i] = 1 |
|
assert len(preds[0]) == len(mask) |
|
masks = torch.tensor([mask]) |
|
convertor = NerConvertor( |
|
annotation_type='bio', |
|
vocab_file=vocab_file, |
|
categories=categories, |
|
max_len=128) |
|
all_entities = convertor.convert_pred2entities(preds=preds, masks=masks) |
|
assert len(all_entities[0][0]) == 3 |
|
|
|
tmp_dir.cleanup() |
|
|