File size: 2,596 Bytes
bc25cf1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from itertools import cycle
import random
import argparse

from simpletransformers.seq2seq import Seq2SeqModel
import pandas as pd


random.seed = 42

def load_conllu_dataset(datafile, join=False):
    arr = []
    with open(datafile, encoding='utf-8') as inp:
        strings = inp.readlines()
    for s in strings:
        if (s[0] != "#" and s.strip()):
            split_string = s.split('\t')
            if split_string[1] == "(" or split_string[1] == ")" or split_string[1] == "[" or split_string[1] == "]":
                form = split_string[1]
            else:
                form = split_string[1].replace("(", "").replace(")", "").replace("[", "").replace("]", "")
            if split_string[3] != "PROPN":
                form = form.lower()
            else:
                form = form.capitalize()
            lemma = split_string[2]
            if split_string[3] == "PROPN":
                lemma = lemma.capitalize()
            if join:
                inpt = form + " " + split_string[3] + " " + split_string[5]
            else:
                inpt = form
            pos = split_string[3]
            arr.append([inpt, lemma, pos])
    return pd.DataFrame(arr, columns=["input_text", "target_text", "pos"])

def predict(in_file, out_file, join=False):
    if join:
        model_name = "Futyn-Maker/RuthLemm-morphology"
    else:
        model_name = "Futyn-Maker/RuthLemm"

    model = Seq2SeqModel(
        encoder_decoder_type="bart",
        encoder_decoder_name=model_name,
        use_cuda=False
    )

    pred_data = load_conllu_dataset(in_file, join=join)["input_text"].tolist()
    predictions = cycle(model.predict(pred_data))

    with open(in_file, encoding="utf8") as inp:
        strings = inp.readlines()
    predicted = []
    for s in strings:
        if (s[0] != "#" and s.strip()):
            split_string = s.split("\t")
            split_string[2] = next(predictions)
            joined_string = "\t".join(split_string)
            predicted.append(joined_string)
            continue
        predicted.append(s)

    with open(out_file, "w", encoding="utf8") as out:
        out.write("".join(predicted))

if __name__ == '__main__':    
    parser = argparse.ArgumentParser()
    parser.add_argument("input_file", type=str, help="Path to the input file")
    parser.add_argument("output_file", type=str, help="Path to the output file")
    parser.add_argument("--morphology", "-m", action="store_true", help="Use morphology")

    args = parser.parse_args()
    predict(args.input_file, args.output_file, args.morphology)
    print("All done!")