|
|
|
|
|
import os.path as osp |
|
import shutil |
|
import time |
|
from argparse import ArgumentParser |
|
from itertools import compress |
|
|
|
import mmcv |
|
from mmcv.utils import ProgressBar |
|
|
|
from mmocr.apis import init_detector, model_inference |
|
from mmocr.core.evaluation.ocr_metric import eval_ocr_metric |
|
from mmocr.datasets import build_dataset |
|
from mmocr.models import build_detector |
|
from mmocr.utils import get_root_logger, list_from_file, list_to_file |
|
|
|
|
|
def save_results(img_paths, pred_labels, gt_labels, res_dir): |
|
"""Save predicted results to txt file. |
|
|
|
Args: |
|
img_paths (list[str]) |
|
pred_labels (list[str]) |
|
gt_labels (list[str]) |
|
res_dir (str) |
|
""" |
|
assert len(img_paths) == len(pred_labels) == len(gt_labels) |
|
corrects = [pred == gt for pred, gt in zip(pred_labels, gt_labels)] |
|
wrongs = [not c for c in corrects] |
|
lines = [ |
|
f'{img} {pred} {gt}' |
|
for img, pred, gt in zip(img_paths, pred_labels, gt_labels) |
|
] |
|
list_to_file(osp.join(res_dir, 'results.txt'), lines) |
|
list_to_file(osp.join(res_dir, 'correct.txt'), compress(lines, corrects)) |
|
list_to_file(osp.join(res_dir, 'wrong.txt'), compress(lines, wrongs)) |
|
|
|
|
|
def main(): |
|
parser = ArgumentParser() |
|
parser.add_argument('img_root_path', type=str, help='Image root path') |
|
parser.add_argument('img_list', type=str, help='Image path list file') |
|
parser.add_argument('config', type=str, help='Config file') |
|
parser.add_argument('checkpoint', type=str, help='Checkpoint file') |
|
parser.add_argument( |
|
'--out_dir', type=str, default='./results', help='Dir to save results') |
|
parser.add_argument( |
|
'--show', action='store_true', help='show image or save') |
|
parser.add_argument( |
|
'--device', default='cuda:0', help='Device used for inference.') |
|
args = parser.parse_args() |
|
|
|
|
|
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) |
|
log_file = osp.join(args.out_dir, f'{timestamp}.log') |
|
logger = get_root_logger(log_file=log_file, log_level='INFO') |
|
|
|
|
|
model = init_detector(args.config, args.checkpoint, device=args.device) |
|
if hasattr(model, 'module'): |
|
model = model.module |
|
|
|
|
|
out_vis_dir = osp.join(args.out_dir, 'out_vis_dir') |
|
mmcv.mkdir_or_exist(out_vis_dir) |
|
correct_vis_dir = osp.join(args.out_dir, 'correct') |
|
mmcv.mkdir_or_exist(correct_vis_dir) |
|
wrong_vis_dir = osp.join(args.out_dir, 'wrong') |
|
mmcv.mkdir_or_exist(wrong_vis_dir) |
|
img_paths, pred_labels, gt_labels = [], [], [] |
|
|
|
lines = list_from_file(args.img_list) |
|
progressbar = ProgressBar(task_num=len(lines)) |
|
num_gt_label = 0 |
|
for line in lines: |
|
progressbar.update() |
|
item_list = line.strip().split() |
|
img_file = item_list[0] |
|
gt_label = '' |
|
if len(item_list) >= 2: |
|
gt_label = item_list[1] |
|
num_gt_label += 1 |
|
img_path = osp.join(args.img_root_path, img_file) |
|
if not osp.exists(img_path): |
|
raise FileNotFoundError(img_path) |
|
|
|
result = model_inference(model, img_path) |
|
pred_label = result['text'] |
|
|
|
out_img_name = '_'.join(img_file.split('/')) |
|
out_file = osp.join(out_vis_dir, out_img_name) |
|
kwargs_dict = { |
|
'gt_label': gt_label, |
|
'show': args.show, |
|
'out_file': '' if args.show else out_file |
|
} |
|
model.show_result(img_path, result, **kwargs_dict) |
|
if gt_label != '': |
|
if gt_label == pred_label: |
|
dst_file = osp.join(correct_vis_dir, out_img_name) |
|
else: |
|
dst_file = osp.join(wrong_vis_dir, out_img_name) |
|
shutil.copy(out_file, dst_file) |
|
img_paths.append(img_path) |
|
gt_labels.append(gt_label) |
|
pred_labels.append(pred_label) |
|
|
|
|
|
save_results(img_paths, pred_labels, gt_labels, args.out_dir) |
|
|
|
if num_gt_label == len(pred_labels): |
|
|
|
eval_results = eval_ocr_metric(pred_labels, gt_labels) |
|
logger.info('\n' + '-' * 100) |
|
info = ('eval on testset with img_root_path ' |
|
f'{args.img_root_path} and img_list {args.img_list}\n') |
|
logger.info(info) |
|
logger.info(eval_results) |
|
|
|
print(f'\nInference done, and results saved in {args.out_dir}\n') |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|