from typing import * import torch import json import argparse import os from tqdm import tqdm from sftp.predictor import SpanPredictor from sftp.models import SpanModel from sftp.data_reader import BetterDatasetReader def predict_doc(predictor, json_path: str): src = json.load(open(json_path)) for doc_name, entry in tqdm(list(src['entries'].items())): pred = predictor.predict_json(entry) triggers = list() for trigger in pred['prediction']: children = list() for child in trigger['children']: children.append([child['start_idx'], child['end_idx']]) triggers.append({ "span": [trigger['start_idx'], trigger['end_idx']], "argument": children }) entry['trigger span'] = triggers return src if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-a', type=str, help='archive path') parser.add_argument('-s', type=str, help='source path') parser.add_argument('-d', type=str, help='destination path') parser.add_argument('-c', type=int, default=0, help='cuda device') args = parser.parse_args() predictor_ = SpanPredictor.from_path(os.path.join(args.a, 'model.tar.gz'), 'span', cuda_device=args.c) model_name = os.path.basename(args.a) tgt_path = os.path.join(args.d, model_name) os.makedirs(tgt_path, exist_ok=True) for root, _, files in os.walk(args.s): for fn in files: if not fn.endswith('json') and not fn.endswith('valid'): continue processed_json = predict_doc(predictor_, os.path.join(root, fn)) with open(os.path.join(tgt_path, fn), 'w') as fp: json.dump(processed_json, fp)