|
from argparse import ArgumentParser |
|
from typing import * |
|
import json |
|
import logging |
|
|
|
from sftp import SpanPredictor |
|
|
|
logger = logging.getLogger('ConcretePredictor') |
|
|
|
|
|
def read_kairos(ontology_mapping_path: Optional[str] = None): |
|
|
|
if ontology_mapping_path is None: |
|
return |
|
raw = json.load(open(ontology_mapping_path)) |
|
fn2kairos = dict() |
|
for kairos_label in raw: |
|
for fn in raw[kairos_label]['framenet']: |
|
fn_label = fn['label'] |
|
if fn_label in fn2kairos: |
|
logger.warning(f'"{fn_label}" is repeated in the ontology file.') |
|
fn2kairos[fn_label] = kairos_label |
|
return fn2kairos |
|
|
|
|
|
def run(src, dst, model_path, ontology_mapping_path, device): |
|
mapping = SpanPredictor.read_ontology_mapping(ontology_mapping_path) |
|
predictor = SpanPredictor.from_path(model_path, cuda_device=device) |
|
predictor.predict_concrete(src, dst, ontology_mapping=mapping) |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = ArgumentParser() |
|
parser.add_argument('src', type=str) |
|
parser.add_argument('dst', type=str) |
|
parser.add_argument('model', type=str) |
|
parser.add_argument('--map', type=str, default=None) |
|
parser.add_argument('--device', type=int, default=-1) |
|
args = parser.parse_args() |
|
run(args.src, args.dst, args.model, args.map, args.device) |
|
|