|
from argparse import ArgumentParser |
|
from collections import defaultdict |
|
|
|
from torch import nn |
|
from copy import deepcopy |
|
import torch |
|
import os |
|
import json |
|
|
|
from sftp import SpanPredictor |
|
import nltk |
|
|
|
|
|
def shift_grid_cos_sim(mat: torch.Tensor): |
|
mat1 = mat.unsqueeze(0).expand(mat.shape[0], -1, -1) |
|
mat2 = mat.unsqueeze(1).expand(-1, mat.shape[0], -1) |
|
cos = nn.CosineSimilarity(2) |
|
sim = (cos(mat1, mat2) + 1) / 2 |
|
return sim |
|
|
|
|
|
def all_frames(): |
|
nltk.download('framenet_v17') |
|
fn = nltk.corpus.framenet |
|
return fn.frames() |
|
|
|
|
|
def extract_relations(fr): |
|
ret = list() |
|
added = {fr.name} |
|
for rel in fr.frameRelations: |
|
for key in ['subFrameName', 'superFrameName']: |
|
rel_fr_name = rel[key] |
|
if rel_fr_name in added: |
|
continue |
|
ret.append((rel_fr_name, key[:-4])) |
|
return ret |
|
|
|
|
|
def run(): |
|
parser = ArgumentParser() |
|
parser.add_argument('archive', metavar='ARCHIVE_PATH', type=str) |
|
parser.add_argument('dst', metavar='DESTINATION', type=str) |
|
parser.add_argument('kairos', metavar='KAIROS', type=str) |
|
parser.add_argument('--topk', metavar='TOPK', type=int, default=10) |
|
args = parser.parse_args() |
|
|
|
predictor = SpanPredictor.from_path(args.archive, cuda_device=-1) |
|
kairos_gold_mapping = json.load(open(args.kairos)) |
|
|
|
label_emb = predictor._model._span_typing.label_emb.weight.clone().detach() |
|
idx2label = predictor._model.vocab.get_index_to_token_vocabulary('span_label') |
|
|
|
emb_sim = shift_grid_cos_sim(label_emb) |
|
fr2definition = {fr.name: (fr.URL, fr.definition) for fr in all_frames()} |
|
|
|
last_mlp = predictor._model._span_typing.MLPs[-1].weight.detach().clone() |
|
mlp_sim = shift_grid_cos_sim(last_mlp) |
|
|
|
def rank_frame(sim): |
|
rank = sim.argsort(1, True) |
|
scores = sim.gather(1, rank) |
|
mapping = { |
|
fr.name: { |
|
'similarity': list(), |
|
'ontology': extract_relations(fr), |
|
'URL': fr.URL, |
|
'definition': fr.definition |
|
} for fr in all_frames() |
|
} |
|
for left_idx, (right_indices, match_scores) in enumerate(zip(rank, scores)): |
|
left_label = idx2label[left_idx] |
|
if left_label not in mapping: |
|
continue |
|
for right_idx, s in zip(right_indices, match_scores): |
|
right_label = idx2label[int(right_idx)] |
|
if right_label not in mapping or right_idx == left_idx: |
|
continue |
|
mapping[left_label]['similarity'].append((right_label, float(s))) |
|
return mapping |
|
|
|
emb_map = rank_frame(emb_sim) |
|
mlp_map = rank_frame(mlp_sim) |
|
|
|
def dump(mapping, folder_path): |
|
os.makedirs(folder_path, exist_ok=True) |
|
json.dump(mapping, open(os.path.join(folder_path, 'raw.json'), 'w')) |
|
sim_lines, onto_lines = list(), list() |
|
|
|
for fr, values in mapping.items(): |
|
sim_line = [ |
|
fr, |
|
values['definition'], |
|
values['URL'], |
|
] |
|
onto_line = deepcopy(sim_line) |
|
for rel_fr_name, rel_type in values['ontology']: |
|
onto_line.append(f'{rel_fr_name} ({rel_type})') |
|
onto_lines.append('\t'.join(onto_line)) |
|
if len(values['similarity']) > 0: |
|
for sim_fr_name, score in values['similarity'][:args.topk]: |
|
sim_line.append(f'{sim_fr_name} ({score:.3f})') |
|
sim_lines.append('\t'.join(sim_line)) |
|
|
|
with open(os.path.join(folder_path, 'similarity.tsv'), 'w') as fp: |
|
fp.write('\n'.join(sim_lines)) |
|
with open(os.path.join(folder_path, 'ontology.tsv'), 'w') as fp: |
|
fp.write('\n'.join(onto_lines)) |
|
|
|
kairos_dump = list() |
|
for kairos_event, kairos_content in kairos_gold_mapping.items(): |
|
for gold_fr in kairos_content['framenet']: |
|
gold_fr = gold_fr['label'] |
|
if gold_fr not in fr2definition: |
|
continue |
|
kairos_dump.append([ |
|
'GOLD', |
|
gold_fr, |
|
kairos_event, |
|
fr2definition[gold_fr][0], |
|
fr2definition[gold_fr][1], |
|
str(kairos_content['description']), |
|
'1.00' |
|
]) |
|
for ass_fr, sim_score in mapping[gold_fr]['similarity'][:args.topk]: |
|
kairos_dump.append([ |
|
'', |
|
ass_fr, |
|
kairos_event, |
|
fr2definition[ass_fr][0], |
|
fr2definition[ass_fr][1], |
|
str(kairos_content['description']), |
|
f'{sim_score:.2f}' |
|
]) |
|
kairos_dump = list(map(lambda line: '\t'.join(line), kairos_dump)) |
|
open(os.path.join(folder_path, 'kairos_sheet.tsv'), 'w').write('\n'.join(kairos_dump)) |
|
|
|
dump(mlp_map, os.path.join(args.dst, 'mlp')) |
|
dump(emb_map, os.path.join(args.dst, 'emb')) |
|
|
|
|
|
if __name__ == '__main__': |
|
run() |
|
|