sociolome / scripts /fn_eval /frame_id.py
Gosse Minnema
Initial commit
05922fb
import json
from argparse import ArgumentParser
from collections import defaultdict
import numpy as np
from tqdm import tqdm
from nltk.corpus import framenet as fn
from sftp import SpanPredictor
def run(model_path, data_path, device, use_ontology=False):
data = list(map(json.loads, open(data_path).readlines()))
lu2frame = defaultdict(list)
for lu in fn.lus():
lu2frame[lu.name].append(lu.frame.name)
predictor = SpanPredictor.from_path(model_path, cuda_device=device)
frame2idx = predictor._model.vocab.get_token_to_index_vocabulary('span_label')
all_frames = [fr.name for fr in fn.frames()]
n_positive = n_total = 0
with tqdm(total=len(data)) as bar:
for sent in data:
bar.update()
for point in sent['annotations']:
model_output = predictor.force_decode(
sent['tokens'], child_spans=[(point['span'][0], point['span'][-1])]
).distribution[0]
if use_ontology:
candidate_frames = lu2frame[point['lu']]
else:
candidate_frames = all_frames
candidate_prob = [-1.0 for _ in candidate_frames]
for idx_can, fr in enumerate(candidate_frames):
if fr in frame2idx:
candidate_prob[idx_can] = model_output[frame2idx[fr]]
if len(candidate_prob) > 0:
pred_frame = candidate_frames[int(np.argmax(candidate_prob))]
if pred_frame == point['label']:
n_positive += 1
n_total += 1
bar.set_description(f'acc={n_positive/n_total*100:.3f}')
print(f'acc={n_positive/n_total*100:.3f}')
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('model', metavar="MODEL")
parser.add_argument('data', metavar="DATA")
parser.add_argument('-d', default=-1, type=int, help='Device')
parser.add_argument('-o', action='store_true', help='Flag to use ontology.')
args = parser.parse_args()
run(args.model, args.data, args.d, args.o)