|
import os |
|
import argparse |
|
from xml.etree import ElementTree |
|
import copy |
|
from operator import attrgetter |
|
import json |
|
import logging |
|
|
|
from sftp import SpanPredictor |
|
|
|
|
|
def predict_kairos(model_archive, source_folder, onto_map): |
|
xml_files = list() |
|
for root, _, files in os.walk(source_folder): |
|
for f in files: |
|
if f.endswith('.xml'): |
|
xml_files.append(os.path.join(root, f)) |
|
logging.info(f'{len(xml_files)} files are found:') |
|
for fn in xml_files: |
|
logging.info(' - ' + fn) |
|
|
|
logging.info('Loading ontology from ' + onto_map) |
|
k_map = dict() |
|
for kairos_event, content in json.load(open(onto_map)).items(): |
|
for fr in content['framenet']: |
|
if fr['label'] in k_map: |
|
logging.info("Duplicate frame: " + fr['label']) |
|
k_map[fr['label']] = kairos_event |
|
|
|
logging.info('Loading model from ' + model_archive + ' ...') |
|
predictor = SpanPredictor.from_path(model_archive) |
|
|
|
predictions = list() |
|
|
|
for fn in xml_files: |
|
logging.info('Now processing ' + os.path.basename(fn)) |
|
tree = ElementTree.parse(fn).getroot() |
|
for doc in tree: |
|
doc_meta = copy.deepcopy(doc.attrib) |
|
text = list(doc)[0] |
|
for seg in text: |
|
seg_meta = copy.deepcopy(doc_meta) |
|
seg_meta['seg'] = copy.deepcopy(seg.attrib) |
|
tokens = [child for child in seg if child.tag == 'TOKEN'] |
|
tokens.sort(key=lambda t: t.attrib['start_char']) |
|
words = list(map(attrgetter('text'), tokens)) |
|
one_pred = predictor.predict_sentence(words) |
|
one_pred['meta'] = seg_meta |
|
|
|
new_frames = list() |
|
for fr in one_pred['prediction']: |
|
if fr['label'] in k_map: |
|
fr['label'] = k_map[fr['label']] |
|
new_frames.append(fr) |
|
one_pred['prediction'] = new_frames |
|
|
|
predictions.append(one_pred) |
|
|
|
logging.info('Finished Prediction.') |
|
|
|
return predictions |
|
|
|
|
|
def do_task(input_dir, model_archive, onto_map): |
|
""" |
|
This function is called by the KAIROS infrastructure code for each |
|
TASK1 input. |
|
""" |
|
|
|
return predict_kairos(model_archive=model_archive, |
|
source_folder=input_dir, |
|
onto_map=onto_map) |
|
|
|
|
|
def run(): |
|
parser = argparse.ArgumentParser(description='Span Finder for KAIROS Quizlet4\n') |
|
parser.add_argument('model_archive', metavar='MODEL_ARCHIVE', type=str, help='Path to model archive file.') |
|
parser.add_argument('source_folder', metavar='SOURCE_FOLDER', type=str, help='Path to the folder that contains the XMLs.') |
|
parser.add_argument('onto_map', metavar='ONTO_MAP', type=str, help='Path to the ontology JSON.') |
|
parser.add_argument('destination', metavar='DESTINATION', type=str, help='Output path. (jsonl file path)') |
|
args = parser.parse_args() |
|
|
|
logging.basicConfig(level='INFO', format="%(asctime)s %(name)-12s %(levelname)-8s %(message)s") |
|
|
|
predictions = predict_kairos(model_archive=args.model_archive, |
|
source_folder=args.source_folder, |
|
onto_map=args.onto_map) |
|
|
|
logging.info('Saving to ' + args.destination + ' ...') |
|
os.makedirs(os.path.dirname(args.destination), exist_ok=True) |
|
with open(args.destination, 'w') as fp: |
|
fp.write('\n'.join(map(json.dumps, predictions))) |
|
|
|
logging.info('Done.') |
|
|
|
|
|
if __name__ == '__main__': |
|
run() |
|
|