|
|
|
import mmcv |
|
import numpy as np |
|
from mmdet.datasets.api_wrappers import COCO |
|
from mmdet.datasets.builder import DATASETS |
|
from mmdet.datasets.coco import CocoDataset |
|
|
|
import mmocr.utils as utils |
|
from mmocr import digit_version |
|
from mmocr.core.evaluation.hmean import eval_hmean |
|
|
|
|
|
@DATASETS.register_module() |
|
class IcdarDataset(CocoDataset): |
|
"""Dataset for text detection while ann_file in coco format. |
|
|
|
Args: |
|
ann_file_backend (str): Storage backend for annotation file, |
|
should be one in ['disk', 'petrel', 'http']. Default to 'disk'. |
|
""" |
|
CLASSES = ('text') |
|
|
|
def __init__(self, |
|
ann_file, |
|
pipeline, |
|
classes=None, |
|
data_root=None, |
|
img_prefix='', |
|
seg_prefix=None, |
|
proposal_file=None, |
|
test_mode=False, |
|
filter_empty_gt=True, |
|
select_first_k=-1, |
|
ann_file_backend='disk'): |
|
|
|
self.select_first_k = select_first_k |
|
assert ann_file_backend in ['disk', 'petrel', 'http'] |
|
self.ann_file_backend = ann_file_backend |
|
|
|
super().__init__(ann_file, pipeline, classes, data_root, img_prefix, |
|
seg_prefix, proposal_file, test_mode, filter_empty_gt) |
|
|
|
def load_annotations(self, ann_file): |
|
"""Load annotation from COCO style annotation file. |
|
|
|
Args: |
|
ann_file (str): Path of annotation file. |
|
|
|
Returns: |
|
list[dict]: Annotation info from COCO api. |
|
""" |
|
if self.ann_file_backend == 'disk': |
|
self.coco = COCO(ann_file) |
|
else: |
|
mmcv_version = digit_version(mmcv.__version__) |
|
if mmcv_version < digit_version('1.3.16'): |
|
raise Exception('Please update mmcv to 1.3.16 or higher ' |
|
'to enable "get_local_path" of "FileClient".') |
|
file_client = mmcv.FileClient(backend=self.ann_file_backend) |
|
with file_client.get_local_path(ann_file) as local_path: |
|
self.coco = COCO(local_path) |
|
self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES) |
|
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} |
|
self.img_ids = self.coco.get_img_ids() |
|
data_infos = [] |
|
|
|
count = 0 |
|
for i in self.img_ids: |
|
info = self.coco.load_imgs([i])[0] |
|
info['filename'] = info['file_name'] |
|
data_infos.append(info) |
|
count = count + 1 |
|
if count > self.select_first_k and self.select_first_k > 0: |
|
break |
|
return data_infos |
|
|
|
def _parse_ann_info(self, img_info, ann_info): |
|
"""Parse bbox and mask annotation. |
|
|
|
Args: |
|
ann_info (list[dict]): Annotation info of an image. |
|
|
|
Returns: |
|
dict: A dict containing the following keys: bboxes, bboxes_ignore, |
|
labels, masks, masks_ignore, seg_map. "masks" and |
|
"masks_ignore" are represented by polygon boundary |
|
point sequences. |
|
""" |
|
gt_bboxes = [] |
|
gt_labels = [] |
|
gt_bboxes_ignore = [] |
|
gt_masks_ignore = [] |
|
gt_masks_ann = [] |
|
|
|
for ann in ann_info: |
|
if ann.get('ignore', False): |
|
continue |
|
x1, y1, w, h = ann['bbox'] |
|
if ann['area'] <= 0 or w < 1 or h < 1: |
|
continue |
|
if ann['category_id'] not in self.cat_ids: |
|
continue |
|
bbox = [x1, y1, x1 + w, y1 + h] |
|
if ann.get('iscrowd', False): |
|
gt_bboxes_ignore.append(bbox) |
|
gt_masks_ignore.append(ann.get( |
|
'segmentation', None)) |
|
|
|
else: |
|
gt_bboxes.append(bbox) |
|
gt_labels.append(self.cat2label[ann['category_id']]) |
|
gt_masks_ann.append(ann.get('segmentation', None)) |
|
if gt_bboxes: |
|
gt_bboxes = np.array(gt_bboxes, dtype=np.float32) |
|
gt_labels = np.array(gt_labels, dtype=np.int64) |
|
else: |
|
gt_bboxes = np.zeros((0, 4), dtype=np.float32) |
|
gt_labels = np.array([], dtype=np.int64) |
|
|
|
if gt_bboxes_ignore: |
|
gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32) |
|
else: |
|
gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32) |
|
|
|
seg_map = img_info['filename'].replace('jpg', 'png') |
|
|
|
ann = dict( |
|
bboxes=gt_bboxes, |
|
labels=gt_labels, |
|
bboxes_ignore=gt_bboxes_ignore, |
|
masks_ignore=gt_masks_ignore, |
|
masks=gt_masks_ann, |
|
seg_map=seg_map) |
|
|
|
return ann |
|
|
|
def evaluate(self, |
|
results, |
|
metric='hmean-iou', |
|
logger=None, |
|
score_thr=0.3, |
|
rank_list=None, |
|
**kwargs): |
|
"""Evaluate the hmean metric. |
|
|
|
Args: |
|
results (list[dict]): Testing results of the dataset. |
|
metric (str | list[str]): Metrics to be evaluated. |
|
logger (logging.Logger | str | None): Logger used for printing |
|
related information during evaluation. Default: None. |
|
rank_list (str): json file used to save eval result |
|
of each image after ranking. |
|
Returns: |
|
dict[dict[str: float]]: The evaluation results. |
|
""" |
|
assert utils.is_type_list(results, dict) |
|
|
|
metrics = metric if isinstance(metric, list) else [metric] |
|
allowed_metrics = ['hmean-iou', 'hmean-ic13'] |
|
metrics = set(metrics) & set(allowed_metrics) |
|
|
|
img_infos = [] |
|
ann_infos = [] |
|
for i in range(len(self)): |
|
img_info = {'filename': self.data_infos[i]['file_name']} |
|
img_infos.append(img_info) |
|
ann_infos.append(self.get_ann_info(i)) |
|
|
|
eval_results = eval_hmean( |
|
results, |
|
img_infos, |
|
ann_infos, |
|
metrics=metrics, |
|
score_thr=score_thr, |
|
logger=logger, |
|
rank_list=rank_list) |
|
|
|
return eval_results |
|
|