Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
from data import read_input | |
from aux import log | |
import sys | |
from collections import defaultdict | |
from evaluate import load as load_metric | |
SMUGRI_RES = { | |
'high': set("Estonian,English,Russian,Finnish,Hungarian,Latvian,German,Swedish,Norwegian,French".split(",")), | |
'mid': set("Komi,Komi-Zyrian,Northern Sami,Meadow Mari".split(",")), | |
'low': set("Udmurt,Proper Karelian,Southern Sami,Livvi,Veps,Moksha,Erzya,Lule Sami,Võro,Hill Mari," | |
"Komi-Permyak,Inari Sami".split(",")), | |
'xlow': set("Ludian,Livonian,Izhorian,Votic,Shur Khanty,Skolt Sami,Meänkieli," | |
"Sred Khanty,Surgut Khanty,Priur Khanty,Vakh Khanty,Unk Khanty," | |
"Pite Sami,Mansi,Kazym Khanty,Kven,Ume Sami,Kildin Sami".split(",")) | |
} | |
def _gen_lang(lang): | |
return lang.split(",")[0] | |
def _hi_or_lo_lang(lang): | |
gen_lang = _gen_lang(lang) | |
for k, v in SMUGRI_RES.items(): | |
if gen_lang in v: | |
return k | |
log(f"Unrecognized language: {lang} / {gen_lang}") | |
return '?' | |
def _collect_lp_pairs(json_inputs, str_outputs): | |
sets_by_lp = defaultdict(list) | |
for i, o in zip(json_inputs, str_outputs): | |
ref = i["tgt_segm"] | |
hyp = o | |
det_lp = 'detailed: ' + i["src_lang"] + " -> " + i["tgt_lang"] | |
gen_lp = 'general: ' + _gen_lang(i["src_lang"]) + " -> " + _gen_lang(i["tgt_lang"]) | |
hilo_lp = 'classes: ' + _hi_or_lo_lang(i["src_lang"]) + " -> " + _hi_or_lo_lang(i["tgt_lang"]) | |
sets_by_lp[det_lp].append((hyp, ref)) | |
sets_by_lp[gen_lp].append((hyp, ref)) | |
sets_by_lp[hilo_lp].append((hyp, ref)) | |
return sets_by_lp | |
def compute_metrics(json_inputs, str_outputs): | |
sets_by_lp = _collect_lp_pairs(json_inputs, str_outputs) | |
metric = load_metric("chrf") | |
result = [] | |
for lp in sets_by_lp: | |
preds, outputs = zip(*sets_by_lp[lp]) | |
metric_value = metric.compute(predictions=preds, references=outputs) | |
result.append((lp, metric_value, len(preds))) | |
return result | |
def avoid_global_scope(): | |
json_inputs = read_input(sys.argv[1], "json") | |
str_outputs = read_input(sys.argv[2], "json") | |
lp_metric_dict = compute_metrics(json_inputs, str_outputs) | |
for lp, metric, size in lp_metric_dict: | |
print(f"{lp}: {metric['score']:.2f} ({size})") | |
if __name__ == "__main__": | |
avoid_global_scope() |