Spaces:
Running
Running
from itertools import cycle | |
import random | |
import argparse | |
from simpletransformers.seq2seq import Seq2SeqModel | |
import pandas as pd | |
random.seed = 42 | |
def load_conllu_dataset(datafile, join=False): | |
arr = [] | |
with open(datafile, encoding='utf-8') as inp: | |
strings = inp.readlines() | |
for s in strings: | |
if (s[0] != "#" and s.strip()): | |
split_string = s.split('\t') | |
if split_string[1] == "(" or split_string[1] == ")" or split_string[1] == "[" or split_string[1] == "]": | |
form = split_string[1] | |
else: | |
form = split_string[1].replace("(", "").replace(")", "").replace("[", "").replace("]", "") | |
if split_string[3] != "PROPN": | |
form = form.lower() | |
else: | |
form = form.capitalize() | |
lemma = split_string[2] | |
if split_string[3] == "PROPN": | |
lemma = lemma.capitalize() | |
if join: | |
inpt = form + " " + split_string[3] + " " + split_string[5] | |
else: | |
inpt = form | |
pos = split_string[3] | |
arr.append([inpt, lemma, pos]) | |
return pd.DataFrame(arr, columns=["input_text", "target_text", "pos"]) | |
def predict(in_file, out_file, join=False): | |
if join: | |
model_name = "Futyn-Maker/RuthLemm-morphology" | |
else: | |
model_name = "Futyn-Maker/RuthLemm" | |
model = Seq2SeqModel( | |
encoder_decoder_type="bart", | |
encoder_decoder_name=model_name, | |
use_cuda=False | |
) | |
pred_data = load_conllu_dataset(in_file, join=join)["input_text"].tolist() | |
predictions = cycle(model.predict(pred_data)) | |
with open(in_file, encoding="utf8") as inp: | |
strings = inp.readlines() | |
predicted = [] | |
for s in strings: | |
if (s[0] != "#" and s.strip()): | |
split_string = s.split("\t") | |
split_string[2] = next(predictions) | |
joined_string = "\t".join(split_string) | |
predicted.append(joined_string) | |
continue | |
predicted.append(s) | |
with open(out_file, "w", encoding="utf8") as out: | |
out.write("".join(predicted)) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument("input_file", type=str, help="Path to the input file") | |
parser.add_argument("output_file", type=str, help="Path to the output file") | |
parser.add_argument("--morphology", "-m", action="store_true", help="Use morphology") | |
args = parser.parse_args() | |
predict(args.input_file, args.output_file, args.morphology) | |
print("All done!") | |