#!/usr/bin/env python3 import sys import json from collections import defaultdict from benchmark import get_hyp_cache_dir, translate_all_hyps from inference import load_and_init_module_config from legacy.langconv import get_high_set, any_to_mdl_type, get_mdl_type from accelerate import Accelerator from aux import log def load_raw_data(path): with open(path, 'r') as f: return json.load(f) def save_raw_data(path, data): with open(path, 'w') as f: json.dump(data, f, indent=2) def apply_func_to_hires_snts(snt_set, func): high_set = get_high_set() for tupl in snt_set: langs = [k for k in tupl if not "-dia" in k and k in high_set] if langs: revlangs = high_set - set(langs) for revlang in revlangs: for lang in langs: # translate sentences tupl[lang] from lang to revlang # OR # add the result as tupl[revlang] func(tupl, lang, revlang) def report_part_stats(part, part_index, num_parts): hi_set = get_high_set() num_snts = len(part['sentences']) hires_langs = {k for k in part['sentences'][0] if "dia" not in k and k in hi_set} num_hires_langs = len(hires_langs) langs_to_do = hi_set - hires_langs num_to_translate = num_hires_langs * len(langs_to_do) log(f"Part {part_index + 1}/{num_parts}; {num_snts} sentences, num hires: {num_hires_langs}, to translate: {num_to_translate}") return num_snts * num_hires_langs, num_snts * num_to_translate def add_hires_synth_data(mdl_id, corpus_in, corpus_out, dry=False): accelerator = Accelerator() log("Loading data", accelerator) data = load_raw_data(corpus_in) log("Loading model", accelerator) if dry: main_model, module_config = None, None mdl_type = None else: main_model, module_config = load_and_init_module_config(mdl_id, accelerator) mdl_type = get_mdl_type(main_model) if accelerator.is_main_process: _ = get_hyp_cache_dir(mdl_id, create=True) l = len(data) tot_snt = 0 tot_tr = 0 for i, part in enumerate(data): tr_dict = defaultdict(lambda: defaultdict(lambda: None)) num_snt, num_tr = report_part_stats(part, i, l) tot_snt += num_snt tot_tr += num_tr if not dry: def _transfer(tup, src, tgt): srcm = any_to_mdl_type(mdl_type, src) tgtm = any_to_mdl_type(mdl_type, tgt) lp = f"{srcm}-{tgtm}" inp_snt = tup[src] # this "touches" the value: if it was not there, now it is None # and if it was there, then we use it if tr_dict[lp][inp_snt] is not None: tup[tgt] = tr_dict[lp][inp_snt] # collect sentences to translate apply_func_to_hires_snts(part['sentences'], _transfer) in_tr_dict_list = { lp: sorted(tr_dict[lp].items()) for lp in tr_dict } log(f"Translating part {i+1}/{l}", accelerator) #translate_cache_dict(tr_dict, mdl_id, module_config, corpus_in, accelerator) translate_all_hyps(in_tr_dict_list, module_config, mdl_id, f"{corpus_in}-{i}", accelerator) log(f"Collecting part {i+1}/{l}", accelerator) out_tr_dict_list = translate_all_hyps(in_tr_dict_list, module_config, mdl_id, corpus_in) for lp in out_tr_dict_list: for inp, outp in out_tr_dict_list[lp]: tr_dict[lp][inp] = outp # put translations back into data structure log(f"Integrating part {i+1}/{l}", accelerator) apply_func_to_hires_snts(part['sentences'], _transfer) log(f"Total sentences: {tot_snt}, total to generate: {tot_tr}", accelerator) if not dry: log("Saving data", accelerator) save_raw_data(corpus_out, data) if __name__ == '__main__': try: mdl_id_param = sys.argv[1] corpus_param = sys.argv[2] corpus_output_param = sys.argv[3] except IndexError: mdl_id_param = "models/nllb600m" corpus_param = "data/flt.json" corpus_output_param = "data/fltout.json" try: _ = sys.argv[4] dry_run = True except IndexError: dry_run = False add_hires_synth_data(mdl_id_param, corpus_param, corpus_output_param, dry_run)