Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
import json | |
#import os | |
import sys | |
import torch | |
#import re | |
import math | |
from torch.utils.data import IterableDataset | |
from collections import namedtuple, defaultdict | |
from random import randrange, shuffle, randint | |
#from pathlib import Path | |
#from aux import log | |
#from langconv import any_to_madlad, any_to_nllb, is_nllb, is_madlad, get_mdl_type, any_to_mdl_type, is_dec_only_llm, \ | |
# base_to_nllb | |
#from tokops import tokenizeit | |
# TrPair = namedtuple('TrPair', ["src_lang", "tgt_lang", "input", "output"]) | |
""" | |
def prep_llm_input(ljmftpl): | |
#{'task': 'translate' / 'approx-translate' / 'generate', | |
# 'src_segm': src_segm, | |
# 'tgt_segm': tgt_segm, | |
# 'src_lang': src_lang, | |
# 'tgt_lang': tgt_lang} | |
# it's a tuple | |
if "src_segm" in ljmftpl and "task" in ljmftpl: | |
if ljmftpl['task'] in {'translate', 'approx-translate'}: | |
return (f"{ljmftpl['src_segm']}\n=====\n{ljmftpl['task']} from {ljmftpl['src_lang']}; " + | |
f"to {ljmftpl['tgt_lang']}:\n{ljmftpl['tgt_segm']}") | |
elif ljmftpl['task'] == 'generate': | |
return f"{ljmftpl['src_segm']}\n=====\nis in {ljmftpl['src_lang']};" | |
# it's a string | |
else: | |
return ljmftpl | |
def make_path_compatible(filename): | |
return filename.replace("/", "_").replace(":", "-") | |
def do_list_in_batches(data, batch_size): | |
i = 0 | |
while i < len(data): | |
yield data[i:i + batch_size] | |
i += batch_size | |
""" | |
""" | |
def do_bins_in_batches(bins, batch_size, sort_by_length): | |
result_list = [] | |
for src_k in bins: | |
for tgt_k in bins[src_k]: | |
if src_k == 0 or tgt_k == 0: | |
result_list += [(e, src_k, tgt_k) for e in do_list_in_batches(bins[src_k][tgt_k], batch_size)] | |
shuffle(result_list) | |
return result_list | |
def _post_proc(text, lang): | |
if lang == 'liv' and "’" in text and "O’R" not in text: | |
return text.replace("’", "") | |
else: | |
return text | |
def clean_entry(entry, leave_out): | |
result = {k: _post_proc(entry[k], k) for k in entry if entry[k].strip() and k not in leave_out} | |
return result | |
def load_json_data(path, leave_out={}, skip_cats=True, load_mono=True): | |
with open(path, 'r') as f: | |
data = json.load(f) | |
if skip_cats: | |
# skip categories | |
resx = [clean_entry(entry, leave_out) | |
for cat in data for entry in cat['sentences']] | |
res = [e for e in resx if e] | |
else: | |
raise NotImplementedError | |
# resx = {cat['source']: [clean_entry(entry, leave_out) for entry in cat['sentences']] for cat in data} | |
# res = {k: resx[k] for k in resx if resx[k]} | |
return res | |
def get_tr_pairs(raw_data=None, filename=None, leave_out=None, leave_only=None, model_type=None, exclude_set=None): | |
if filename is not None: | |
raw_data = load_json_data(filename) | |
if raw_data is None: | |
raise ValueError("Neither file nor data are provided") | |
i = 0 | |
log("Loading data") | |
for tup in raw_data: | |
for l1 in tup: | |
for l2 in tup: | |
if l1 != l2 and not "dia" in l1 and not "dia" in l2: | |
if leave_out is None or f"{l1}-{l2}" not in leave_out: | |
if leave_only is None or f"{l1}-{l2}" in leave_only: | |
i += 1 | |
if not i % 1000000: | |
log(f"Loaded {i/1000000}M pairs") | |
dia_key = f"{l2}-dia" | |
if exclude_set is None or (tup[l1] not in exclude_set[l1] and tup[l2] not in exclude_set[l2]): | |
input = tup[l1] | |
if dia_key in tup: | |
input = f"<{tup[dia_key]}> {input}" | |
conv_l1 = any_to_mdl_type(model_type, l1) | |
conv_l2 = any_to_mdl_type(model_type, l2) | |
if not snt_is_fishy(input, conv_l1) and not snt_is_fishy(tup[l2], conv_l2): | |
yield TrPair(conv_l1, conv_l2, input, tup[l2]) | |
def split_by_lang(filename, model_type): | |
result = defaultdict(list) | |
# if filename is not None: | |
# tr_pairs = load_json_datax(filename) | |
tr_pairs = get_tr_pairs(filename=filename, model_type=model_type) | |
for tup in tr_pairs: | |
#for l1 in tup: | |
# for l2 in tup: | |
# if l1 != l2 and not "dia" in l1 and not "dia" in l2: | |
l1 = tup.src_lang | |
l2 = tup.tgt_lang | |
lp = f"{l1}-{l2}" | |
result[lp].append((tup.input, tup.output)) | |
return result | |
def data_iter_for_tok_train(raw_data, langs_to_include): | |
for tup in raw_data: | |
for lang in tup: | |
if lang in langs_to_include: | |
yield tup[lang] | |
def lang_bin_mapping(coupling_specs): | |
lang_to_idx = dict() | |
for i, spec_pair in enumerate(coupling_specs): | |
for lang in spec_pair.lang_set: | |
if lang not in lang_to_idx: | |
lang_to_idx[lang] = {i} | |
else: | |
lang_to_idx[lang].add(i) | |
return lang_to_idx | |
def mix_and_sample_idxs_carefully(src_idxs, tgt_idxs): | |
idx_pairs = [(s, t) for s in src_idxs for t in tgt_idxs if not (s == 1 and t == 1)] | |
if len(idx_pairs) == 0: | |
result = (None, None) | |
else: | |
pair_idx = randrange(len(idx_pairs)) | |
result = idx_pairs[pair_idx] | |
# debug(f"src lang: {tr_pair.src_lang}, tgt_lang: {tr_pair.tgt_lang}, idx list: {idx_pairs}, result: {result}") | |
return result | |
def inject_bin_indices(batch, src_k, tgt_k): | |
batch['input_ids'][0,0] += src_k << 30 | |
batch['labels'][0,0] += tgt_k << 30 | |
def get_data_cache_location(cache_meta_path, idx): | |
cache_folder, cache_file = os.path.split(cache_meta_path) | |
if cache_folder: | |
Path(cache_folder).mkdir(parents=True, exist_ok=True) | |
if cache_meta_path.endswith(".json"): | |
return cache_meta_path[:-5] + f"_{idx:04}.pt" | |
else: | |
raise ValueError(f"Expected a json file for the cache meta-location ({cache_meta_path})") | |
def make_gen_text(src_lang, tgt_lang, input_text, output_text=None, tok=None): | |
if input_text.startswith("<"): | |
posit = input_text.find(">") + 1 | |
dialect = input_text[1:posit-1] | |
diatxt = f", variety: {dialect}" | |
txt = input_text[posit+1:] | |
else: | |
dialect = None | |
diatxt = "" | |
txt = input_text | |
return (f"Translate:\n== From: {src_lang}\n== To: {tgt_lang}{diatxt}\n== Input: {txt}\n== Output: " + | |
("" if (output_text is None or tok is None) else f"{output_text}{tok.eos_token}")) | |
class MultilingualBatchingCachingDataset: | |
def _post_proc_bins(self, bins): | |
for src_k in bins: | |
for tgt_k in bins[src_k]: | |
while len(bins[src_k][tgt_k]) % self.args.batch_size != 0: | |
rnd_elem_idx = randrange(len(bins[src_k][tgt_k])) | |
rnd_elem = bins[src_k][tgt_k][rnd_elem_idx] | |
bins[src_k][tgt_k].append(rnd_elem) | |
if self.args.sort_by_len: | |
bins[src_k][tgt_k] = sorted(bins[src_k][tgt_k], key=lambda e: len(e.input)) | |
else: | |
shuffle(bins[src_k][tgt_k]) | |
return bins | |
def _get_idxs(self, tr_pair): | |
src_idxs = self._lang_to_idx[tr_pair.src_lang] | |
tgt_idxs = self._lang_to_idx[tr_pair.tgt_lang] | |
return mix_and_sample_idxs_carefully(src_idxs, tgt_idxs) | |
def _fill_bins(self): | |
bins = defaultdict(lambda: defaultdict(list)) | |
for tr_pair in get_tr_pairs(filename=self.filename, model_type=self.model_type, exclude_set=self.exclude_set): | |
src_bin_idx, tgt_bin_idx = self._get_idxs(tr_pair) | |
if src_bin_idx is not None and tgt_bin_idx is not None: | |
bins[src_bin_idx][tgt_bin_idx].append(tr_pair) | |
return self._post_proc_bins(bins) | |
def report_update_stats(self, bins): | |
total = 0 | |
totalx = 0 | |
updates = 0 | |
duds = 0 | |
enc_count = 0 | |
dec_count = 0 | |
for src_k in bins: | |
for tgt_k in bins[src_k]: | |
l = len(bins[src_k][tgt_k]) | |
total += l | |
if src_k == 0 or tgt_k == 0: | |
totalx += l | |
updates += l * (1 - (src_k + tgt_k) / 2) | |
enc_count += l * (1 - src_k) | |
dec_count += l * (1 - tgt_k) | |
if src_k == 1 and tgt_k == 1: | |
duds += 1 | |
# log(str(self._lang_to_idx)) | |
log(f"### Ratio of coupled model updates: {100 * updates / total:.2f}% ({100 * updates / totalx:.2f}%); " + \ | |
f"frozen meaningless updates: {100 * duds / total:.2f}%; " + \ | |
f"enc samples: {enc_count}, dec samples: {dec_count}") | |
def tokenize_input(self, cplspec, input_list, rawbatch): | |
src_tokenizer = cplspec.tokenizer | |
src_tokenizer.src_lang = rawbatch[0].src_lang | |
#prep_batch_grouped = src_tokenizer(text=input_list, return_tensors="pt", | |
# padding="longest", truncation=True, max_length=self.args.max_snt_len) | |
prep_batch_grouped = tokenizeit((src_tokenizer, cplspec.postokenizer), input_list, self.args.max_snt_len, False) | |
if is_nllb(src_tokenizer): | |
src_lang_list = [any_to_nllb(e.src_lang) for e in rawbatch] | |
src_lang_vec = src_tokenizer.convert_tokens_to_ids(src_lang_list) | |
prep_batch_grouped['input_ids'][:,0] = torch.tensor(src_lang_vec) | |
return prep_batch_grouped | |
def tokenize_output(self, tgttokenizer, tgtposttok, rawbatch): | |
outputs = [e.output for e in rawbatch] | |
tgttokenizer.tgt_lang = rawbatch[0].tgt_lang | |
#labels = tgttokenizer(text_target=outputs, return_tensors="pt", | |
# padding="longest", truncation=True, max_length=self.args.max_snt_len) | |
labels = tokenizeit((tgttokenizer, tgtposttok), outputs, self.args.max_snt_len, True) | |
if is_nllb(tgttokenizer): | |
tgt_lang_list = [any_to_nllb(e.tgt_lang) for e in rawbatch] | |
tgt_lang_vec = tgttokenizer.convert_tokens_to_ids(tgt_lang_list) | |
labels['input_ids'][:, 0] = torch.tensor(tgt_lang_vec) | |
return labels | |
def tokenize_gen_batch(self, raw_batch): | |
tokenizer = self.coupling_specs[0].tokenizer | |
tokenizer.pad_token = '<|reserved_special_token_0|>' | |
tokenizer.padding_side = 'left' | |
texts = [make_gen_text(e.src_lang, e.tgt_lang, e.input, e.output, tokenizer) for e in raw_batch] | |
#batch = tokenizer(texts, return_tensors="pt", max_length=512, truncation=True, add_special_tokens=True, padding=True) | |
batch = tokenizeit((tokenizer, self.coupling_specs[0].postokenizer), texts, self.args.max_snt_len, False) | |
return batch | |
def tokenize_and_pad(self, raw_batch, src_k, tgt_k): | |
tgt_tokenizer = self.coupling_specs[tgt_k].tokenizer | |
tgt_postok = self.coupling_specs[tgt_k].postokenizer | |
if is_madlad(tgt_tokenizer): | |
inputs = [f"{any_to_madlad(e.tgt_lang)} {e.input}" for e in raw_batch] | |
else: | |
inputs = [e.input for e in raw_batch] | |
prep_batch_grouped = self.tokenize_input(self.coupling_specs[src_k], inputs, raw_batch) | |
labels = self.tokenize_output(tgt_tokenizer, tgt_postok, raw_batch) | |
prep_batch_grouped['labels'] = labels['input_ids'] | |
# inject_bin_indices(prep_batch_grouped, src_k, tgt_k) | |
#split_prep_batch = [{k: prep_batch_grouped[k][i] for k in prep_batch_grouped} | |
# for i, trp in enumerate(raw_batch)] | |
return prep_batch_grouped | |
def _bins_to_tokenized_batched_cached_data(self, bins, cache_path): | |
shard_i = 0 | |
batch_i = 0 | |
total_i = 0 | |
metainfo = [] | |
data = [] | |
log("Tokenizing data") | |
for raw_batch, src_k, tgt_k in do_bins_in_batches(bins, self.args.batch_size, self.args.sort_by_len): | |
batch_i += 1 | |
if not batch_i % 10000: | |
log(f"Tokenized {batch_i + shard_i * self.args.shard_size} batches (shard {shard_i})") | |
if is_dec_only_llm(self.coupling_specs[tgt_k].tokenizer): | |
prepared_batch = self.tokenize_gen_batch(raw_batch) | |
data.append((prepared_batch, total_i)) | |
else: | |
prepared_batch = self.tokenize_and_pad(raw_batch, src_k, tgt_k) | |
data.append((prepared_batch, src_k, tgt_k, total_i)) | |
if batch_i >= self.args.shard_size: | |
shard_i += 1 | |
batch_i = 0 | |
fn = self._save_cache_file(data, cache_path, shard_i) | |
metainfo.append({'shard_filename': fn, 'shard_size': len(data)}) | |
del data | |
data = [] | |
total_i += 1 | |
if len(data) > 0: | |
fn = self._save_cache_file(data, cache_path, shard_i + 1) | |
metainfo.append({'shard_filename': fn, 'shard_size': len(data)}) | |
with open(cache_path, 'w') as f: | |
json.dump(metainfo, f) | |
del data | |
@staticmethod | |
def _save_cache_file(data, cache_location, idx): | |
cache_location = get_data_cache_location(cache_location, idx) | |
if os.path.exists(cache_location): | |
raise Exception("Cache already exists") | |
torch.save(data, cache_location) | |
log(f"Saved data into cache (shard {idx})") | |
return cache_location | |
def set_model_type(self): | |
result = None | |
for spec_tuple in self.coupling_specs: | |
this_type = get_mdl_type(spec_tuple.tokenizer) | |
if result is None: | |
result = this_type | |
else: | |
assert result == this_type, "in this implementation model types (NLLB/MADLAD/...) must be the same for all included models" | |
return result | |
def __init__(self, tr_file, coupling_specs, args): | |
self.args = args | |
self.filename = tr_file | |
self.coupling_specs = coupling_specs | |
self.exclude_set = _dev_to_dict(args.exclude_set) if args.exclude_set is not None else None | |
self.model_type = self.set_model_type() | |
# init lang to idx | |
self._lang_to_idx = lang_bin_mapping(coupling_specs) | |
def load_and_cache_data(self, cache_path): | |
# collect data into bins and cache it | |
bins = self._fill_bins() | |
self.report_update_stats(bins) | |
self._bins_to_tokenized_batched_cached_data(bins, cache_path) | |
""" | |
""" | |
class DataState: | |
def __init__(self, elem_idx = 0, shard_idx = 0, epoch_idx = None): | |
self.elem_idx = elem_idx | |
self.shard_idx = shard_idx | |
self.epoch_idx = epoch_idx | |
def state_dict(self): | |
return {'elem_idx': self.elem_idx, 'shard_idx': self.shard_idx, 'epoch_idx': self.epoch_idx} | |
def load_state_dict(self, state_dict): | |
self.elem_idx = state_dict['elem_idx'] | |
self.shard_idx = state_dict['shard_idx'] | |
self.epoch_idx = state_dict['epoch_idx'] | |
def copy_from(self, src_ds, epoch_idx = None): | |
self.shard_idx = src_ds.shard_idx | |
self.elem_idx = src_ds.elem_idx | |
if epoch_idx is not None: | |
self.epoch_idx = epoch_idx | |
def __str__(self): | |
return 'DataState(elem_idx={}, shard_idx={}, epoch_idx={})'.format(self.elem_idx, self.shard_idx, self.epoch_idx) | |
def __repr__(self): | |
return self.__str__() | |
class BatchingIterator(IterableDataset): | |
def __init__(self, segment_list, batch_size, tokenizer, max_len=8000): | |
self.data = segment_list | |
shuffle(self.data) | |
self.batch_size = batch_size | |
self.tokenizer = tokenizer | |
self.max_len = max_len | |
self.curr_elem_idx = 0 | |
self.data_len = math.ceil(len(self.data) / self.batch_size) | |
def __len__(self): | |
return self.data_len | |
def __iter__(self): | |
self.curr_elem_idx = 0 | |
return self | |
def where_are_we(self): | |
return DataState(shard_idx=0, elem_idx=self.curr_elem_idx) | |
def thats_where(self, data_state): | |
self.curr_elem_idx = data_state.elem_idx | |
def _get_properly_sized_segment_list(self): | |
i = self.curr_elem_idx * self.batch_size | |
segment_list = self.data[i:i + self.batch_size] | |
if len(segment_list) < self.batch_size: | |
orig_len = len(segment_list) | |
while len(segment_list) < self.batch_size: | |
segment_list.append(segment_list[randint(0, orig_len - 1)]) | |
return segment_list | |
def _tokenize(self, segment_list): | |
#{'task': 'translate', | |
# 'src_segm': src_segm, | |
# 'tgt_segm': tgt_segm, | |
# 'src_lang': src_lang, | |
# 'tgt_lang': tgt_lang} | |
prepped_segm_list = [prep_llm_input(s) for s in segment_list] | |
self.tokenizer.pad_token = '<|reserved_special_token_0|>' | |
tokenized_batch = self.tokenizer(prepped_segm_list, return_tensors="pt", max_length=self.max_len, | |
truncation=True, add_special_tokens=True, | |
padding=True) | |
return tokenized_batch, self.curr_elem_idx + 1 | |
def __next__(self): | |
if self.curr_elem_idx >= self.data_len: | |
raise StopIteration | |
else: | |
segment_list = self._get_properly_sized_segment_list() | |
batch = self._tokenize(segment_list) | |
self.curr_elem_idx += 1 | |
return batch | |
""" | |
""" | |
class MultilingualDatasetIterator(IterableDataset): | |
def _load_metafile(self, cache_metafile): | |
with open(cache_metafile, 'r') as f: | |
self.metainfo = json.load(f) | |
self.data_len = sum([e['shard_size'] for e in self.metainfo]) | |
def _init_curr_shard(self): | |
cache_location = self.metainfo[self.curr_shard_idx]['shard_filename'] | |
self.curr_shard_data = torch.load(cache_location, weights_only=False) | |
assert len(self.curr_shard_data) == self.metainfo[self.curr_shard_idx]['shard_size'] | |
def __init__(self, filename): | |
self.curr_shard_idx = 0 | |
self.curr_elem_idx = 0 | |
self.prev_shard_sum_len = 0 | |
if filename is not None: | |
self._load_metafile(filename) | |
def __iter__(self): | |
self._init_curr_shard() | |
return self | |
def where_are_we(self): | |
return DataState(shard_idx=self.curr_shard_idx, elem_idx=self.curr_elem_idx) | |
def thats_where(self, data_state): | |
self.curr_shard_idx = data_state.shard_idx | |
self.curr_elem_idx = data_state.elem_idx | |
self.prev_shard_sum_len = sum([e['shard_size'] for i, e in enumerate(self.metainfo) if i < self.curr_shard_idx]) | |
def __next__(self): | |
try: | |
result_data = self.curr_shard_data[self.curr_elem_idx] | |
self.curr_elem_idx += 1 | |
except IndexError: | |
self.prev_shard_sum_len += self.metainfo[self.curr_shard_idx]['shard_size'] | |
self.curr_shard_idx += 1 | |
if self.curr_shard_idx >= len(self.metainfo): | |
self.__init__(None) | |
raise StopIteration | |
else: | |
self._init_curr_shard() | |
self.curr_elem_idx = 0 | |
result_data = self.curr_shard_data[self.curr_elem_idx] | |
self.curr_elem_idx += 1 | |
index_in_epoch = self.prev_shard_sum_len + self.curr_elem_idx | |
return result_data, index_in_epoch | |
def __len__(self): | |
return self.data_len | |
def dump_to_stdout(): | |
filename = sys.argv[1] | |
lc_src = defaultdict(int) | |
tot_len = 0 | |
tot_count = 0 | |
for tr_pair in get_tr_pairs(filename=filename): | |
print(tr_pair.src_lang + "\t" + tr_pair.input + "\t" + tr_pair.tgt_lang + "\t" + tr_pair.output) | |
tot_len += upd_lc(lc_src, tr_pair.src_lang, tr_pair.input) | |
tot_len += upd_lc(lc_src, tr_pair.tgt_lang, tr_pair.output) | |
tot_count += 2 | |
totes = sum(lc_src.values()) | |
for k in sorted(lc_src): | |
sys.stderr.write(f"{k}: {100*lc_src[k]/totes:.1f}%\n") | |
sys.stderr.write(f"Avg length: {tot_len/float(tot_count):.1f}\n") | |
def do_stats(filename): | |
stats = defaultdict(int) | |
raw_data = load_json_data(filename) | |
for data in raw_data: | |
langs = sorted([k for k in data.keys() if data[k].strip() != ""]) | |
stats["-".join(langs)] += 1 | |
for k in stats: | |
print(k, stats[k]) | |
def lang_from_name(filename): | |
return filename.split(".")[-1] | |
def moses_to_json(file1, file2): | |
result = list() | |
l1 = lang_from_name(file1) | |
l2 = lang_from_name(file2) | |
with open(file1, "r") as h1, open(file2, "r") as h2: | |
for line1 in h1: | |
line2 = h2.readline() | |
result.append({l1: line1.strip(), l2: line2.strip()}) | |
return result | |
def multi_moses_to_json(output_file, init_json, input_file_tuples): | |
try: | |
with open(init_json, "r") as h: | |
result = json.load(h) | |
except: | |
result = list() | |
for input_file_tuple in input_file_tuples: | |
this_result = moses_to_json(*input_file_tuple) | |
result.append({"source": f"{input_file_tuple[0]}-{input_file_tuple[1]}", "sentences": this_result}) | |
with open(output_file, "w") as f: | |
json.dump(result, f, indent=2, sort_keys=True) | |
def group_tuples(input_tuples): | |
return [(input_tuples[2 * i], input_tuples[2 * i + 1]) for i in range(int(len(input_tuples) / 2))] | |
def combine_two_jsons(json_target, json_addition): | |
for k in json_addition: | |
if k in json_target: | |
json_target[k] += json_addition[k] | |
else: | |
json_target[k] = json_addition[k] | |
def combine_jsons(filelist): | |
result = dict() | |
for filename in filelist: | |
data = json.load(open(filename)) | |
combine_two_jsons(result, data) | |
json.dumps(result) | |
def _dev_to_dict(filename): | |
result = defaultdict(lambda: defaultdict(int)) | |
for dev_sample in load_json_data(filename): | |
for lang in dev_sample: | |
if not "dia" in lang: | |
result[lang][dev_sample[lang]] = 1 | |
return result | |
def check_cross_pollination(small_path, large_path): | |
print("preparing dev set") | |
dct = _dev_to_dict(small_path) | |
print("reading train set") | |
for train_sample in load_json_data(large_path): | |
for lang in train_sample: | |
if not "dia" in lang and lang in dct: | |
snt = train_sample[lang] | |
if snt in dct[lang]: | |
dct[lang][snt] += 1 | |
print("---------------------") | |
print("contamination report:") | |
print("---------------------") | |
for lang in dct: | |
total = 0 | |
counts = 0 | |
freqs = 0 | |
for snt in dct[lang]: | |
total += 1 | |
if dct[lang][snt] > 1: | |
counts += 1 | |
freqs += (dct[lang][snt] - 1) | |
print(f"{lang}: contaminated: {counts} ({100*counts/float(total):.1f}%), total occurrence: {freqs}") | |
def char_class(c): | |
lc = c.lower() | |
if re.match("[a-z]", lc): | |
return "latn" | |
elif re.match("[а-я]", lc): | |
return "cyrl" | |
else: | |
return "other" | |
def snt_is_fishy(snt_raw, lang, detailed=False): | |
snt = re.sub(r'^<[^>]+> ', '', snt_raw) | |
snt_db = defaultdict(int) | |
for c in snt: | |
c_c = char_class(c) | |
snt_db[c_c] += 1 | |
tot = snt_db['latn'] + snt_db['cyrl'] | |
if tot > 0: | |
if snt_db['latn'] / tot > 0.7: | |
this_is = 'latn' | |
elif snt_db['cyrl'] / tot > 0.7: | |
this_is = 'cyrl' | |
else: | |
this_is = 'mix' | |
should_be = any_to_nllb(lang).split("_")[1].lower() | |
if should_be != this_is: | |
return (True, this_is, should_be) if detailed else True | |
return (False, None, None) if detailed else False | |
def script_stats(): | |
db = defaultdict(lambda: defaultdict(int)) | |
# corp = [] | |
for raw_line in sys.stdin: | |
lang, snt_raw = raw_line.strip().split("\t") | |
is_fishy, this_is, should_be = snt_is_fishy(snt_raw, lang, detailed=True) | |
if is_fishy: | |
print(f"{lang}: should be {should_be}, is actually {this_is}:\n{snt_raw}") | |
def get_full_lang(lang, tupl): | |
dia_key = f"{lang}-dia" | |
if dia_key in tupl: | |
return f"{lang}, {tupl[dia_key]}" | |
else: | |
return lang | |
def convert_json_to_json(src_json, dest_json): | |
raw_data = load_json_data(src_json) | |
output_data = [] | |
for tupl in raw_data: | |
for l1 in tupl: | |
for l2 in tupl: | |
if l1 != l2 and not "dia" in l1 and not "dia" in l2: | |
src_segm = tupl[l1] | |
tgt_segm = tupl[l2] | |
src_lang = get_full_lang(l1, tupl) | |
tgt_lang = get_full_lang(l2, tupl) | |
output_data.append({ 'task': 'translate', | |
'src_segm': src_segm, | |
'tgt_segm': tgt_segm, | |
'src_lang': src_lang, | |
'tgt_lang': tgt_lang}) | |
with open(dest_json, "w") as f: | |
json.dump(output_data, f, indent=2) | |
""" | |
if __name__ == "__main__": | |
# check_cross_pollination(sys.argv[1], sys.argv[2]) | |
# multi_moses_to_json(sys.argv[1], sys.argv[2], group_tuples(sys.argv[3:])) | |
# combine_jsons(sys.argv[1:]) | |
# do_stats("data/train.json") | |
# dump_to_stdout() | |
# script_stats() | |
# convert_json_to_json(sys.argv[1], sys.argv[2]) | |
pass | |