|
from sftp import SpanPredictor |
|
|
|
|
|
def print_children(sentence, boundary, labels, _): |
|
print('Sentence:', ' '.join(sentence)) |
|
for (start_idx, end_idx), lbl in zip(boundary, labels): |
|
print(' '.join(sentence[start_idx:end_idx+1]), ':', lbl) |
|
print('='*20) |
|
|
|
|
|
def example(): |
|
print("Loading predictor...") |
|
predictor = SpanPredictor.from_path( |
|
|
|
"/data/p289731/cloned/lome-models/models/spanfinder/model.mod.tar.gz", |
|
cuda_device=-1 |
|
) |
|
|
|
print("Predicting for sentence..") |
|
sentence = ['Tom', 'eats', 'an', 'apple', 'and', 'he', 'wakes', 'up', '.'] |
|
p1 = predictor.force_decode(sentence) |
|
print_children(sentence, *p1) |
|
p2 = predictor.force_decode(sentence, parent_span=(1, 1), parent_label='Ingestion') |
|
print_children(sentence, *p2) |
|
p3 = predictor.force_decode(sentence, child_spans=[(0, 0), (2, 3)], parent_span=(1, 1), parent_label='Ingestion') |
|
print_children(sentence, *p3) |
|
|
|
|
|
if __name__ == '__main__': |
|
example() |
|
|