|
import json |
|
from argparse import ArgumentParser |
|
from collections import defaultdict |
|
import numpy as np |
|
|
|
from tqdm import tqdm |
|
from nltk.corpus import framenet as fn |
|
|
|
from sftp import SpanPredictor |
|
|
|
|
|
def run(model_path, data_path, device, use_ontology=False): |
|
data = list(map(json.loads, open(data_path).readlines())) |
|
lu2frame = defaultdict(list) |
|
for lu in fn.lus(): |
|
lu2frame[lu.name].append(lu.frame.name) |
|
predictor = SpanPredictor.from_path(model_path, cuda_device=device) |
|
frame2idx = predictor._model.vocab.get_token_to_index_vocabulary('span_label') |
|
all_frames = [fr.name for fr in fn.frames()] |
|
n_positive = n_total = 0 |
|
with tqdm(total=len(data)) as bar: |
|
for sent in data: |
|
bar.update() |
|
for point in sent['annotations']: |
|
model_output = predictor.force_decode( |
|
sent['tokens'], child_spans=[(point['span'][0], point['span'][-1])] |
|
).distribution[0] |
|
if use_ontology: |
|
candidate_frames = lu2frame[point['lu']] |
|
else: |
|
candidate_frames = all_frames |
|
candidate_prob = [-1.0 for _ in candidate_frames] |
|
for idx_can, fr in enumerate(candidate_frames): |
|
if fr in frame2idx: |
|
candidate_prob[idx_can] = model_output[frame2idx[fr]] |
|
if len(candidate_prob) > 0: |
|
pred_frame = candidate_frames[int(np.argmax(candidate_prob))] |
|
if pred_frame == point['label']: |
|
n_positive += 1 |
|
n_total += 1 |
|
bar.set_description(f'acc={n_positive/n_total*100:.3f}') |
|
print(f'acc={n_positive/n_total*100:.3f}') |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = ArgumentParser() |
|
parser.add_argument('model', metavar="MODEL") |
|
parser.add_argument('data', metavar="DATA") |
|
parser.add_argument('-d', default=-1, type=int, help='Device') |
|
parser.add_argument('-o', action='store_true', help='Flag to use ontology.') |
|
args = parser.parse_args() |
|
run(args.model, args.data, args.d, args.o) |
|
|