Rasmus Lellep
add loader
76b1ec5
#!/usr/bin/env python3
from data import read_input
from aux import log
import sys
from collections import defaultdict
from evaluate import load as load_metric
SMUGRI_RES = {
'high': set("Estonian,English,Russian,Finnish,Hungarian,Latvian,German,Swedish,Norwegian,French".split(",")),
'mid': set("Komi,Komi-Zyrian,Northern Sami,Meadow Mari".split(",")),
'low': set("Udmurt,Proper Karelian,Southern Sami,Livvi,Veps,Moksha,Erzya,Lule Sami,Võro,Hill Mari,"
"Komi-Permyak,Inari Sami".split(",")),
'xlow': set("Ludian,Livonian,Izhorian,Votic,Shur Khanty,Skolt Sami,Meänkieli,"
"Sred Khanty,Surgut Khanty,Priur Khanty,Vakh Khanty,Unk Khanty,"
"Pite Sami,Mansi,Kazym Khanty,Kven,Ume Sami,Kildin Sami".split(","))
}
def _gen_lang(lang):
return lang.split(",")[0]
def _hi_or_lo_lang(lang):
gen_lang = _gen_lang(lang)
for k, v in SMUGRI_RES.items():
if gen_lang in v:
return k
log(f"Unrecognized language: {lang} / {gen_lang}")
return '?'
def _collect_lp_pairs(json_inputs, str_outputs):
sets_by_lp = defaultdict(list)
for i, o in zip(json_inputs, str_outputs):
ref = i["tgt_segm"]
hyp = o
det_lp = 'detailed: ' + i["src_lang"] + " -> " + i["tgt_lang"]
gen_lp = 'general: ' + _gen_lang(i["src_lang"]) + " -> " + _gen_lang(i["tgt_lang"])
hilo_lp = 'classes: ' + _hi_or_lo_lang(i["src_lang"]) + " -> " + _hi_or_lo_lang(i["tgt_lang"])
sets_by_lp[det_lp].append((hyp, ref))
sets_by_lp[gen_lp].append((hyp, ref))
sets_by_lp[hilo_lp].append((hyp, ref))
return sets_by_lp
def compute_metrics(json_inputs, str_outputs):
sets_by_lp = _collect_lp_pairs(json_inputs, str_outputs)
metric = load_metric("chrf")
result = []
for lp in sets_by_lp:
preds, outputs = zip(*sets_by_lp[lp])
metric_value = metric.compute(predictions=preds, references=outputs)
result.append((lp, metric_value, len(preds)))
return result
def avoid_global_scope():
json_inputs = read_input(sys.argv[1], "json")
str_outputs = read_input(sys.argv[2], "json")
lp_metric_dict = compute_metrics(json_inputs, str_outputs)
for lp, metric, size in lp_metric_dict:
print(f"{lp}: {metric['score']:.2f} ({size})")
if __name__ == "__main__":
avoid_global_scope()