|
from typing import * |
|
import torch |
|
import json |
|
import argparse |
|
import os |
|
from tqdm import tqdm |
|
|
|
from sftp.predictor import SpanPredictor |
|
from sftp.models import SpanModel |
|
from sftp.data_reader import BetterDatasetReader |
|
|
|
|
|
def predict_doc(predictor, json_path: str): |
|
src = json.load(open(json_path)) |
|
for doc_name, entry in tqdm(list(src['entries'].items())): |
|
pred = predictor.predict_json(entry) |
|
triggers = list() |
|
for trigger in pred['prediction']: |
|
children = list() |
|
for child in trigger['children']: |
|
children.append([child['start_idx'], child['end_idx']]) |
|
triggers.append({ |
|
"span": [trigger['start_idx'], trigger['end_idx']], |
|
"argument": children |
|
}) |
|
entry['trigger span'] = triggers |
|
return src |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('-a', type=str, help='archive path') |
|
parser.add_argument('-s', type=str, help='source path') |
|
parser.add_argument('-d', type=str, help='destination path') |
|
parser.add_argument('-c', type=int, default=0, help='cuda device') |
|
args = parser.parse_args() |
|
predictor_ = SpanPredictor.from_path(os.path.join(args.a, 'model.tar.gz'), 'span', cuda_device=args.c) |
|
model_name = os.path.basename(args.a) |
|
tgt_path = os.path.join(args.d, model_name) |
|
os.makedirs(tgt_path, exist_ok=True) |
|
for root, _, files in os.walk(args.s): |
|
for fn in files: |
|
if not fn.endswith('json') and not fn.endswith('valid'): |
|
continue |
|
processed_json = predict_doc(predictor_, os.path.join(root, fn)) |
|
with open(os.path.join(tgt_path, fn), 'w') as fp: |
|
json.dump(processed_json, fp) |
|
|