Spaces:
Build error
Build error
| # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import numpy as np | |
| import os | |
| import sys | |
| __dir__ = os.path.dirname(os.path.abspath(__file__)) | |
| sys.path.append(__dir__) | |
| sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) | |
| os.environ["FLAGS_allocator_strategy"] = 'auto_growth' | |
| import cv2 | |
| import json | |
| import paddle | |
| import paddle.distributed as dist | |
| from ppocr.data import create_operators, transform | |
| from ppocr.modeling.architectures import build_model | |
| from ppocr.postprocess import build_post_process | |
| from ppocr.utils.save_load import load_model | |
| from ppocr.utils.visual import draw_re_results | |
| from ppocr.utils.logging import get_logger | |
| from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps, print_dict | |
| from tools.program import ArgsParser, load_config, merge_config | |
| from tools.infer_kie_token_ser import SerPredictor | |
| class ReArgsParser(ArgsParser): | |
| def __init__(self): | |
| super(ReArgsParser, self).__init__() | |
| self.add_argument( | |
| "-c_ser", "--config_ser", help="ser configuration file to use") | |
| self.add_argument( | |
| "-o_ser", | |
| "--opt_ser", | |
| nargs='+', | |
| help="set ser configuration options ") | |
| def parse_args(self, argv=None): | |
| args = super(ReArgsParser, self).parse_args(argv) | |
| assert args.config_ser is not None, \ | |
| "Please specify --config_ser=ser_configure_file_path." | |
| args.opt_ser = self._parse_opt(args.opt_ser) | |
| return args | |
| def make_input(ser_inputs, ser_results): | |
| entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2} | |
| batch_size, max_seq_len = ser_inputs[0].shape[:2] | |
| entities = ser_inputs[8][0] | |
| ser_results = ser_results[0] | |
| assert len(entities) == len(ser_results) | |
| # entities | |
| start = [] | |
| end = [] | |
| label = [] | |
| entity_idx_dict = {} | |
| for i, (res, entity) in enumerate(zip(ser_results, entities)): | |
| if res['pred'] == 'O': | |
| continue | |
| entity_idx_dict[len(start)] = i | |
| start.append(entity['start']) | |
| end.append(entity['end']) | |
| label.append(entities_labels[res['pred']]) | |
| entities = np.full([max_seq_len + 1, 3], fill_value=-1, dtype=np.int64) | |
| entities[0, 0] = len(start) | |
| entities[1:len(start) + 1, 0] = start | |
| entities[0, 1] = len(end) | |
| entities[1:len(end) + 1, 1] = end | |
| entities[0, 2] = len(label) | |
| entities[1:len(label) + 1, 2] = label | |
| # relations | |
| head = [] | |
| tail = [] | |
| for i in range(len(label)): | |
| for j in range(len(label)): | |
| if label[i] == 1 and label[j] == 2: | |
| head.append(i) | |
| tail.append(j) | |
| relations = np.full([len(head) + 1, 2], fill_value=-1, dtype=np.int64) | |
| relations[0, 0] = len(head) | |
| relations[1:len(head) + 1, 0] = head | |
| relations[0, 1] = len(tail) | |
| relations[1:len(tail) + 1, 1] = tail | |
| entities = np.expand_dims(entities, axis=0) | |
| entities = np.repeat(entities, batch_size, axis=0) | |
| relations = np.expand_dims(relations, axis=0) | |
| relations = np.repeat(relations, batch_size, axis=0) | |
| # remove ocr_info segment_offset_id and label in ser input | |
| if isinstance(ser_inputs[0], paddle.Tensor): | |
| entities = paddle.to_tensor(entities) | |
| relations = paddle.to_tensor(relations) | |
| ser_inputs = ser_inputs[:5] + [entities, relations] | |
| entity_idx_dict_batch = [] | |
| for b in range(batch_size): | |
| entity_idx_dict_batch.append(entity_idx_dict) | |
| return ser_inputs, entity_idx_dict_batch | |
| class SerRePredictor(object): | |
| def __init__(self, config, ser_config): | |
| global_config = config['Global'] | |
| if "infer_mode" in global_config: | |
| ser_config["Global"]["infer_mode"] = global_config["infer_mode"] | |
| self.ser_engine = SerPredictor(ser_config) | |
| # init re model | |
| # build post process | |
| self.post_process_class = build_post_process(config['PostProcess'], | |
| global_config) | |
| # build model | |
| self.model = build_model(config['Architecture']) | |
| load_model( | |
| config, self.model, model_type=config['Architecture']["model_type"]) | |
| self.model.eval() | |
| def __call__(self, data): | |
| ser_results, ser_inputs = self.ser_engine(data) | |
| re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results) | |
| if self.model.backbone.use_visual_backbone is False: | |
| re_input.pop(4) | |
| preds = self.model(re_input) | |
| post_result = self.post_process_class( | |
| preds, | |
| ser_results=ser_results, | |
| entity_idx_dict_batch=entity_idx_dict_batch) | |
| return post_result | |
| def preprocess(): | |
| FLAGS = ReArgsParser().parse_args() | |
| config = load_config(FLAGS.config) | |
| config = merge_config(config, FLAGS.opt) | |
| ser_config = load_config(FLAGS.config_ser) | |
| ser_config = merge_config(ser_config, FLAGS.opt_ser) | |
| logger = get_logger() | |
| # check if set use_gpu=True in paddlepaddle cpu version | |
| use_gpu = config['Global']['use_gpu'] | |
| device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' | |
| device = paddle.set_device(device) | |
| logger.info('{} re config {}'.format('*' * 10, '*' * 10)) | |
| print_dict(config, logger) | |
| logger.info('\n') | |
| logger.info('{} ser config {}'.format('*' * 10, '*' * 10)) | |
| print_dict(ser_config, logger) | |
| logger.info('train with paddle {} and device {}'.format(paddle.__version__, | |
| device)) | |
| return config, ser_config, device, logger | |
| if __name__ == '__main__': | |
| config, ser_config, device, logger = preprocess() | |
| os.makedirs(config['Global']['save_res_path'], exist_ok=True) | |
| ser_re_engine = SerRePredictor(config, ser_config) | |
| if config["Global"].get("infer_mode", None) is False: | |
| data_dir = config['Eval']['dataset']['data_dir'] | |
| with open(config['Global']['infer_img'], "rb") as f: | |
| infer_imgs = f.readlines() | |
| else: | |
| infer_imgs = get_image_file_list(config['Global']['infer_img']) | |
| with open( | |
| os.path.join(config['Global']['save_res_path'], | |
| "infer_results.txt"), | |
| "w", | |
| encoding='utf-8') as fout: | |
| for idx, info in enumerate(infer_imgs): | |
| if config["Global"].get("infer_mode", None) is False: | |
| data_line = info.decode('utf-8') | |
| substr = data_line.strip("\n").split("\t") | |
| img_path = os.path.join(data_dir, substr[0]) | |
| data = {'img_path': img_path, 'label': substr[1]} | |
| else: | |
| img_path = info | |
| data = {'img_path': img_path} | |
| save_img_path = os.path.join( | |
| config['Global']['save_res_path'], | |
| os.path.splitext(os.path.basename(img_path))[0] + "_ser_re.jpg") | |
| result = ser_re_engine(data) | |
| result = result[0] | |
| fout.write(img_path + "\t" + json.dumps( | |
| result, ensure_ascii=False) + "\n") | |
| img_res = draw_re_results(img_path, result) | |
| cv2.imwrite(save_img_path, img_res) | |
| logger.info("process: [{}/{}], save result to {}".format( | |
| idx, len(infer_imgs), save_img_path)) | |