Spaces:
Sleeping
Sleeping
File size: 2,378 Bytes
76b1ec5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
#!/usr/bin/env python3
from data import read_input
from aux import log
import sys
from collections import defaultdict
from evaluate import load as load_metric
SMUGRI_RES = {
'high': set("Estonian,English,Russian,Finnish,Hungarian,Latvian,German,Swedish,Norwegian,French".split(",")),
'mid': set("Komi,Komi-Zyrian,Northern Sami,Meadow Mari".split(",")),
'low': set("Udmurt,Proper Karelian,Southern Sami,Livvi,Veps,Moksha,Erzya,Lule Sami,Võro,Hill Mari,"
"Komi-Permyak,Inari Sami".split(",")),
'xlow': set("Ludian,Livonian,Izhorian,Votic,Shur Khanty,Skolt Sami,Meänkieli,"
"Sred Khanty,Surgut Khanty,Priur Khanty,Vakh Khanty,Unk Khanty,"
"Pite Sami,Mansi,Kazym Khanty,Kven,Ume Sami,Kildin Sami".split(","))
}
def _gen_lang(lang):
return lang.split(",")[0]
def _hi_or_lo_lang(lang):
gen_lang = _gen_lang(lang)
for k, v in SMUGRI_RES.items():
if gen_lang in v:
return k
log(f"Unrecognized language: {lang} / {gen_lang}")
return '?'
def _collect_lp_pairs(json_inputs, str_outputs):
sets_by_lp = defaultdict(list)
for i, o in zip(json_inputs, str_outputs):
ref = i["tgt_segm"]
hyp = o
det_lp = 'detailed: ' + i["src_lang"] + " -> " + i["tgt_lang"]
gen_lp = 'general: ' + _gen_lang(i["src_lang"]) + " -> " + _gen_lang(i["tgt_lang"])
hilo_lp = 'classes: ' + _hi_or_lo_lang(i["src_lang"]) + " -> " + _hi_or_lo_lang(i["tgt_lang"])
sets_by_lp[det_lp].append((hyp, ref))
sets_by_lp[gen_lp].append((hyp, ref))
sets_by_lp[hilo_lp].append((hyp, ref))
return sets_by_lp
def compute_metrics(json_inputs, str_outputs):
sets_by_lp = _collect_lp_pairs(json_inputs, str_outputs)
metric = load_metric("chrf")
result = []
for lp in sets_by_lp:
preds, outputs = zip(*sets_by_lp[lp])
metric_value = metric.compute(predictions=preds, references=outputs)
result.append((lp, metric_value, len(preds)))
return result
def avoid_global_scope():
json_inputs = read_input(sys.argv[1], "json")
str_outputs = read_input(sys.argv[2], "json")
lp_metric_dict = compute_metrics(json_inputs, str_outputs)
for lp, metric, size in lp_metric_dict:
print(f"{lp}: {metric['score']:.2f} ({size})")
if __name__ == "__main__":
avoid_global_scope() |