|
|
|
import copy |
|
|
|
import numpy as np |
|
import torch |
|
from mmdet.datasets.builder import DATASETS |
|
|
|
from mmocr.datasets import KIEDataset |
|
|
|
|
|
@DATASETS.register_module() |
|
class OpensetKIEDataset(KIEDataset): |
|
"""Openset KIE classifies the nodes (i.e. text boxes) into bg/key/value |
|
categories, and additionally learns key-value relationship among nodes. |
|
|
|
Args: |
|
ann_file (str): Annotation file path. |
|
loader (dict): Dictionary to construct loader |
|
to load annotation infos. |
|
dict_file (str): Character dict file path. |
|
img_prefix (str, optional): Image prefix to generate full |
|
image path. |
|
pipeline (list[dict]): Processing pipeline. |
|
norm (float): Norm to map value from one range to another. |
|
link_type (str): ``one-to-one`` | ``one-to-many`` | |
|
``many-to-one`` | ``many-to-many``. For ``many-to-many``, |
|
one key box can have many values and vice versa. |
|
edge_thr (float): Score threshold for a valid edge. |
|
test_mode (bool, optional): If True, try...except will |
|
be turned off in __getitem__. |
|
key_node_idx (int): Index of key in node classes. |
|
value_node_idx (int): Index of value in node classes. |
|
node_classes (int): Number of node classes. |
|
""" |
|
|
|
def __init__(self, |
|
ann_file, |
|
loader, |
|
dict_file, |
|
img_prefix='', |
|
pipeline=None, |
|
norm=10., |
|
link_type='one-to-one', |
|
edge_thr=0.5, |
|
test_mode=True, |
|
key_node_idx=1, |
|
value_node_idx=2, |
|
node_classes=4): |
|
super().__init__(ann_file, loader, dict_file, img_prefix, pipeline, |
|
norm, False, test_mode) |
|
assert link_type in [ |
|
'one-to-one', 'one-to-many', 'many-to-one', 'many-to-many', 'none' |
|
] |
|
self.link_type = link_type |
|
self.data_dict = {x['file_name']: x for x in self.data_infos} |
|
self.edge_thr = edge_thr |
|
self.key_node_idx = key_node_idx |
|
self.value_node_idx = value_node_idx |
|
self.node_classes = node_classes |
|
|
|
def pre_pipeline(self, results): |
|
super().pre_pipeline(results) |
|
results['ori_texts'] = results['ann_info']['ori_texts'] |
|
results['ori_boxes'] = results['ann_info']['ori_boxes'] |
|
|
|
def list_to_numpy(self, ann_infos): |
|
results = super().list_to_numpy(ann_infos) |
|
results.update(dict(ori_texts=ann_infos['texts'])) |
|
results.update(dict(ori_boxes=ann_infos['boxes'])) |
|
|
|
return results |
|
|
|
def evaluate(self, |
|
results, |
|
metric='openset_f1', |
|
metric_options=None, |
|
**kwargs): |
|
|
|
metric_options = copy.deepcopy(metric_options) |
|
|
|
metrics = metric if isinstance(metric, list) else [metric] |
|
allowed_metrics = ['openset_f1'] |
|
for m in metrics: |
|
if m not in allowed_metrics: |
|
raise KeyError(f'metric {m} is not supported') |
|
|
|
preds, gts = [], [] |
|
for result in results: |
|
|
|
pred = self.decode_pred(result) |
|
preds.append(pred) |
|
|
|
gt = self.decode_gt(pred['filename']) |
|
gts.append(gt) |
|
|
|
return self.compute_openset_f1(preds, gts) |
|
|
|
def _decode_pairs_gt(self, labels, edge_ids): |
|
"""Find all pairs in gt. |
|
|
|
The first index in the pair (n1, n2) is key. |
|
""" |
|
gt_pairs = [] |
|
for i, label in enumerate(labels): |
|
if label == self.key_node_idx: |
|
for j, edge_id in enumerate(edge_ids): |
|
if edge_id == edge_ids[i] and labels[ |
|
j] == self.value_node_idx: |
|
gt_pairs.append((i, j)) |
|
|
|
return gt_pairs |
|
|
|
@staticmethod |
|
def _decode_pairs_pred(nodes, |
|
labels, |
|
edges, |
|
edge_thr=0.5, |
|
link_type='one-to-one'): |
|
"""Find all pairs in prediction. |
|
|
|
The first index in the pair (n1, n2) is more likely to be a key |
|
according to prediction in nodes. |
|
""" |
|
edges = torch.max(edges, edges.T) |
|
if link_type in ['none', 'many-to-many']: |
|
pair_inds = (edges > edge_thr).nonzero(as_tuple=True) |
|
pred_pairs = [(n1.item(), |
|
n2.item()) if nodes[n1, 1] > nodes[n1, 2] else |
|
(n2.item(), n1.item()) for n1, n2 in zip(*pair_inds) |
|
if n1 < n2] |
|
pred_pairs = [(i, j) for i, j in pred_pairs |
|
if labels[i] == 1 and labels[j] == 2] |
|
else: |
|
links = edges.clone() |
|
links[links <= edge_thr] = -1 |
|
links[labels != 1, :] = -1 |
|
links[:, labels != 2] = -1 |
|
|
|
pred_pairs = [] |
|
while (links > -1).any(): |
|
i, j = np.unravel_index(torch.argmax(links), links.shape) |
|
pred_pairs.append((i, j)) |
|
if link_type == 'one-to-one': |
|
links[i, :] = -1 |
|
links[:, j] = -1 |
|
elif link_type == 'one-to-many': |
|
links[:, j] = -1 |
|
elif link_type == 'many-to-one': |
|
links[i, :] = -1 |
|
else: |
|
raise ValueError(f'not supported link type {link_type}') |
|
|
|
pairs_conf = [edges[i, j].item() for i, j in pred_pairs] |
|
return pred_pairs, pairs_conf |
|
|
|
def decode_pred(self, result): |
|
"""Decode prediction. |
|
|
|
Assemble boxes and predicted labels into bboxes, and convert edges into |
|
matrix. |
|
""" |
|
filename = result['img_metas'][0]['ori_filename'] |
|
nodes = result['nodes'].cpu() |
|
labels_conf, labels = torch.max(nodes, dim=-1) |
|
num_nodes = nodes.size(0) |
|
edges = result['edges'][:, -1].view(num_nodes, num_nodes).cpu() |
|
annos = self.data_dict[filename]['annotations'] |
|
boxes = [x['box'] for x in annos] |
|
texts = [x['text'] for x in annos] |
|
bboxes = torch.Tensor(boxes)[:, [0, 1, 4, 5]] |
|
bboxes = torch.cat([bboxes, labels[:, None].float()], -1) |
|
pairs, pairs_conf = self._decode_pairs_pred(nodes, labels, edges, |
|
self.edge_thr, |
|
self.link_type) |
|
pred = { |
|
'filename': filename, |
|
'boxes': boxes, |
|
'bboxes': bboxes.tolist(), |
|
'labels': labels.tolist(), |
|
'labels_conf': labels_conf.tolist(), |
|
'texts': texts, |
|
'pairs': pairs, |
|
'pairs_conf': pairs_conf |
|
} |
|
return pred |
|
|
|
def decode_gt(self, filename): |
|
"""Decode ground truth. |
|
|
|
Assemble boxes and labels into bboxes. |
|
""" |
|
annos = self.data_dict[filename]['annotations'] |
|
labels = torch.Tensor([x['label'] for x in annos]) |
|
texts = [x['text'] for x in annos] |
|
edge_ids = [x['edge'] for x in annos] |
|
boxes = [x['box'] for x in annos] |
|
bboxes = torch.Tensor(boxes)[:, [0, 1, 4, 5]] |
|
bboxes = torch.cat([bboxes, labels[:, None].float()], -1) |
|
pairs = self._decode_pairs_gt(labels, edge_ids) |
|
gt = { |
|
'filename': filename, |
|
'boxes': boxes, |
|
'bboxes': bboxes.tolist(), |
|
'labels': labels.tolist(), |
|
'labels_conf': [1. for _ in labels], |
|
'texts': texts, |
|
'pairs': pairs, |
|
'pairs_conf': [1. for _ in pairs] |
|
} |
|
return gt |
|
|
|
def compute_openset_f1(self, preds, gts): |
|
"""Compute openset macro-f1 and micro-f1 score. |
|
|
|
Args: |
|
preds: (list[dict]): List of prediction results, including |
|
keys: ``filename``, ``pairs``, etc. |
|
gts: (list[dict]): List of ground-truth infos, including |
|
keys: ``filename``, ``pairs``, etc. |
|
|
|
Returns: |
|
dict: Evaluation result with keys: ``node_openset_micro_f1``, \ |
|
``node_openset_macro_f1``, ``edge_openset_f1``. |
|
""" |
|
|
|
total_edge_hit_num, total_edge_gt_num, total_edge_pred_num = 0, 0, 0 |
|
total_node_hit_num, total_node_gt_num, total_node_pred_num = {}, {}, {} |
|
node_inds = list(range(self.node_classes)) |
|
for node_idx in node_inds: |
|
total_node_hit_num[node_idx] = 0 |
|
total_node_gt_num[node_idx] = 0 |
|
total_node_pred_num[node_idx] = 0 |
|
|
|
img_level_res = {} |
|
for pred, gt in zip(preds, gts): |
|
filename = pred['filename'] |
|
img_res = {} |
|
|
|
pairs_pred = pred['pairs'] |
|
pairs_gt = gt['pairs'] |
|
img_res['edge_hit_num'] = 0 |
|
for pair in pairs_gt: |
|
if pair in pairs_pred: |
|
img_res['edge_hit_num'] += 1 |
|
img_res['edge_recall'] = 1.0 * img_res['edge_hit_num'] / max( |
|
1, len(pairs_gt)) |
|
img_res['edge_precision'] = 1.0 * img_res['edge_hit_num'] / max( |
|
1, len(pairs_pred)) |
|
img_res['f1'] = 2 * img_res['edge_recall'] * img_res[ |
|
'edge_precision'] / max( |
|
1, img_res['edge_recall'] + img_res['edge_precision']) |
|
total_edge_hit_num += img_res['edge_hit_num'] |
|
total_edge_gt_num += len(pairs_gt) |
|
total_edge_pred_num += len(pairs_pred) |
|
|
|
|
|
nodes_pred = pred['labels'] |
|
nodes_gt = gt['labels'] |
|
for i, node_gt in enumerate(nodes_gt): |
|
node_gt = int(node_gt) |
|
total_node_gt_num[node_gt] += 1 |
|
if nodes_pred[i] == node_gt: |
|
total_node_hit_num[node_gt] += 1 |
|
for node_pred in nodes_pred: |
|
total_node_pred_num[node_pred] += 1 |
|
|
|
img_level_res[filename] = img_res |
|
|
|
stats = {} |
|
|
|
total_edge_recall = 1.0 * total_edge_hit_num / max( |
|
1, total_edge_gt_num) |
|
total_edge_precision = 1.0 * total_edge_hit_num / max( |
|
1, total_edge_pred_num) |
|
edge_f1 = 2 * total_edge_recall * total_edge_precision / max( |
|
1, total_edge_recall + total_edge_precision) |
|
stats = {'edge_openset_f1': edge_f1} |
|
|
|
|
|
cared_node_hit_num, cared_node_gt_num, cared_node_pred_num = 0, 0, 0 |
|
node_macro_metric = {} |
|
for node_idx in node_inds: |
|
if node_idx < 1 or node_idx > 2: |
|
continue |
|
cared_node_hit_num += total_node_hit_num[node_idx] |
|
cared_node_gt_num += total_node_gt_num[node_idx] |
|
cared_node_pred_num += total_node_pred_num[node_idx] |
|
node_res = {} |
|
node_res['recall'] = 1.0 * total_node_hit_num[node_idx] / max( |
|
1, total_node_gt_num[node_idx]) |
|
node_res['precision'] = 1.0 * total_node_hit_num[node_idx] / max( |
|
1, total_node_pred_num[node_idx]) |
|
node_res[ |
|
'f1'] = 2 * node_res['recall'] * node_res['precision'] / max( |
|
1, node_res['recall'] + node_res['precision']) |
|
node_macro_metric[node_idx] = node_res |
|
|
|
node_micro_recall = 1.0 * cared_node_hit_num / max( |
|
1, cared_node_gt_num) |
|
node_micro_precision = 1.0 * cared_node_hit_num / max( |
|
1, cared_node_pred_num) |
|
node_micro_f1 = 2 * node_micro_recall * node_micro_precision / max( |
|
1, node_micro_recall + node_micro_precision) |
|
|
|
stats['node_openset_micro_f1'] = node_micro_f1 |
|
stats['node_openset_macro_f1'] = np.mean( |
|
[v['f1'] for k, v in node_macro_metric.items()]) |
|
|
|
return stats |
|
|