Rasmus Lellep
add loader
76b1ec5
raw
history blame
25.6 kB
#!/usr/bin/env python3
import json
#import os
import sys
import torch
#import re
import math
from torch.utils.data import IterableDataset
from collections import namedtuple, defaultdict
from random import randrange, shuffle, randint
#from pathlib import Path
#from aux import log
#from langconv import any_to_madlad, any_to_nllb, is_nllb, is_madlad, get_mdl_type, any_to_mdl_type, is_dec_only_llm, \
# base_to_nllb
#from tokops import tokenizeit
# TrPair = namedtuple('TrPair', ["src_lang", "tgt_lang", "input", "output"])
"""
def prep_llm_input(ljmftpl):
#{'task': 'translate' / 'approx-translate' / 'generate',
# 'src_segm': src_segm,
# 'tgt_segm': tgt_segm,
# 'src_lang': src_lang,
# 'tgt_lang': tgt_lang}
# it's a tuple
if "src_segm" in ljmftpl and "task" in ljmftpl:
if ljmftpl['task'] in {'translate', 'approx-translate'}:
return (f"{ljmftpl['src_segm']}\n=====\n{ljmftpl['task']} from {ljmftpl['src_lang']}; " +
f"to {ljmftpl['tgt_lang']}:\n{ljmftpl['tgt_segm']}")
elif ljmftpl['task'] == 'generate':
return f"{ljmftpl['src_segm']}\n=====\nis in {ljmftpl['src_lang']};"
# it's a string
else:
return ljmftpl
def make_path_compatible(filename):
return filename.replace("/", "_").replace(":", "-")
def do_list_in_batches(data, batch_size):
i = 0
while i < len(data):
yield data[i:i + batch_size]
i += batch_size
"""
"""
def do_bins_in_batches(bins, batch_size, sort_by_length):
result_list = []
for src_k in bins:
for tgt_k in bins[src_k]:
if src_k == 0 or tgt_k == 0:
result_list += [(e, src_k, tgt_k) for e in do_list_in_batches(bins[src_k][tgt_k], batch_size)]
shuffle(result_list)
return result_list
def _post_proc(text, lang):
if lang == 'liv' and "’" in text and "O’R" not in text:
return text.replace("’", "")
else:
return text
def clean_entry(entry, leave_out):
result = {k: _post_proc(entry[k], k) for k in entry if entry[k].strip() and k not in leave_out}
return result
def load_json_data(path, leave_out={}, skip_cats=True, load_mono=True):
with open(path, 'r') as f:
data = json.load(f)
if skip_cats:
# skip categories
resx = [clean_entry(entry, leave_out)
for cat in data for entry in cat['sentences']]
res = [e for e in resx if e]
else:
raise NotImplementedError
# resx = {cat['source']: [clean_entry(entry, leave_out) for entry in cat['sentences']] for cat in data}
# res = {k: resx[k] for k in resx if resx[k]}
return res
def get_tr_pairs(raw_data=None, filename=None, leave_out=None, leave_only=None, model_type=None, exclude_set=None):
if filename is not None:
raw_data = load_json_data(filename)
if raw_data is None:
raise ValueError("Neither file nor data are provided")
i = 0
log("Loading data")
for tup in raw_data:
for l1 in tup:
for l2 in tup:
if l1 != l2 and not "dia" in l1 and not "dia" in l2:
if leave_out is None or f"{l1}-{l2}" not in leave_out:
if leave_only is None or f"{l1}-{l2}" in leave_only:
i += 1
if not i % 1000000:
log(f"Loaded {i/1000000}M pairs")
dia_key = f"{l2}-dia"
if exclude_set is None or (tup[l1] not in exclude_set[l1] and tup[l2] not in exclude_set[l2]):
input = tup[l1]
if dia_key in tup:
input = f"<{tup[dia_key]}> {input}"
conv_l1 = any_to_mdl_type(model_type, l1)
conv_l2 = any_to_mdl_type(model_type, l2)
if not snt_is_fishy(input, conv_l1) and not snt_is_fishy(tup[l2], conv_l2):
yield TrPair(conv_l1, conv_l2, input, tup[l2])
def split_by_lang(filename, model_type):
result = defaultdict(list)
# if filename is not None:
# tr_pairs = load_json_datax(filename)
tr_pairs = get_tr_pairs(filename=filename, model_type=model_type)
for tup in tr_pairs:
#for l1 in tup:
# for l2 in tup:
# if l1 != l2 and not "dia" in l1 and not "dia" in l2:
l1 = tup.src_lang
l2 = tup.tgt_lang
lp = f"{l1}-{l2}"
result[lp].append((tup.input, tup.output))
return result
def data_iter_for_tok_train(raw_data, langs_to_include):
for tup in raw_data:
for lang in tup:
if lang in langs_to_include:
yield tup[lang]
def lang_bin_mapping(coupling_specs):
lang_to_idx = dict()
for i, spec_pair in enumerate(coupling_specs):
for lang in spec_pair.lang_set:
if lang not in lang_to_idx:
lang_to_idx[lang] = {i}
else:
lang_to_idx[lang].add(i)
return lang_to_idx
def mix_and_sample_idxs_carefully(src_idxs, tgt_idxs):
idx_pairs = [(s, t) for s in src_idxs for t in tgt_idxs if not (s == 1 and t == 1)]
if len(idx_pairs) == 0:
result = (None, None)
else:
pair_idx = randrange(len(idx_pairs))
result = idx_pairs[pair_idx]
# debug(f"src lang: {tr_pair.src_lang}, tgt_lang: {tr_pair.tgt_lang}, idx list: {idx_pairs}, result: {result}")
return result
def inject_bin_indices(batch, src_k, tgt_k):
batch['input_ids'][0,0] += src_k << 30
batch['labels'][0,0] += tgt_k << 30
def get_data_cache_location(cache_meta_path, idx):
cache_folder, cache_file = os.path.split(cache_meta_path)
if cache_folder:
Path(cache_folder).mkdir(parents=True, exist_ok=True)
if cache_meta_path.endswith(".json"):
return cache_meta_path[:-5] + f"_{idx:04}.pt"
else:
raise ValueError(f"Expected a json file for the cache meta-location ({cache_meta_path})")
def make_gen_text(src_lang, tgt_lang, input_text, output_text=None, tok=None):
if input_text.startswith("<"):
posit = input_text.find(">") + 1
dialect = input_text[1:posit-1]
diatxt = f", variety: {dialect}"
txt = input_text[posit+1:]
else:
dialect = None
diatxt = ""
txt = input_text
return (f"Translate:\n== From: {src_lang}\n== To: {tgt_lang}{diatxt}\n== Input: {txt}\n== Output: " +
("" if (output_text is None or tok is None) else f"{output_text}{tok.eos_token}"))
class MultilingualBatchingCachingDataset:
def _post_proc_bins(self, bins):
for src_k in bins:
for tgt_k in bins[src_k]:
while len(bins[src_k][tgt_k]) % self.args.batch_size != 0:
rnd_elem_idx = randrange(len(bins[src_k][tgt_k]))
rnd_elem = bins[src_k][tgt_k][rnd_elem_idx]
bins[src_k][tgt_k].append(rnd_elem)
if self.args.sort_by_len:
bins[src_k][tgt_k] = sorted(bins[src_k][tgt_k], key=lambda e: len(e.input))
else:
shuffle(bins[src_k][tgt_k])
return bins
def _get_idxs(self, tr_pair):
src_idxs = self._lang_to_idx[tr_pair.src_lang]
tgt_idxs = self._lang_to_idx[tr_pair.tgt_lang]
return mix_and_sample_idxs_carefully(src_idxs, tgt_idxs)
def _fill_bins(self):
bins = defaultdict(lambda: defaultdict(list))
for tr_pair in get_tr_pairs(filename=self.filename, model_type=self.model_type, exclude_set=self.exclude_set):
src_bin_idx, tgt_bin_idx = self._get_idxs(tr_pair)
if src_bin_idx is not None and tgt_bin_idx is not None:
bins[src_bin_idx][tgt_bin_idx].append(tr_pair)
return self._post_proc_bins(bins)
def report_update_stats(self, bins):
total = 0
totalx = 0
updates = 0
duds = 0
enc_count = 0
dec_count = 0
for src_k in bins:
for tgt_k in bins[src_k]:
l = len(bins[src_k][tgt_k])
total += l
if src_k == 0 or tgt_k == 0:
totalx += l
updates += l * (1 - (src_k + tgt_k) / 2)
enc_count += l * (1 - src_k)
dec_count += l * (1 - tgt_k)
if src_k == 1 and tgt_k == 1:
duds += 1
# log(str(self._lang_to_idx))
log(f"### Ratio of coupled model updates: {100 * updates / total:.2f}% ({100 * updates / totalx:.2f}%); " + \
f"frozen meaningless updates: {100 * duds / total:.2f}%; " + \
f"enc samples: {enc_count}, dec samples: {dec_count}")
def tokenize_input(self, cplspec, input_list, rawbatch):
src_tokenizer = cplspec.tokenizer
src_tokenizer.src_lang = rawbatch[0].src_lang
#prep_batch_grouped = src_tokenizer(text=input_list, return_tensors="pt",
# padding="longest", truncation=True, max_length=self.args.max_snt_len)
prep_batch_grouped = tokenizeit((src_tokenizer, cplspec.postokenizer), input_list, self.args.max_snt_len, False)
if is_nllb(src_tokenizer):
src_lang_list = [any_to_nllb(e.src_lang) for e in rawbatch]
src_lang_vec = src_tokenizer.convert_tokens_to_ids(src_lang_list)
prep_batch_grouped['input_ids'][:,0] = torch.tensor(src_lang_vec)
return prep_batch_grouped
def tokenize_output(self, tgttokenizer, tgtposttok, rawbatch):
outputs = [e.output for e in rawbatch]
tgttokenizer.tgt_lang = rawbatch[0].tgt_lang
#labels = tgttokenizer(text_target=outputs, return_tensors="pt",
# padding="longest", truncation=True, max_length=self.args.max_snt_len)
labels = tokenizeit((tgttokenizer, tgtposttok), outputs, self.args.max_snt_len, True)
if is_nllb(tgttokenizer):
tgt_lang_list = [any_to_nllb(e.tgt_lang) for e in rawbatch]
tgt_lang_vec = tgttokenizer.convert_tokens_to_ids(tgt_lang_list)
labels['input_ids'][:, 0] = torch.tensor(tgt_lang_vec)
return labels
def tokenize_gen_batch(self, raw_batch):
tokenizer = self.coupling_specs[0].tokenizer
tokenizer.pad_token = '<|reserved_special_token_0|>'
tokenizer.padding_side = 'left'
texts = [make_gen_text(e.src_lang, e.tgt_lang, e.input, e.output, tokenizer) for e in raw_batch]
#batch = tokenizer(texts, return_tensors="pt", max_length=512, truncation=True, add_special_tokens=True, padding=True)
batch = tokenizeit((tokenizer, self.coupling_specs[0].postokenizer), texts, self.args.max_snt_len, False)
return batch
def tokenize_and_pad(self, raw_batch, src_k, tgt_k):
tgt_tokenizer = self.coupling_specs[tgt_k].tokenizer
tgt_postok = self.coupling_specs[tgt_k].postokenizer
if is_madlad(tgt_tokenizer):
inputs = [f"{any_to_madlad(e.tgt_lang)} {e.input}" for e in raw_batch]
else:
inputs = [e.input for e in raw_batch]
prep_batch_grouped = self.tokenize_input(self.coupling_specs[src_k], inputs, raw_batch)
labels = self.tokenize_output(tgt_tokenizer, tgt_postok, raw_batch)
prep_batch_grouped['labels'] = labels['input_ids']
# inject_bin_indices(prep_batch_grouped, src_k, tgt_k)
#split_prep_batch = [{k: prep_batch_grouped[k][i] for k in prep_batch_grouped}
# for i, trp in enumerate(raw_batch)]
return prep_batch_grouped
def _bins_to_tokenized_batched_cached_data(self, bins, cache_path):
shard_i = 0
batch_i = 0
total_i = 0
metainfo = []
data = []
log("Tokenizing data")
for raw_batch, src_k, tgt_k in do_bins_in_batches(bins, self.args.batch_size, self.args.sort_by_len):
batch_i += 1
if not batch_i % 10000:
log(f"Tokenized {batch_i + shard_i * self.args.shard_size} batches (shard {shard_i})")
if is_dec_only_llm(self.coupling_specs[tgt_k].tokenizer):
prepared_batch = self.tokenize_gen_batch(raw_batch)
data.append((prepared_batch, total_i))
else:
prepared_batch = self.tokenize_and_pad(raw_batch, src_k, tgt_k)
data.append((prepared_batch, src_k, tgt_k, total_i))
if batch_i >= self.args.shard_size:
shard_i += 1
batch_i = 0
fn = self._save_cache_file(data, cache_path, shard_i)
metainfo.append({'shard_filename': fn, 'shard_size': len(data)})
del data
data = []
total_i += 1
if len(data) > 0:
fn = self._save_cache_file(data, cache_path, shard_i + 1)
metainfo.append({'shard_filename': fn, 'shard_size': len(data)})
with open(cache_path, 'w') as f:
json.dump(metainfo, f)
del data
@staticmethod
def _save_cache_file(data, cache_location, idx):
cache_location = get_data_cache_location(cache_location, idx)
if os.path.exists(cache_location):
raise Exception("Cache already exists")
torch.save(data, cache_location)
log(f"Saved data into cache (shard {idx})")
return cache_location
def set_model_type(self):
result = None
for spec_tuple in self.coupling_specs:
this_type = get_mdl_type(spec_tuple.tokenizer)
if result is None:
result = this_type
else:
assert result == this_type, "in this implementation model types (NLLB/MADLAD/...) must be the same for all included models"
return result
def __init__(self, tr_file, coupling_specs, args):
self.args = args
self.filename = tr_file
self.coupling_specs = coupling_specs
self.exclude_set = _dev_to_dict(args.exclude_set) if args.exclude_set is not None else None
self.model_type = self.set_model_type()
# init lang to idx
self._lang_to_idx = lang_bin_mapping(coupling_specs)
def load_and_cache_data(self, cache_path):
# collect data into bins and cache it
bins = self._fill_bins()
self.report_update_stats(bins)
self._bins_to_tokenized_batched_cached_data(bins, cache_path)
"""
"""
class DataState:
def __init__(self, elem_idx = 0, shard_idx = 0, epoch_idx = None):
self.elem_idx = elem_idx
self.shard_idx = shard_idx
self.epoch_idx = epoch_idx
def state_dict(self):
return {'elem_idx': self.elem_idx, 'shard_idx': self.shard_idx, 'epoch_idx': self.epoch_idx}
def load_state_dict(self, state_dict):
self.elem_idx = state_dict['elem_idx']
self.shard_idx = state_dict['shard_idx']
self.epoch_idx = state_dict['epoch_idx']
def copy_from(self, src_ds, epoch_idx = None):
self.shard_idx = src_ds.shard_idx
self.elem_idx = src_ds.elem_idx
if epoch_idx is not None:
self.epoch_idx = epoch_idx
def __str__(self):
return 'DataState(elem_idx={}, shard_idx={}, epoch_idx={})'.format(self.elem_idx, self.shard_idx, self.epoch_idx)
def __repr__(self):
return self.__str__()
class BatchingIterator(IterableDataset):
def __init__(self, segment_list, batch_size, tokenizer, max_len=8000):
self.data = segment_list
shuffle(self.data)
self.batch_size = batch_size
self.tokenizer = tokenizer
self.max_len = max_len
self.curr_elem_idx = 0
self.data_len = math.ceil(len(self.data) / self.batch_size)
def __len__(self):
return self.data_len
def __iter__(self):
self.curr_elem_idx = 0
return self
def where_are_we(self):
return DataState(shard_idx=0, elem_idx=self.curr_elem_idx)
def thats_where(self, data_state):
self.curr_elem_idx = data_state.elem_idx
def _get_properly_sized_segment_list(self):
i = self.curr_elem_idx * self.batch_size
segment_list = self.data[i:i + self.batch_size]
if len(segment_list) < self.batch_size:
orig_len = len(segment_list)
while len(segment_list) < self.batch_size:
segment_list.append(segment_list[randint(0, orig_len - 1)])
return segment_list
def _tokenize(self, segment_list):
#{'task': 'translate',
# 'src_segm': src_segm,
# 'tgt_segm': tgt_segm,
# 'src_lang': src_lang,
# 'tgt_lang': tgt_lang}
prepped_segm_list = [prep_llm_input(s) for s in segment_list]
self.tokenizer.pad_token = '<|reserved_special_token_0|>'
tokenized_batch = self.tokenizer(prepped_segm_list, return_tensors="pt", max_length=self.max_len,
truncation=True, add_special_tokens=True,
padding=True)
return tokenized_batch, self.curr_elem_idx + 1
def __next__(self):
if self.curr_elem_idx >= self.data_len:
raise StopIteration
else:
segment_list = self._get_properly_sized_segment_list()
batch = self._tokenize(segment_list)
self.curr_elem_idx += 1
return batch
"""
"""
class MultilingualDatasetIterator(IterableDataset):
def _load_metafile(self, cache_metafile):
with open(cache_metafile, 'r') as f:
self.metainfo = json.load(f)
self.data_len = sum([e['shard_size'] for e in self.metainfo])
def _init_curr_shard(self):
cache_location = self.metainfo[self.curr_shard_idx]['shard_filename']
self.curr_shard_data = torch.load(cache_location, weights_only=False)
assert len(self.curr_shard_data) == self.metainfo[self.curr_shard_idx]['shard_size']
def __init__(self, filename):
self.curr_shard_idx = 0
self.curr_elem_idx = 0
self.prev_shard_sum_len = 0
if filename is not None:
self._load_metafile(filename)
def __iter__(self):
self._init_curr_shard()
return self
def where_are_we(self):
return DataState(shard_idx=self.curr_shard_idx, elem_idx=self.curr_elem_idx)
def thats_where(self, data_state):
self.curr_shard_idx = data_state.shard_idx
self.curr_elem_idx = data_state.elem_idx
self.prev_shard_sum_len = sum([e['shard_size'] for i, e in enumerate(self.metainfo) if i < self.curr_shard_idx])
def __next__(self):
try:
result_data = self.curr_shard_data[self.curr_elem_idx]
self.curr_elem_idx += 1
except IndexError:
self.prev_shard_sum_len += self.metainfo[self.curr_shard_idx]['shard_size']
self.curr_shard_idx += 1
if self.curr_shard_idx >= len(self.metainfo):
self.__init__(None)
raise StopIteration
else:
self._init_curr_shard()
self.curr_elem_idx = 0
result_data = self.curr_shard_data[self.curr_elem_idx]
self.curr_elem_idx += 1
index_in_epoch = self.prev_shard_sum_len + self.curr_elem_idx
return result_data, index_in_epoch
def __len__(self):
return self.data_len
def dump_to_stdout():
filename = sys.argv[1]
lc_src = defaultdict(int)
tot_len = 0
tot_count = 0
for tr_pair in get_tr_pairs(filename=filename):
print(tr_pair.src_lang + "\t" + tr_pair.input + "\t" + tr_pair.tgt_lang + "\t" + tr_pair.output)
tot_len += upd_lc(lc_src, tr_pair.src_lang, tr_pair.input)
tot_len += upd_lc(lc_src, tr_pair.tgt_lang, tr_pair.output)
tot_count += 2
totes = sum(lc_src.values())
for k in sorted(lc_src):
sys.stderr.write(f"{k}: {100*lc_src[k]/totes:.1f}%\n")
sys.stderr.write(f"Avg length: {tot_len/float(tot_count):.1f}\n")
def do_stats(filename):
stats = defaultdict(int)
raw_data = load_json_data(filename)
for data in raw_data:
langs = sorted([k for k in data.keys() if data[k].strip() != ""])
stats["-".join(langs)] += 1
for k in stats:
print(k, stats[k])
def lang_from_name(filename):
return filename.split(".")[-1]
def moses_to_json(file1, file2):
result = list()
l1 = lang_from_name(file1)
l2 = lang_from_name(file2)
with open(file1, "r") as h1, open(file2, "r") as h2:
for line1 in h1:
line2 = h2.readline()
result.append({l1: line1.strip(), l2: line2.strip()})
return result
def multi_moses_to_json(output_file, init_json, input_file_tuples):
try:
with open(init_json, "r") as h:
result = json.load(h)
except:
result = list()
for input_file_tuple in input_file_tuples:
this_result = moses_to_json(*input_file_tuple)
result.append({"source": f"{input_file_tuple[0]}-{input_file_tuple[1]}", "sentences": this_result})
with open(output_file, "w") as f:
json.dump(result, f, indent=2, sort_keys=True)
def group_tuples(input_tuples):
return [(input_tuples[2 * i], input_tuples[2 * i + 1]) for i in range(int(len(input_tuples) / 2))]
def combine_two_jsons(json_target, json_addition):
for k in json_addition:
if k in json_target:
json_target[k] += json_addition[k]
else:
json_target[k] = json_addition[k]
def combine_jsons(filelist):
result = dict()
for filename in filelist:
data = json.load(open(filename))
combine_two_jsons(result, data)
json.dumps(result)
def _dev_to_dict(filename):
result = defaultdict(lambda: defaultdict(int))
for dev_sample in load_json_data(filename):
for lang in dev_sample:
if not "dia" in lang:
result[lang][dev_sample[lang]] = 1
return result
def check_cross_pollination(small_path, large_path):
print("preparing dev set")
dct = _dev_to_dict(small_path)
print("reading train set")
for train_sample in load_json_data(large_path):
for lang in train_sample:
if not "dia" in lang and lang in dct:
snt = train_sample[lang]
if snt in dct[lang]:
dct[lang][snt] += 1
print("---------------------")
print("contamination report:")
print("---------------------")
for lang in dct:
total = 0
counts = 0
freqs = 0
for snt in dct[lang]:
total += 1
if dct[lang][snt] > 1:
counts += 1
freqs += (dct[lang][snt] - 1)
print(f"{lang}: contaminated: {counts} ({100*counts/float(total):.1f}%), total occurrence: {freqs}")
def char_class(c):
lc = c.lower()
if re.match("[a-z]", lc):
return "latn"
elif re.match("[а-я]", lc):
return "cyrl"
else:
return "other"
def snt_is_fishy(snt_raw, lang, detailed=False):
snt = re.sub(r'^<[^>]+> ', '', snt_raw)
snt_db = defaultdict(int)
for c in snt:
c_c = char_class(c)
snt_db[c_c] += 1
tot = snt_db['latn'] + snt_db['cyrl']
if tot > 0:
if snt_db['latn'] / tot > 0.7:
this_is = 'latn'
elif snt_db['cyrl'] / tot > 0.7:
this_is = 'cyrl'
else:
this_is = 'mix'
should_be = any_to_nllb(lang).split("_")[1].lower()
if should_be != this_is:
return (True, this_is, should_be) if detailed else True
return (False, None, None) if detailed else False
def script_stats():
db = defaultdict(lambda: defaultdict(int))
# corp = []
for raw_line in sys.stdin:
lang, snt_raw = raw_line.strip().split("\t")
is_fishy, this_is, should_be = snt_is_fishy(snt_raw, lang, detailed=True)
if is_fishy:
print(f"{lang}: should be {should_be}, is actually {this_is}:\n{snt_raw}")
def get_full_lang(lang, tupl):
dia_key = f"{lang}-dia"
if dia_key in tupl:
return f"{lang}, {tupl[dia_key]}"
else:
return lang
def convert_json_to_json(src_json, dest_json):
raw_data = load_json_data(src_json)
output_data = []
for tupl in raw_data:
for l1 in tupl:
for l2 in tupl:
if l1 != l2 and not "dia" in l1 and not "dia" in l2:
src_segm = tupl[l1]
tgt_segm = tupl[l2]
src_lang = get_full_lang(l1, tupl)
tgt_lang = get_full_lang(l2, tupl)
output_data.append({ 'task': 'translate',
'src_segm': src_segm,
'tgt_segm': tgt_segm,
'src_lang': src_lang,
'tgt_lang': tgt_lang})
with open(dest_json, "w") as f:
json.dump(output_data, f, indent=2)
"""
if __name__ == "__main__":
# check_cross_pollination(sys.argv[1], sys.argv[2])
# multi_moses_to_json(sys.argv[1], sys.argv[2], group_tuples(sys.argv[3:]))
# combine_jsons(sys.argv[1:])
# do_stats("data/train.json")
# dump_to_stdout()
# script_stats()
# convert_json_to_json(sys.argv[1], sys.argv[2])
pass