|
|
|
import argparse |
|
import math |
|
import os.path as osp |
|
|
|
import mmcv |
|
|
|
from mmocr.utils import convert_annotations |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser( |
|
description='Generate training and validation set of TextOCR ') |
|
parser.add_argument('root_path', help='Root dir path of TextOCR') |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def collect_textocr_info(root_path, annotation_filename, print_every=1000): |
|
|
|
annotation_path = osp.join(root_path, annotation_filename) |
|
if not osp.exists(annotation_path): |
|
raise Exception( |
|
f'{annotation_path} not exists, please check and try again.') |
|
|
|
annotation = mmcv.load(annotation_path) |
|
|
|
|
|
img_infos = [] |
|
for i, img_info in enumerate(annotation['imgs'].values()): |
|
if i > 0 and i % print_every == 0: |
|
print(f'{i}/{len(annotation["imgs"].values())}') |
|
|
|
img_info['segm_file'] = annotation_path |
|
ann_ids = annotation['imgToAnns'][img_info['id']] |
|
anno_info = [] |
|
for ann_id in ann_ids: |
|
ann = annotation['anns'][ann_id] |
|
|
|
|
|
text_label = ann['utf8_string'] |
|
iscrowd = 1 if text_label == '.' else 0 |
|
|
|
x, y, w, h = ann['bbox'] |
|
x, y = max(0, math.floor(x)), max(0, math.floor(y)) |
|
w, h = math.ceil(w), math.ceil(h) |
|
bbox = [x, y, w, h] |
|
segmentation = [max(0, int(x)) for x in ann['points']] |
|
anno = dict( |
|
iscrowd=iscrowd, |
|
category_id=1, |
|
bbox=bbox, |
|
area=ann['area'], |
|
segmentation=[segmentation]) |
|
anno_info.append(anno) |
|
img_info.update(anno_info=anno_info) |
|
img_infos.append(img_info) |
|
return img_infos |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
root_path = args.root_path |
|
print('Processing training set...') |
|
training_infos = collect_textocr_info(root_path, 'TextOCR_0.1_train.json') |
|
convert_annotations(training_infos, |
|
osp.join(root_path, 'instances_training.json')) |
|
print('Processing validation set...') |
|
val_infos = collect_textocr_info(root_path, 'TextOCR_0.1_val.json') |
|
convert_annotations(val_infos, osp.join(root_path, 'instances_val.json')) |
|
print('Finish') |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|