File size: 2,378 Bytes
76b1ec5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#!/usr/bin/env python3

from data import read_input
from aux import log

import sys
from collections import defaultdict
from evaluate import load as load_metric

SMUGRI_RES = {
    'high': set("Estonian,English,Russian,Finnish,Hungarian,Latvian,German,Swedish,Norwegian,French".split(",")),
    'mid': set("Komi,Komi-Zyrian,Northern Sami,Meadow Mari".split(",")),
    'low': set("Udmurt,Proper Karelian,Southern Sami,Livvi,Veps,Moksha,Erzya,Lule Sami,Võro,Hill Mari,"
               "Komi-Permyak,Inari Sami".split(",")),
    'xlow': set("Ludian,Livonian,Izhorian,Votic,Shur Khanty,Skolt Sami,Meänkieli,"
                "Sred Khanty,Surgut Khanty,Priur Khanty,Vakh Khanty,Unk Khanty,"
                "Pite Sami,Mansi,Kazym Khanty,Kven,Ume Sami,Kildin Sami".split(","))
}


def _gen_lang(lang):
    return lang.split(",")[0]


def _hi_or_lo_lang(lang):
    gen_lang = _gen_lang(lang)

    for k, v in SMUGRI_RES.items():
        if gen_lang in v:
            return k

    log(f"Unrecognized language: {lang} / {gen_lang}")
    return '?'


def _collect_lp_pairs(json_inputs, str_outputs):
    sets_by_lp = defaultdict(list)

    for i, o in zip(json_inputs, str_outputs):
        ref = i["tgt_segm"]
        hyp = o
        det_lp = 'detailed: ' + i["src_lang"] + " -> " + i["tgt_lang"]
        gen_lp = 'general: ' + _gen_lang(i["src_lang"]) + " -> " + _gen_lang(i["tgt_lang"])
        hilo_lp = 'classes: ' + _hi_or_lo_lang(i["src_lang"]) + " -> " + _hi_or_lo_lang(i["tgt_lang"])

        sets_by_lp[det_lp].append((hyp, ref))
        sets_by_lp[gen_lp].append((hyp, ref))
        sets_by_lp[hilo_lp].append((hyp, ref))

    return sets_by_lp


def compute_metrics(json_inputs, str_outputs):
    sets_by_lp = _collect_lp_pairs(json_inputs, str_outputs)

    metric = load_metric("chrf")

    result = []

    for lp in sets_by_lp:
        preds, outputs = zip(*sets_by_lp[lp])
        metric_value = metric.compute(predictions=preds, references=outputs)

        result.append((lp, metric_value, len(preds)))

    return result


def avoid_global_scope():
    json_inputs = read_input(sys.argv[1], "json")
    str_outputs = read_input(sys.argv[2], "json")

    lp_metric_dict = compute_metrics(json_inputs, str_outputs)

    for lp, metric, size in lp_metric_dict:
        print(f"{lp}: {metric['score']:.2f} ({size})")

if __name__ == "__main__":
    avoid_global_scope()