Spaces:
Sleeping
Sleeping
""" | |
Convert lang. codes between different schemas | |
NLLB uses codes like "eng_Latn": ISO-639-3 and script | |
MADLAD uses codes like "<2en>": ISO-639-1 for where it's available, ISO-639-3 elsewhere | |
(and some codes include the script, but we'll ignore them here) | |
Functions at the end of the file (any_to_nllb, any_to_madlad) should | |
cope with a lang code in any style ('en', 'eng', 'eng_Latn', '<2en>', '<2eng>', etc) | |
and convert them to corresponding representations (NLLB/MADLAD). | |
""" | |
from collections import defaultdict | |
SMUGRI_LOW = "fkv,izh,kca,koi,kpv,krl,liv,lud,mdf,mhr,mns,mrj,myv,olo,sjd,sje,sju,sma,sme,smj,smn,sms,udm,vep,vot,vro" | |
SMUGRI_HIGH = "deu,eng,est,fin,hun,lvs,nor,rus,swe" | |
SMUGRI = SMUGRI_HIGH + "," + SMUGRI_LOW | |
import pycountry | |
# madlad all codes | |
MADLAD_CODES = ['<2meo>', '<2lo>', '<2Grek>', '<2ada>', '<2ps>', '<2arn>', '<2Armn>', '<2to>', '<2raj>', '<2bas>', '<2ny>', '<2>', '<2zza>', '<2Thai>', '<2kaa_Latn>', '<2yap>', '<2en_xx_simple>', '<2ta>', '<2bg_Latn>', '<2mkn>', '<2lhu>', '<2gu_Latn>', '<2nzi>', '<2uz>', '<2pis>', '<2cfm>', '<2min>', '<2fon>', '<2tn>', '<2msi>', '<2sw>', '<2Tfng>', '<2teo>', '<2taj>', '<2pap>', '<2sd>', '<2Jpan>', '<2tca>', '<2sr>', '<2an>', '<2fr>', '<2gor>', '<2az>', '<2qvi>', '<2pck>', '<2cak>', '<2ltg>', '<2sah>', '<2tly_IR>', '<2ts>', '<2yo>', '<2hne>', '<2bzj>', '<2tuc>', '<2sh>', '<2da>', '<2gui>', '<2translate>', '<2et>', '<2sja>', '<2nhe>', '<2scn>', '<2dje>', '<2pt>', '<2nog>', '<2fil>', '<2mai>', '<2lb>', '<2bm>', '<2Guru>', '<2gom>', '<2hr>', '<2kg>', '<2uk>', '<2rw>', '<2izz>', '<2Telu>', '<2wuu>', '<2Deva>', '<2or>', '<2is>', '<2om>', '<2iso>', '<2sn>', '<2kjh>', '<2tbz>', '<2suz>', '<2bjn>', '<2lv>', '<2mfe>', '<2tcy>', '<2tyz>', '<2ksw>', '<2nds_NL>', '<2ms>', '<2mam>', '<2ubu>', '<2hil>', '<2mh>', '<2gl>', '<2bew>', '<2ilo>', '<2kbd>', '<2toj>', '<2quf>', '<2jam>', '<2Beng>', '<2tyv>', '<2lmo>', '<2ace>', '<2cab>', '<2sq>', '<2ug>', '<2kac>', '<2ay>', '<2mag>', '<2Arab>', '<2mrj>', '<2cs>', '<2bci>', '<2doi>', '<2zu>', '<2ndc_ZW>', '<2smt>', '<2ho>', '<2ss>', '<2he>', '<2twu>', '<2kjg>', '<2pag>', '<2Latn>', '<2gym>', '<2sus>', '<2zh_Latn>', '<2mps>', '<2lg>', '<2ko>', '<2se>', '<2guc>', '<2mr>', '<2mwl>', '<2dwr>', '<2din>', '<2ffm>', '<2maz>', '<2nia>', '<2nl>', '<2Knda>', '<2jv>', '<2noa>', '<2udm>', '<2kr>', '<2de>', '<2ar>', '<2ZW>', '<2dln>', '<2mn>', '<2ml>', '<2crh>', '<2ha>', '<2ks>', '<2qvc>', '<2fur>', '<2myv>', '<2nv>', '<2ak>', '<2Gujr>', '<2cce>', '<2nso>', '<2sg>', '<2rmc>', '<2mas>', '<2mni>', '<2frp>', '<2my>', '<2xal>', '<2th>', '<2bik>', '<2bho>', '<2inb>', '<2Mlym>', '<2oj>', '<2back_translated>', '<2tet>', '<2gsw>', '<2ff>', '<2hy>', '<2otq>', '<2el>', '<2agr>', '<2br>', '<2alt>', '<2tzo>', '<2chm>', '<2transliterate>', '<2hu>', '<2btx>', '<2vi>', '<2iba>', '<2bg>', '<2gub>', '<2li>', '<2ace_Arab>', '<2qub>', '<2ktu>', '<2bru>', '<2bbc>', '<2ca>', '<2hvn>', '<2sat_Latn>', '<2ku>', '<2shn>', '<2djk>', '<2krc>', '<2io>', '<2ig>', '<2chk>', '<2sm>', '<2Mymr>', '<2Kore>', '<2ary>', '<2lu>', '<2fa>', '<2spp>', '<2af>', '<2ti>', '<2Tibt>', '<2emp>', '<2enq>', '<2kl>', '<2be>', '<2srn>', '<2ms_Arab_BN>', '<2kri>', '<2gd>', '<2mk>', '<2syr>', '<2kmz_Latn>', '<2CA>', '<2ium>', '<2abt>', '<2ngu>', '<2tab>', '<2it>', '<2ru>', '<2ann>', '<2msm>', '<2fo>', '<2ne>', '<2akb>', '<2kv>', '<2jac>', '<2ceb>', '<2ang>', '<2tdx>', '<2tr>', '<2kbp>', '<2mgh>', '<2az_RU>', '<2acf>', '<2tg>', '<2dov>', '<2pau>', '<2mg>', '<2fuv>', '<2nn>', '<2Hant>', '<2hui>', '<2ml_Latn>', '<2ja>', '<2lus>', '<2te>', '<2qu>', '<2rom>', '<2tsg>', '<2el_Latn>', '<2cr_Latn>', '<2ur>', '<2fi>', '<2shp>', '<2brx>', '<2laj>', '<2sda>', '<2lij>', '<2st>', '<2bn>', '<2zxx_xx_dtynoise>', '<2yua>', '<2no>', '<2fr_CA>', '<2miq>', '<2trp>', '<2es>', '<2ch>', '<2mass>', '<2os>', '<2bts>', '<2ady>', '<2lrc>', '<2seh>', '<2adh>', '<2new>', '<2mak>', '<2grc>', '<2nus>', '<2tzj>', '<2nut>', '<2gu>', '<2oc>', '<2ppk>', '<2Hans>', '<2tzh>', '<2si>', '<2wo>', '<2nyu>', '<2Hebr>', '<2mad>', '<2tll>', '<2kr_Arab>', '<2pon>', '<2mbt>', '<2kw>', '<2bjn_Arab>', '<2gn>', '<2eu>', '<2dz>', '<2kaa>', '<2crh_Latn>', '<2te_Latn>', '<2ky>', '<2kn_Latn>', '<2kum>', '<2fip>', '<2ksd>', '<2sk>', '<2NL>', '<2ctd_Latn>', '<2Khmr>', '<2gbm>', '<2Cans>', '<2haw>', '<2gag>', '<2Taml>', '<2cnh>', '<2bim>', '<2ms_Arab>', '<2Thaa>', '<2kha>', '<2tvl>', '<2Cyrl>', '<2chr>', '<2dtp>', '<2ba>', '<2nan_Latn_TW>', '<2ro>', '<2ctu>', '<2Ethi>', '<2zh>', '<2ln>', '<2ve>', '<2xh>', '<2skr>', '<2ber>', '<2niq>', '<2ibb>', '<2jvn>', '<2tks>', '<2av>', '<2ahk>', '<2tk>', '<2tt>', '<2ka>', '<2tsc>', '<2km>', '<2co>', '<2id>', '<2prs>', '<2rki>', '<2kmb>', '<2ks_Deva>', '<2ify>', '<2wal>', '<2arz>', '<2amu>', '<2rm>', '<2pa>', '<2RU>', '<2ce>', '<2hi>', '<2eo>', '<2taq>', '<2ga>', '<2qxr>', '<2la>', '<2bi>', '<2rwo>', '<2dyu>', '<2zh_Hant>', '<2mt>', '<2bqc>', '<2bn_Latn>', '<2zne>', '<2szl>', '<2lt>', '<2sl>', '<2hif>', '<2alz>', '<2ber_Latn>', '<2ckb>', '<2wa>', '<2Cher>', '<2msb>', '<2gom_Latn>', '<2ru_Latn>', '<2crs>', '<2kk>', '<2gvl>', '<2qvz>', '<2bar>', '<2qup>', '<2bgp>', '<2bo>', '<2su>', '<2tzm>', '<2IR>', '<2sv>', '<2srm>', '<2rn>', '<2bus>', '<2jiv>', '<2awa>', '<2gv>', '<2knj>', '<2as>', '<2quc>', '<2en>', '<2sa>', '<2bug>', '<2quy>', '<2hi_Latn>', '<2nds>', '<2kek>', '<2mrw>', '<2kos>', '<2cy>', '<2ta_Latn>', '<2kn>', '<2nr>', '<2ape>', '<2bs>', '<2iu>', '<2nnb>', '<2Geor>', '<2rcf>', '<2meu>', '<2cac>', '<2cuk>', '<2bua>', '<2vec>', '<2so>', '<2fj>', '<2gof>', '<2koi>', '<2cv>', '<2guh>', '<2war>', '<2pl>', '<2cbk>', '<2kj>', '<2dv>', '<2mdf>', '<2fy>', '<2am>', '<2sc>', '<2taq_Tfng>', '<2mi>', '<2zap>', '<2mqy>', '<2yi>', '<2kwi>', '<2hmn>', '<2tiv>', '<2sxn>', '<2hus>', '<2ban>', '<2nij>', '<2tlh>', '<2Orya>', '<2quh>', '<2ee>', '<2ht>', '<2bum>', '<2stq>'] | |
# NLLB all codes | |
NLLB_CODES = ['ace_Arab', 'ace_Latn', 'acm_Arab', 'acq_Arab', 'aeb_Arab', 'afr_Latn', 'ajp_Arab', 'aka_Latn', 'amh_Ethi', 'apc_Arab', 'arb_Arab', 'ars_Arab', 'ary_Arab', 'arz_Arab', 'asm_Beng', 'ast_Latn', 'awa_Deva', 'ayr_Latn', 'azb_Arab', 'azj_Latn', 'bak_Cyrl', 'bam_Latn', 'ban_Latn', 'bel_Cyrl', 'bem_Latn', 'ben_Beng', 'bho_Deva', 'bjn_Arab', 'bjn_Latn', 'bod_Tibt', 'bos_Latn', 'bug_Latn', 'bul_Cyrl', 'cat_Latn', 'ceb_Latn', 'ces_Latn', 'cjk_Latn', 'ckb_Arab', 'crh_Latn', 'cym_Latn', 'dan_Latn', 'deu_Latn', 'dik_Latn', 'dyu_Latn', 'dzo_Tibt', 'ell_Grek', 'eng_Latn', 'epo_Latn', 'est_Latn', 'eus_Latn', 'ewe_Latn', 'fao_Latn', 'pes_Arab', 'fij_Latn', 'fin_Latn', 'fon_Latn', 'fra_Latn', 'fur_Latn', 'fuv_Latn', 'gla_Latn', 'gle_Latn', 'glg_Latn', 'grn_Latn', 'guj_Gujr', 'hat_Latn', 'hau_Latn', 'heb_Hebr', 'hin_Deva', 'hne_Deva', 'hrv_Latn', 'hun_Latn', 'hye_Armn', 'ibo_Latn', 'ilo_Latn', 'ind_Latn', 'isl_Latn', 'ita_Latn', 'jav_Latn', 'jpn_Jpan', 'kab_Latn', 'kac_Latn', 'kam_Latn', 'kan_Knda', 'kas_Arab', 'kas_Deva', 'kat_Geor', 'knc_Arab', 'knc_Latn', 'kaz_Cyrl', 'kbp_Latn', 'kea_Latn', 'khm_Khmr', 'kik_Latn', 'kin_Latn', 'kir_Cyrl', 'kmb_Latn', 'kon_Latn', 'kor_Hang', 'kmr_Latn', 'lao_Laoo', 'lvs_Latn', 'lij_Latn', 'lim_Latn', 'lin_Latn', 'lit_Latn', 'lmo_Latn', 'ltg_Latn', 'ltz_Latn', 'lua_Latn', 'lug_Latn', 'luo_Latn', 'lus_Latn', 'mag_Deva', 'mai_Deva', 'mal_Mlym', 'mar_Deva', 'min_Latn', 'mkd_Cyrl', 'plt_Latn', 'mlt_Latn', 'mni_Beng', 'khk_Cyrl', 'mos_Latn', 'mri_Latn', 'zsm_Latn', 'mya_Mymr', 'nld_Latn', 'nno_Latn', 'nob_Latn', 'npi_Deva', 'nso_Latn', 'nus_Latn', 'nya_Latn', 'oci_Latn', 'gaz_Latn', 'ory_Orya', 'pag_Latn', 'pan_Guru', 'pap_Latn', 'pol_Latn', 'por_Latn', 'prs_Arab', 'pbt_Arab', 'quy_Latn', 'ron_Latn', 'run_Latn', 'rus_Cyrl', 'sag_Latn', 'san_Deva', 'sat_Beng', 'scn_Latn', 'shn_Mymr', 'sin_Sinh', 'slk_Latn', 'slv_Latn', 'smo_Latn', 'sna_Latn', 'snd_Arab', 'som_Latn', 'sot_Latn', 'spa_Latn', 'als_Latn', 'srd_Latn', 'srp_Cyrl', 'ssw_Latn', 'sun_Latn', 'swe_Latn', 'swh_Latn', 'szl_Latn', 'tam_Taml', 'tat_Cyrl', 'tel_Telu', 'tgk_Cyrl', 'tgl_Latn', 'tha_Thai', 'tir_Ethi', 'taq_Latn', 'taq_Tfng', 'tpi_Latn', 'tsn_Latn', 'tso_Latn', 'tuk_Latn', 'tum_Latn', 'tur_Latn', 'twi_Latn', 'tzm_Tfng', 'uig_Arab', 'ukr_Cyrl', 'umb_Latn', 'urd_Arab', 'uzn_Latn', 'vec_Latn', 'vie_Latn', 'war_Latn', 'wol_Latn', 'xho_Latn', 'ydd_Hebr', 'yor_Latn', 'yue_Hant', 'zho_Hans', 'zho_Hant', 'zul_Latn'] | |
MDL_NLLB = "MDL_NLLB" | |
MDL_MADLAD = "MDL_MADLAD" | |
MDL_NEUROTOLGE = "MDL_NEUROTÕLGE" | |
MDL_LLAMA = "MDL_LLAMA" | |
_iso3_to_script = dict([nllb_code.split("_") for nllb_code in NLLB_CODES]) | |
iso3_to_nllb = { code: f"{code}_{_iso3_to_script[code]}" for code in _iso3_to_script } | |
iso3_to_nllb['lav'] = "lvs_Latn" | |
iso3_to_nllb['nor'] = "nob_Latn" | |
iso3_to_nllb['yid'] = "ydd_Hebr" | |
for lang in "fkv izh krl liv lud olo sje sju sma sme smj smn sms vep vot vro".split(): | |
iso3_to_nllb[lang] = f"{lang}_Latn" | |
for lang in "kca koi kpv mdf mhr mns mrj myv sjd udm".split(): | |
iso3_to_nllb[lang] = f"{lang}_Cyrl" | |
_rev_joshi = defaultdict(lambda: "?") | |
for k in "krl,sma,vep,smj,smn,lud,liv,izh,vot,kca,sms,sje,mns,fkv,sju,sjd".split(","): | |
_rev_joshi[k] = "0" | |
for k in "kpv,sme,mhr,udm,olo,myv,mdf,vro,mrj,koi".split(","): | |
_rev_joshi[k] = "1" | |
for k in SMUGRI_HIGH.split(","): | |
_rev_joshi[k] = "2+" | |
def guess_script(lang): | |
return "Unk" | |
def get_high_set(): | |
return set(SMUGRI_HIGH.split(",")) - {"deu", "swe"} | |
def clean_lang(raw_lang): | |
if "<2" in raw_lang: | |
raw_lang = raw_lang[2:-1] | |
if "_" in raw_lang: | |
return raw_lang.split("_")[0] | |
else: | |
return raw_lang | |
def any_to_base(lang): | |
clang = clean_lang(lang) | |
res = pycountry.languages.get(alpha_2=clang) | |
if res is None: | |
return pycountry.languages.get(alpha_3=clang) | |
else: | |
return res | |
def base_to_nllb(lang_entry=None, lang_code=None): | |
if lang_code is None: | |
lang_code = lang_entry.alpha_3 | |
try: | |
#script = iso3_to_script[lang_code] | |
return iso3_to_nllb[lang_code] | |
except KeyError: | |
script = guess_script(lang_code) | |
return f"{lang_code}_{script}" | |
def base_to_madlad(lang_entry=None, lang_code=None): | |
if lang_code is None: | |
if hasattr(lang_entry, 'alpha_2'): | |
lang_code = lang_entry.alpha_2 | |
else: | |
lang_code = lang_entry.alpha_3 | |
return f"<2{lang_code}>" | |
def any_to_something(lang, conv_func): | |
base = any_to_base(lang) | |
if base is None: | |
clang = clean_lang(lang) | |
return conv_func(None, clang) | |
else: | |
return conv_func(base) | |
def run_test(src_list, tgt_list, conv_func, msg_prefix, verbose=False): | |
ok_count = 0 | |
err_count = 0 | |
fail_count = 0 | |
for raw_c in src_list: | |
try: | |
test = conv_func(raw_c) | |
if test in tgt_list: | |
ok_count += 1 | |
else: | |
fail_count += 1 | |
if verbose: | |
print("FAIL:", test) | |
except KeyError: | |
err_count += 1 | |
if verbose: | |
print("ERR:", raw_c) | |
print(f"{msg_prefix}: {ok_count} good, {fail_count} fail, {err_count} err") | |
def any_to_madlad(lang): | |
return any_to_something(lang, base_to_madlad) | |
def any_to_nllb(lang): | |
return any_to_something(lang, base_to_nllb) | |
def any_to_neurotolge(lang): | |
l = any_to_base(lang).alpha_3 | |
return l if l != 'lvs' else 'lv' | |
def any_to_mdl_type(mdl_type, lang): | |
if mdl_type == MDL_NLLB: | |
return any_to_nllb(lang) | |
elif mdl_type == MDL_MADLAD: | |
return any_to_madlad(lang) | |
elif mdl_type is None: | |
return lang | |
elif mdl_type == MDL_LLAMA: | |
return lang | |
else: | |
raise ValueError(f"Unknown mdl_type {mdl_type}") | |
def langs_to_madlad(lang_set): | |
return [any_to_madlad(l) for l in lang_set] if lang_set is not None else [] | |
def langs_to_nllb(lang_set): | |
return [any_to_nllb(l) for l in lang_set] if lang_set is not None else [] | |
if __name__ == "__main__": | |
run_test(NLLB_CODES, MADLAD_CODES, any_to_madlad, "NLLB to MADLAD") | |
run_test(NLLB_CODES, NLLB_CODES, any_to_nllb, "NLLB to NLLB") | |
run_test(MADLAD_CODES, NLLB_CODES, any_to_nllb, "MADLAD TO NLLB") | |
run_test(MADLAD_CODES, MADLAD_CODES, any_to_madlad, "MADLAD TO MADLAD") | |
def is_nllb(object): | |
""" | |
Check if the object is an NLLB model or tokenizer | |
""" | |
name = object.__class__.__name__.lower() | |
return "m2m100" in name or "nllb" in name | |
def is_madlad(object): | |
""" | |
Check if the object is a MADLAD model or tokenizer | |
""" | |
return "t5" in object.__class__.__name__.lower() | |
def is_dec_only_llm(obj): | |
lcname = obj.__class__.__name__.lower() | |
return any(k in lcname for k in ["pretrainedtokenizerfast", "llama", "gemma"]) | |
def get_mdl_type(obj): | |
obj = obj.module if hasattr(obj, "module") else obj | |
if is_nllb(obj): | |
return MDL_NLLB | |
elif is_madlad(obj): | |
return MDL_MADLAD | |
elif is_dec_only_llm(obj): | |
return MDL_LLAMA | |
else: | |
raise ValueError(f"Object {str(obj)[:200]} is not supported") | |
def langs_to_mdl_type(mdl_type, lang_set): | |
if mdl_type == MDL_NLLB: | |
return langs_to_nllb(lang_set) | |
elif mdl_type == MDL_MADLAD: | |
return langs_to_madlad(lang_set) | |
elif mdl_type == MDL_LLAMA: | |
return lang_set | |
else: | |
raise ValueError(f"Model type {mdl_type} is not supported") | |
def get_joshi_class(lang_code): | |
norm_code = any_to_base(lang_code) | |
if norm_code is None: | |
return "?" | |
else: | |
norm_code = norm_code.alpha_3 | |
return _rev_joshi[norm_code] | |
def lang_set_maybe_smugri(lang_def): | |
if lang_def == "smugri-low": | |
preresult = SMUGRI_LOW | |
elif lang_def == "smugri-high": | |
preresult = SMUGRI_HIGH | |
elif lang_def == "smugri": | |
preresult = SMUGRI | |
else: | |
preresult = lang_def | |
return set(preresult.split(",")) | |
def smugri_back(lang_list): | |
sll = sorted(lang_list) | |
sll_str = ",".join(sll) | |
if sll_str == SMUGRI_LOW: | |
return "smugri-low" | |
elif sll_str == SMUGRI_HIGH: | |
return "smugri-high" | |
elif sll_str == SMUGRI: | |
return "smugri-full" | |
else: | |
return sll_str | |