|
|
|
import copy |
|
import warnings |
|
from os import path as osp |
|
|
|
import numpy as np |
|
import torch |
|
from mmdet.datasets.builder import DATASETS |
|
|
|
from mmocr.core import compute_f1_score |
|
from mmocr.datasets.base_dataset import BaseDataset |
|
from mmocr.datasets.pipelines import sort_vertex8 |
|
from mmocr.utils import is_type_list, list_from_file |
|
|
|
|
|
@DATASETS.register_module() |
|
class KIEDataset(BaseDataset): |
|
""" |
|
Args: |
|
ann_file (str): Annotation file path. |
|
pipeline (list[dict]): Processing pipeline. |
|
loader (dict): Dictionary to construct loader |
|
to load annotation infos. |
|
img_prefix (str, optional): Image prefix to generate full |
|
image path. |
|
test_mode (bool, optional): If True, try...except will |
|
be turned off in __getitem__. |
|
dict_file (str): Character dict file path. |
|
norm (float): Norm to map value from one range to another. |
|
""" |
|
|
|
def __init__(self, |
|
ann_file=None, |
|
loader=None, |
|
dict_file=None, |
|
img_prefix='', |
|
pipeline=None, |
|
norm=10., |
|
directed=False, |
|
test_mode=True, |
|
**kwargs): |
|
if ann_file is None and loader is None: |
|
warnings.warn( |
|
'KIEDataset is only initialized as a downstream demo task ' |
|
'of text detection and recognition ' |
|
'without an annotation file.', UserWarning) |
|
else: |
|
super().__init__( |
|
ann_file, |
|
loader, |
|
pipeline, |
|
img_prefix=img_prefix, |
|
test_mode=test_mode) |
|
assert osp.exists(dict_file) |
|
|
|
self.norm = norm |
|
self.directed = directed |
|
self.dict = { |
|
'': 0, |
|
**{ |
|
line.rstrip('\r\n'): ind |
|
for ind, line in enumerate(list_from_file(dict_file), 1) |
|
} |
|
} |
|
|
|
def pre_pipeline(self, results): |
|
results['img_prefix'] = self.img_prefix |
|
results['bbox_fields'] = [] |
|
results['ori_texts'] = results['ann_info']['ori_texts'] |
|
results['filename'] = osp.join(self.img_prefix, |
|
results['img_info']['filename']) |
|
results['ori_filename'] = results['img_info']['filename'] |
|
|
|
results['img'] = np.zeros((0, 0, 0), dtype=np.uint8) |
|
|
|
def _parse_anno_info(self, annotations): |
|
"""Parse annotations of boxes, texts and labels for one image. |
|
Args: |
|
annotations (list[dict]): Annotations of one image, where |
|
each dict is for one character. |
|
|
|
Returns: |
|
dict: A dict containing the following keys: |
|
|
|
- bboxes (np.ndarray): Bbox in one image with shape: |
|
box_num * 4. They are sorted clockwise when loading. |
|
- relations (np.ndarray): Relations between bbox with shape: |
|
box_num * box_num * D. |
|
- texts (np.ndarray): Text index with shape: |
|
box_num * text_max_len. |
|
- labels (np.ndarray): Box Labels with shape: |
|
box_num * (box_num + 1). |
|
""" |
|
|
|
assert is_type_list(annotations, dict) |
|
assert len(annotations) > 0, 'Please remove data with empty annotation' |
|
assert 'box' in annotations[0] |
|
assert 'text' in annotations[0] |
|
|
|
boxes, texts, text_inds, labels, edges = [], [], [], [], [] |
|
for ann in annotations: |
|
box = ann['box'] |
|
sorted_box = sort_vertex8(box[:8]) |
|
boxes.append(sorted_box) |
|
text = ann['text'] |
|
texts.append(ann['text']) |
|
text_ind = [self.dict[c] for c in text if c in self.dict] |
|
text_inds.append(text_ind) |
|
labels.append(ann.get('label', 0)) |
|
edges.append(ann.get('edge', 0)) |
|
|
|
ann_infos = dict( |
|
boxes=boxes, |
|
texts=texts, |
|
text_inds=text_inds, |
|
edges=edges, |
|
labels=labels) |
|
|
|
return self.list_to_numpy(ann_infos) |
|
|
|
def prepare_train_img(self, index): |
|
"""Get training data and annotations from pipeline. |
|
|
|
Args: |
|
index (int): Index of data. |
|
|
|
Returns: |
|
dict: Training data and annotation after pipeline with new keys |
|
introduced by pipeline. |
|
""" |
|
img_ann_info = self.data_infos[index] |
|
img_info = { |
|
'filename': img_ann_info['file_name'], |
|
'height': img_ann_info['height'], |
|
'width': img_ann_info['width'] |
|
} |
|
ann_info = self._parse_anno_info(img_ann_info['annotations']) |
|
results = dict(img_info=img_info, ann_info=ann_info) |
|
|
|
self.pre_pipeline(results) |
|
|
|
return self.pipeline(results) |
|
|
|
def evaluate(self, |
|
results, |
|
metric='macro_f1', |
|
metric_options=dict(macro_f1=dict(ignores=[])), |
|
**kwargs): |
|
|
|
assert set(kwargs).issubset(['logger']) |
|
|
|
|
|
metric_options = copy.deepcopy(metric_options) |
|
|
|
metrics = metric if isinstance(metric, list) else [metric] |
|
allowed_metrics = ['macro_f1'] |
|
for m in metrics: |
|
if m not in allowed_metrics: |
|
raise KeyError(f'metric {m} is not supported') |
|
|
|
return self.compute_macro_f1(results, **metric_options['macro_f1']) |
|
|
|
def compute_macro_f1(self, results, ignores=[]): |
|
node_preds = [] |
|
node_gts = [] |
|
for idx, result in enumerate(results): |
|
node_preds.append(result['nodes'].cpu()) |
|
box_ann_infos = self.data_infos[idx]['annotations'] |
|
node_gt = [box_ann_info['label'] for box_ann_info in box_ann_infos] |
|
node_gts.append(torch.Tensor(node_gt)) |
|
|
|
node_preds = torch.cat(node_preds) |
|
node_gts = torch.cat(node_gts).int() |
|
|
|
node_f1s = compute_f1_score(node_preds, node_gts, ignores) |
|
|
|
return { |
|
'macro_f1': node_f1s.mean(), |
|
} |
|
|
|
def list_to_numpy(self, ann_infos): |
|
"""Convert bboxes, relations, texts and labels to ndarray.""" |
|
boxes, text_inds = ann_infos['boxes'], ann_infos['text_inds'] |
|
texts = ann_infos['texts'] |
|
boxes = np.array(boxes, np.int32) |
|
relations, bboxes = self.compute_relation(boxes) |
|
|
|
labels = ann_infos.get('labels', None) |
|
if labels is not None: |
|
labels = np.array(labels, np.int32) |
|
edges = ann_infos.get('edges', None) |
|
if edges is not None: |
|
labels = labels[:, None] |
|
edges = np.array(edges) |
|
edges = (edges[:, None] == edges[None, :]).astype(np.int32) |
|
if self.directed: |
|
edges = (edges & labels == 1).astype(np.int32) |
|
np.fill_diagonal(edges, -1) |
|
labels = np.concatenate([labels, edges], -1) |
|
padded_text_inds = self.pad_text_indices(text_inds) |
|
|
|
return dict( |
|
bboxes=bboxes, |
|
relations=relations, |
|
texts=padded_text_inds, |
|
ori_texts=texts, |
|
labels=labels) |
|
|
|
def pad_text_indices(self, text_inds): |
|
"""Pad text index to same length.""" |
|
max_len = max([len(text_ind) for text_ind in text_inds]) |
|
padded_text_inds = -np.ones((len(text_inds), max_len), np.int32) |
|
for idx, text_ind in enumerate(text_inds): |
|
padded_text_inds[idx, :len(text_ind)] = np.array(text_ind) |
|
return padded_text_inds |
|
|
|
def compute_relation(self, boxes): |
|
"""Compute relation between every two boxes.""" |
|
|
|
|
|
bboxes = np.concatenate( |
|
[boxes[:, 0::2].min(axis=1, keepdims=True), |
|
boxes[:, 1::2].min(axis=1, keepdims=True), |
|
boxes[:, 0::2].max(axis=1, keepdims=True), |
|
boxes[:, 1::2].max(axis=1, keepdims=True)], |
|
axis=1).astype(np.float32) |
|
|
|
x1, y1 = bboxes[:, 0:1], bboxes[:, 1:2] |
|
x2, y2 = bboxes[:, 2:3], bboxes[:, 3:4] |
|
w, h = np.maximum(x2 - x1 + 1, 1), np.maximum(y2 - y1 + 1, 1) |
|
dx = (x1.T - x1) / self.norm |
|
dy = (y1.T - y1) / self.norm |
|
xhh, xwh = h.T / h, w.T / h |
|
whs = w / h + np.zeros_like(xhh) |
|
relation = np.stack([dx, dy, whs, xhh, xwh], -1).astype(np.float32) |
|
return relation, bboxes |
|
|