|
|
|
|
|
import copy |
|
import os |
|
import warnings |
|
from argparse import ArgumentParser, Namespace |
|
from pathlib import Path |
|
|
|
import mmcv |
|
import numpy as np |
|
import torch |
|
from mmcv.image.misc import tensor2imgs |
|
from mmcv.runner import load_checkpoint |
|
from mmcv.utils.config import Config |
|
|
|
from mmocr.apis import init_detector |
|
from mmocr.apis.inference import model_inference |
|
from mmocr.core.visualize import det_recog_show_result |
|
from mmocr.datasets.kie_dataset import KIEDataset |
|
from mmocr.datasets.pipelines.crop import crop_img |
|
from mmocr.models import build_detector |
|
from mmocr.utils.box_util import stitch_boxes_into_lines |
|
from mmocr.utils.fileio import list_from_file |
|
from mmocr.utils.model import revert_sync_batchnorm |
|
|
|
|
|
|
|
def parse_args(): |
|
parser = ArgumentParser() |
|
parser.add_argument( |
|
'img', type=str, help='Input image file or folder path.') |
|
parser.add_argument( |
|
'--output', |
|
type=str, |
|
default='', |
|
help='Output file/folder name for visualization') |
|
parser.add_argument( |
|
'--det', |
|
type=str, |
|
default='PANet_IC15', |
|
help='Pretrained text detection algorithm') |
|
parser.add_argument( |
|
'--det-config', |
|
type=str, |
|
default='', |
|
help='Path to the custom config file of the selected det model. It ' |
|
'overrides the settings in det') |
|
parser.add_argument( |
|
'--det-ckpt', |
|
type=str, |
|
default='', |
|
help='Path to the custom checkpoint file of the selected det model. ' |
|
'It overrides the settings in det') |
|
parser.add_argument( |
|
'--recog', |
|
type=str, |
|
default='SEG', |
|
help='Pretrained text recognition algorithm') |
|
parser.add_argument( |
|
'--recog-config', |
|
type=str, |
|
default='', |
|
help='Path to the custom config file of the selected recog model. It' |
|
'overrides the settings in recog') |
|
parser.add_argument( |
|
'--recog-ckpt', |
|
type=str, |
|
default='', |
|
help='Path to the custom checkpoint file of the selected recog model. ' |
|
'It overrides the settings in recog') |
|
parser.add_argument( |
|
'--kie', |
|
type=str, |
|
default='', |
|
help='Pretrained key information extraction algorithm') |
|
parser.add_argument( |
|
'--kie-config', |
|
type=str, |
|
default='', |
|
help='Path to the custom config file of the selected kie model. It' |
|
'overrides the settings in kie') |
|
parser.add_argument( |
|
'--kie-ckpt', |
|
type=str, |
|
default='', |
|
help='Path to the custom checkpoint file of the selected kie model. ' |
|
'It overrides the settings in kie') |
|
parser.add_argument( |
|
'--config-dir', |
|
type=str, |
|
default=os.path.join(str(Path.cwd()), 'configs/'), |
|
help='Path to the config directory where all the config files ' |
|
'are located. Defaults to "configs/"') |
|
parser.add_argument( |
|
'--batch-mode', |
|
action='store_true', |
|
help='Whether use batch mode for inference') |
|
parser.add_argument( |
|
'--recog-batch-size', |
|
type=int, |
|
default=0, |
|
help='Batch size for text recognition') |
|
parser.add_argument( |
|
'--det-batch-size', |
|
type=int, |
|
default=0, |
|
help='Batch size for text detection') |
|
parser.add_argument( |
|
'--single-batch-size', |
|
type=int, |
|
default=0, |
|
help='Batch size for separate det/recog inference') |
|
parser.add_argument( |
|
'--device', default=None, help='Device used for inference.') |
|
parser.add_argument( |
|
'--export', |
|
type=str, |
|
default='', |
|
help='Folder where the results of each image are exported') |
|
parser.add_argument( |
|
'--export-format', |
|
type=str, |
|
default='json', |
|
help='Format of the exported result file(s)') |
|
parser.add_argument( |
|
'--details', |
|
action='store_true', |
|
help='Whether include the text boxes coordinates and confidence values' |
|
) |
|
parser.add_argument( |
|
'--imshow', |
|
action='store_true', |
|
help='Whether show image with OpenCV.') |
|
parser.add_argument( |
|
'--print-result', |
|
action='store_true', |
|
help='Prints the recognised text') |
|
parser.add_argument( |
|
'--merge', action='store_true', help='Merge neighboring boxes') |
|
parser.add_argument( |
|
'--merge-xdist', |
|
type=float, |
|
default=20, |
|
help='The maximum x-axis distance to merge boxes') |
|
args = parser.parse_args() |
|
if args.det == 'None': |
|
args.det = None |
|
if args.recog == 'None': |
|
args.recog = None |
|
|
|
if args.merge and not (args.det and args.recog): |
|
warnings.warn( |
|
'Box merging will not work if the script is not' |
|
' running in detection + recognition mode.', UserWarning) |
|
if not os.path.samefile(args.config_dir, os.path.join(str( |
|
Path.cwd()))) and (args.det_config != '' |
|
or args.recog_config != ''): |
|
warnings.warn( |
|
'config_dir will be overridden by det-config or recog-config.', |
|
UserWarning) |
|
return args |
|
|
|
|
|
class MMOCR: |
|
|
|
def __init__(self, |
|
det='PANet_IC15', |
|
det_config='', |
|
det_ckpt='', |
|
recog='SEG', |
|
recog_config='', |
|
recog_ckpt='', |
|
kie='', |
|
kie_config='', |
|
kie_ckpt='', |
|
config_dir=os.path.join(str(Path.cwd()), 'configs/'), |
|
device=None, |
|
**kwargs): |
|
|
|
textdet_models = { |
|
'DB_r18': { |
|
'config': |
|
'dbnet/dbnet_r18_fpnc_1200e_icdar2015.py', |
|
'ckpt': |
|
'dbnet/' |
|
'dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth' |
|
}, |
|
'DB_r50': { |
|
'config': |
|
'dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py', |
|
'ckpt': |
|
'dbnet/' |
|
'dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20211025-9fe3b590.pth' |
|
}, |
|
'DRRG': { |
|
'config': |
|
'drrg/drrg_r50_fpn_unet_1200e_ctw1500.py', |
|
'ckpt': |
|
'drrg/drrg_r50_fpn_unet_1200e_ctw1500_20211022-fb30b001.pth' |
|
}, |
|
'FCE_IC15': { |
|
'config': |
|
'fcenet/fcenet_r50_fpn_1500e_icdar2015.py', |
|
'ckpt': |
|
'fcenet/fcenet_r50_fpn_1500e_icdar2015_20211022-daefb6ed.pth' |
|
}, |
|
'FCE_CTW_DCNv2': { |
|
'config': |
|
'fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py', |
|
'ckpt': |
|
'fcenet/' + |
|
'fcenet_r50dcnv2_fpn_1500e_ctw1500_20211022-e326d7ec.pth' |
|
}, |
|
'MaskRCNN_CTW': { |
|
'config': |
|
'maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py', |
|
'ckpt': |
|
'maskrcnn/' |
|
'mask_rcnn_r50_fpn_160e_ctw1500_20210219-96497a76.pth' |
|
}, |
|
'MaskRCNN_IC15': { |
|
'config': |
|
'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py', |
|
'ckpt': |
|
'maskrcnn/' |
|
'mask_rcnn_r50_fpn_160e_icdar2015_20210219-8eb340a3.pth' |
|
}, |
|
'MaskRCNN_IC17': { |
|
'config': |
|
'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py', |
|
'ckpt': |
|
'maskrcnn/' |
|
'mask_rcnn_r50_fpn_160e_icdar2017_20210218-c6ec3ebb.pth' |
|
}, |
|
'PANet_CTW': { |
|
'config': |
|
'panet/panet_r18_fpem_ffm_600e_ctw1500.py', |
|
'ckpt': |
|
'panet/' |
|
'panet_r18_fpem_ffm_sbn_600e_ctw1500_20210219-3b3a9aa3.pth' |
|
}, |
|
'PANet_IC15': { |
|
'config': |
|
'panet/panet_r18_fpem_ffm_600e_icdar2015.py', |
|
'ckpt': |
|
'panet/' |
|
'panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth' |
|
}, |
|
'PS_CTW': { |
|
'config': 'psenet/psenet_r50_fpnf_600e_ctw1500.py', |
|
'ckpt': |
|
'psenet/psenet_r50_fpnf_600e_ctw1500_20210401-216fed50.pth' |
|
}, |
|
'PS_IC15': { |
|
'config': |
|
'psenet/psenet_r50_fpnf_600e_icdar2015.py', |
|
'ckpt': |
|
'psenet/psenet_r50_fpnf_600e_icdar2015_pretrain-eefd8fe6.pth' |
|
}, |
|
'TextSnake': { |
|
'config': |
|
'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py', |
|
'ckpt': |
|
'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500-27f65b64.pth' |
|
} |
|
} |
|
|
|
textrecog_models = { |
|
'CRNN': { |
|
'config': 'crnn/crnn_academic_dataset.py', |
|
'ckpt': 'crnn/crnn_academic-a723a1c5.pth' |
|
}, |
|
'SAR': { |
|
'config': 'sar/sar_r31_parallel_decoder_academic.py', |
|
'ckpt': 'sar/sar_r31_parallel_decoder_academic-dba3a4a3.pth' |
|
}, |
|
'SAR_CN': { |
|
'config': |
|
'sar/sar_r31_parallel_decoder_chinese.py', |
|
'ckpt': |
|
'sar/sar_r31_parallel_decoder_chineseocr_20210507-b4be8214.pth' |
|
}, |
|
'NRTR_1/16-1/8': { |
|
'config': 'nrtr/nrtr_r31_1by16_1by8_academic.py', |
|
'ckpt': |
|
'nrtr/nrtr_r31_1by16_1by8_academic_20211124-f60cebf4.pth' |
|
}, |
|
'NRTR_1/8-1/4': { |
|
'config': 'nrtr/nrtr_r31_1by8_1by4_academic.py', |
|
'ckpt': |
|
'nrtr/nrtr_r31_1by8_1by4_academic_20211123-e1fdb322.pth' |
|
}, |
|
'RobustScanner': { |
|
'config': 'robust_scanner/robustscanner_r31_academic.py', |
|
'ckpt': 'robustscanner/robustscanner_r31_academic-5f05874f.pth' |
|
}, |
|
'SATRN': { |
|
'config': 'satrn/satrn_academic.py', |
|
'ckpt': 'satrn/satrn_academic_20211009-cb8b1580.pth' |
|
}, |
|
'SATRN_sm': { |
|
'config': 'satrn/satrn_small.py', |
|
'ckpt': 'satrn/satrn_small_20211009-2cf13355.pth' |
|
}, |
|
'ABINet': { |
|
'config': 'abinet/abinet_academic.py', |
|
'ckpt': 'abinet/abinet_academic-f718abf6.pth' |
|
}, |
|
'SEG': { |
|
'config': 'seg/seg_r31_1by16_fpnocr_academic.py', |
|
'ckpt': 'seg/seg_r31_1by16_fpnocr_academic-72235b11.pth' |
|
}, |
|
'CRNN_TPS': { |
|
'config': 'tps/crnn_tps_academic_dataset.py', |
|
'ckpt': 'tps/crnn_tps_academic_dataset_20210510-d221a905.pth' |
|
} |
|
} |
|
|
|
kie_models = { |
|
'SDMGR': { |
|
'config': 'sdmgr/sdmgr_unet16_60e_wildreceipt.py', |
|
'ckpt': |
|
'sdmgr/sdmgr_unet16_60e_wildreceipt_20210520-7489e6de.pth' |
|
} |
|
} |
|
|
|
self.td = det |
|
self.tr = recog |
|
self.kie = kie |
|
self.device = device |
|
if self.device is None: |
|
self.device = torch.device( |
|
'cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
if self.td and self.td not in textdet_models: |
|
raise ValueError(self.td, |
|
'is not a supported text detection algorthm') |
|
elif self.tr and self.tr not in textrecog_models: |
|
raise ValueError(self.tr, |
|
'is not a supported text recognition algorithm') |
|
elif self.kie: |
|
if self.kie not in kie_models: |
|
raise ValueError( |
|
self.kie, 'is not a supported key information extraction' |
|
' algorithm') |
|
elif not (self.td and self.tr): |
|
raise NotImplementedError( |
|
self.kie, 'has to run together' |
|
' with text detection and recognition algorithms.') |
|
|
|
self.detect_model = None |
|
if self.td: |
|
|
|
if not det_config: |
|
det_config = os.path.join(config_dir, 'textdet/', |
|
textdet_models[self.td]['config']) |
|
if not det_ckpt: |
|
det_ckpt = 'https://download.openmmlab.com/mmocr/textdet/' + \ |
|
textdet_models[self.td]['ckpt'] |
|
|
|
self.detect_model = init_detector( |
|
det_config, det_ckpt, device=self.device) |
|
self.detect_model = revert_sync_batchnorm(self.detect_model) |
|
|
|
self.recog_model = None |
|
if self.tr: |
|
|
|
if not recog_config: |
|
recog_config = os.path.join( |
|
config_dir, 'textrecog/', |
|
textrecog_models[self.tr]['config']) |
|
if not recog_ckpt: |
|
recog_ckpt = 'https://download.openmmlab.com/mmocr/' + \ |
|
'textrecog/' + textrecog_models[self.tr]['ckpt'] |
|
|
|
self.recog_model = init_detector( |
|
recog_config, recog_ckpt, device=self.device) |
|
self.recog_model = revert_sync_batchnorm(self.recog_model) |
|
|
|
self.kie_model = None |
|
if self.kie: |
|
|
|
if not kie_config: |
|
kie_config = os.path.join(config_dir, 'kie/', |
|
kie_models[self.kie]['config']) |
|
if not kie_ckpt: |
|
kie_ckpt = 'https://download.openmmlab.com/mmocr/' + \ |
|
'kie/' + kie_models[self.kie]['ckpt'] |
|
|
|
kie_cfg = Config.fromfile(kie_config) |
|
self.kie_model = build_detector( |
|
kie_cfg.model, test_cfg=kie_cfg.get('test_cfg')) |
|
self.kie_model = revert_sync_batchnorm(self.kie_model) |
|
self.kie_model.cfg = kie_cfg |
|
load_checkpoint(self.kie_model, kie_ckpt, map_location=self.device) |
|
|
|
|
|
for model in list(filter(None, [self.recog_model, self.detect_model])): |
|
if hasattr(model, 'module'): |
|
model = model.module |
|
|
|
def readtext(self, |
|
img, |
|
output=None, |
|
details=False, |
|
export=None, |
|
export_format='json', |
|
batch_mode=False, |
|
recog_batch_size=0, |
|
det_batch_size=0, |
|
single_batch_size=0, |
|
imshow=False, |
|
print_result=False, |
|
merge=False, |
|
merge_xdist=20, |
|
**kwargs): |
|
args = locals().copy() |
|
[args.pop(x, None) for x in ['kwargs', 'self']] |
|
args = Namespace(**args) |
|
|
|
|
|
self._args_processing(args) |
|
self.args = args |
|
|
|
pp_result = None |
|
|
|
|
|
|
|
if self.detect_model and self.recog_model: |
|
det_recog_result = self.det_recog_kie_inference( |
|
self.detect_model, self.recog_model, kie_model=self.kie_model) |
|
pp_result = self.det_recog_pp(det_recog_result) |
|
else: |
|
for model in list( |
|
filter(None, [self.recog_model, self.detect_model])): |
|
result = self.single_inference(model, args.arrays, |
|
args.batch_mode, |
|
args.single_batch_size) |
|
pp_result = self.single_pp(result, model) |
|
|
|
return pp_result |
|
|
|
|
|
def det_recog_pp(self, result): |
|
final_results = [] |
|
args = self.args |
|
for arr, output, export, det_recog_result in zip( |
|
args.arrays, args.output, args.export, result): |
|
if output or args.imshow: |
|
if self.kie_model: |
|
res_img = det_recog_show_result(arr, det_recog_result) |
|
else: |
|
res_img = det_recog_show_result( |
|
arr, det_recog_result, out_file=output) |
|
if args.imshow and not self.kie_model: |
|
mmcv.imshow(res_img, 'inference results') |
|
if not args.details: |
|
simple_res = {} |
|
simple_res['filename'] = det_recog_result['filename'] |
|
simple_res['text'] = [ |
|
x['text'] for x in det_recog_result['result'] |
|
] |
|
final_result = simple_res |
|
else: |
|
final_result = det_recog_result |
|
if export: |
|
mmcv.dump(final_result, export, indent=4) |
|
if args.print_result: |
|
print(final_result, end='\n\n') |
|
final_results.append(final_result) |
|
return final_results |
|
|
|
|
|
def single_pp(self, result, model): |
|
for arr, output, export, res in zip(self.args.arrays, self.args.output, |
|
self.args.export, result): |
|
if export: |
|
mmcv.dump(res, export, indent=4) |
|
if output or self.args.imshow: |
|
res_img = model.show_result(arr, res, out_file=output) |
|
if self.args.imshow: |
|
mmcv.imshow(res_img, 'inference results') |
|
if self.args.print_result: |
|
print(res, end='\n\n') |
|
return result |
|
|
|
def generate_kie_labels(self, result, boxes, class_list): |
|
idx_to_cls = {} |
|
if class_list is not None: |
|
for line in list_from_file(class_list): |
|
class_idx, class_label = line.strip().split() |
|
idx_to_cls[class_idx] = class_label |
|
|
|
max_value, max_idx = torch.max(result['nodes'].detach().cpu(), -1) |
|
node_pred_label = max_idx.numpy().tolist() |
|
node_pred_score = max_value.numpy().tolist() |
|
labels = [] |
|
for i in range(len(boxes)): |
|
pred_label = str(node_pred_label[i]) |
|
if pred_label in idx_to_cls: |
|
pred_label = idx_to_cls[pred_label] |
|
pred_score = node_pred_score[i] |
|
labels.append((pred_label, pred_score)) |
|
return labels |
|
|
|
def visualize_kie_output(self, |
|
model, |
|
data, |
|
result, |
|
out_file=None, |
|
show=False): |
|
"""Visualizes KIE output.""" |
|
img_tensor = data['img'].data |
|
img_meta = data['img_metas'].data |
|
gt_bboxes = data['gt_bboxes'].data.numpy().tolist() |
|
if img_tensor.dtype == torch.uint8: |
|
|
|
|
|
img = img_tensor.cpu().numpy().transpose(1, 2, 0) |
|
else: |
|
img = tensor2imgs( |
|
img_tensor.unsqueeze(0), **img_meta.get('img_norm_cfg', {}))[0] |
|
h, w, _ = img_meta.get('img_shape', img.shape) |
|
img_show = img[:h, :w, :] |
|
model.show_result( |
|
img_show, result, gt_bboxes, show=show, out_file=out_file) |
|
|
|
|
|
def det_recog_kie_inference(self, det_model, recog_model, kie_model=None): |
|
end2end_res = [] |
|
|
|
det_result = self.single_inference(det_model, self.args.arrays, |
|
self.args.batch_mode, |
|
self.args.det_batch_size) |
|
bboxes_list = [res['boundary_result'] for res in det_result] |
|
|
|
if kie_model: |
|
kie_dataset = KIEDataset( |
|
dict_file=kie_model.cfg.data.test.dict_file) |
|
|
|
|
|
|
|
|
|
for filename, arr, bboxes, out_file in zip(self.args.filenames, |
|
self.args.arrays, |
|
bboxes_list, |
|
self.args.output): |
|
img_e2e_res = {} |
|
img_e2e_res['filename'] = filename |
|
img_e2e_res['result'] = [] |
|
box_imgs = [] |
|
for bbox in bboxes: |
|
box_res = {} |
|
box_res['box'] = [round(x) for x in bbox[:-1]] |
|
box_res['box_score'] = float(bbox[-1]) |
|
box = bbox[:8] |
|
if len(bbox) > 9: |
|
min_x = min(bbox[0:-1:2]) |
|
min_y = min(bbox[1:-1:2]) |
|
max_x = max(bbox[0:-1:2]) |
|
max_y = max(bbox[1:-1:2]) |
|
box = [ |
|
min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y |
|
] |
|
box_img = crop_img(arr, box) |
|
if self.args.batch_mode: |
|
box_imgs.append(box_img) |
|
else: |
|
recog_result = model_inference(recog_model, box_img) |
|
text = recog_result['text'] |
|
text_score = recog_result['score'] |
|
if isinstance(text_score, list): |
|
text_score = sum(text_score) / max(1, len(text)) |
|
box_res['text'] = text |
|
box_res['text_score'] = text_score |
|
img_e2e_res['result'].append(box_res) |
|
|
|
if self.args.batch_mode: |
|
recog_results = self.single_inference( |
|
recog_model, box_imgs, True, self.args.recog_batch_size) |
|
for i, recog_result in enumerate(recog_results): |
|
text = recog_result['text'] |
|
text_score = recog_result['score'] |
|
if isinstance(text_score, (list, tuple)): |
|
text_score = sum(text_score) / max(1, len(text)) |
|
img_e2e_res['result'][i]['text'] = text |
|
img_e2e_res['result'][i]['text_score'] = text_score |
|
|
|
if self.args.merge: |
|
img_e2e_res['result'] = stitch_boxes_into_lines( |
|
img_e2e_res['result'], self.args.merge_xdist, 0.5) |
|
|
|
if kie_model: |
|
annotations = copy.deepcopy(img_e2e_res['result']) |
|
|
|
|
|
for i, ann in enumerate(annotations): |
|
min_x = min(ann['box'][::2]) |
|
min_y = min(ann['box'][1::2]) |
|
max_x = max(ann['box'][::2]) |
|
max_y = max(ann['box'][1::2]) |
|
annotations[i]['box'] = [ |
|
min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y |
|
] |
|
ann_info = kie_dataset._parse_anno_info(annotations) |
|
ann_info['ori_bboxes'] = ann_info.get('ori_bboxes', |
|
ann_info['bboxes']) |
|
ann_info['gt_bboxes'] = ann_info.get('gt_bboxes', |
|
ann_info['bboxes']) |
|
kie_result, data = model_inference( |
|
kie_model, |
|
arr, |
|
ann=ann_info, |
|
return_data=True, |
|
batch_mode=self.args.batch_mode) |
|
|
|
self.visualize_kie_output( |
|
kie_model, |
|
data, |
|
kie_result, |
|
out_file=out_file, |
|
show=self.args.imshow) |
|
gt_bboxes = data['gt_bboxes'].data.numpy().tolist() |
|
labels = self.generate_kie_labels(kie_result, gt_bboxes, |
|
kie_model.class_list) |
|
for i in range(len(gt_bboxes)): |
|
img_e2e_res['result'][i]['label'] = labels[i][0] |
|
img_e2e_res['result'][i]['label_score'] = labels[i][1] |
|
|
|
end2end_res.append(img_e2e_res) |
|
return end2end_res |
|
|
|
|
|
def single_inference(self, model, arrays, batch_mode, batch_size=0): |
|
result = [] |
|
if batch_mode: |
|
if batch_size == 0: |
|
result = model_inference(model, arrays, batch_mode=True) |
|
else: |
|
n = batch_size |
|
arr_chunks = [ |
|
arrays[i:i + n] for i in range(0, len(arrays), n) |
|
] |
|
for chunk in arr_chunks: |
|
result.extend( |
|
model_inference(model, chunk, batch_mode=True)) |
|
else: |
|
for arr in arrays: |
|
result.append(model_inference(model, arr, batch_mode=False)) |
|
return result |
|
|
|
|
|
def _args_processing(self, args): |
|
|
|
|
|
if isinstance(args.img, (list, tuple)): |
|
img_list = args.img |
|
if not all([isinstance(x, (np.ndarray, str)) for x in args.img]): |
|
raise AssertionError('Images must be strings or numpy arrays') |
|
|
|
|
|
if isinstance(args.img, str): |
|
img_path = Path(args.img) |
|
if img_path.is_dir(): |
|
img_list = [str(x) for x in img_path.glob('*')] |
|
else: |
|
img_list = [str(img_path)] |
|
elif isinstance(args.img, np.ndarray): |
|
img_list = [args.img] |
|
|
|
|
|
|
|
args.arrays = [mmcv.imread(x) for x in img_list] |
|
|
|
|
|
if isinstance(img_list[0], str): |
|
args.filenames = [str(Path(x).stem) for x in img_list] |
|
else: |
|
args.filenames = [str(x) for x in range(len(img_list))] |
|
|
|
|
|
num_res = len(img_list) |
|
if args.output: |
|
output_path = Path(args.output) |
|
if output_path.is_dir(): |
|
args.output = [ |
|
str(output_path / f'out_{x}.png') for x in args.filenames |
|
] |
|
else: |
|
args.output = [str(args.output)] |
|
if args.batch_mode: |
|
raise AssertionError('Output of multiple images inference' |
|
' must be a directory') |
|
else: |
|
args.output = [None] * num_res |
|
|
|
|
|
|
|
if args.export: |
|
export_path = Path(args.export) |
|
args.export = [ |
|
str(export_path / f'out_{x}.{args.export_format}') |
|
for x in args.filenames |
|
] |
|
else: |
|
args.export = [None] * num_res |
|
|
|
return args |
|
|
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
ocr = MMOCR(**vars(args)) |
|
ocr.readtext(**vars(args)) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|