#!/usr/bin/env python3
import os
import sentencepiece as spm
import json
from transformers import AutoTokenizer
from transformers.models.nllb import NllbTokenizer
from transformers.models.t5 import T5Tokenizer
from collections import defaultdict
from aux import log
from legacy.langconv import langs_to_madlad, langs_to_nllb, is_nllb, is_madlad, is_dec_only_llm
from modelops import hf_tok
def test_tok(tok, snt, lang):
tok.src_lang = lang
out = tok(text = snt)
print(out['input_ids'])
print(tok.tokenize(snt))
print(tok.convert_ids_to_tokens(out['input_ids']))
print("-")
def get_stupid_correction(mdl_id):
l_mdl_id = mdl_id.lower()
if "m2m" in l_mdl_id:
correction = 108
elif "nllb" in l_mdl_id:
correction = 2
else:
correction = 0
return correction
def tsv_to_json_vocab(location):
new_location = location + ".json"
with open(location, "r") as f, open(new_location, "w") as w:
idx_dict = { "": 0, "": 1, "": 2, "": 3 }
for line in f:
tok, _ = line.strip().split("\t")
if tok not in idx_dict:
idx_dict[tok] = len(idx_dict)
json.dump(idx_dict, w)
return new_location
def get_unk_toks(tokenizer, corpus, verbose=False):
unk_id = tokenizer.unk_token_id
unk_toks = defaultdict(int)
all_toks = set()
total_count = 0
unk_count = 0
with open(corpus, "r", encoding='utf-8') as f:
for snt in f:
toks = tokenizer.tokenize(snt.strip())
ids = tokenizer.convert_tokens_to_ids(toks)
for t, i in zip(toks, ids):
if i == unk_id:
unk_toks[t] += 1
unk_count += 1
total_count += 1
all_toks.add(t)
if verbose:
print(f"Tokenizer vocab size: {tokenizer.vocab_size}, nr of actually used tokens: {len(all_toks)}")
print(f"Corpus token count: {total_count}, UNK token percentage: {100*unk_count/total_count:.2f}%")
return list(unk_toks)
def get_top_toks(tokenizer, corpus, num_top_toks):
freq_count = defaultdict(int)
with open(corpus, "r", encoding='utf-8') as f:
for snt in f:
toks = tokenizer.tokenize(snt.strip())
for t in toks:
freq_count[t] += 1
sorted_freq_count = sorted(freq_count.keys(), key=lambda x: -freq_count[x])
return sorted_freq_count[:num_top_toks]
def extend_tok_langs(tokenizer, lang_set_raw):
if is_nllb(tokenizer):
lang_set = langs_to_nllb(lang_set_raw)
elif is_madlad(tokenizer):
lang_set = langs_to_madlad(lang_set_raw)
elif is_dec_only_llm(tokenizer):
return
else:
raise NotImplementedError
if 'additional_special_tokens' in tokenizer.special_tokens_map:
orig_langs = tokenizer.special_tokens_map['additional_special_tokens']
orig_lang_set = set(orig_langs)
addable_langs = list(set(lang_set) - orig_lang_set)
else:
orig_langs = []
addable_langs = lang_set
tokenizer.add_special_tokens({'additional_special_tokens': orig_langs + addable_langs})
def wrap_tok_in_correct_class(location, base_model_id, lang_set):
l_base_mdl_id = base_model_id.lower()
if "nllb" in l_base_mdl_id:
nllb_lang_set = langs_to_nllb(lang_set)
return NllbTokenizer(location + ".model", additional_special_tokens=nllb_lang_set)
elif "madlad" in l_base_mdl_id or "t5" in l_base_mdl_id:
madlad_lang_set = langs_to_madlad(lang_set)
return T5Tokenizer(location + ".model", additional_special_tokens=madlad_lang_set)
else:
raise ValueError("Incompatible model type for tokenizer")
def remove_tmp_spm_files(location):
for tmp_file in (".vocab", ".model"):
os.remove(location + tmp_file)
def learn_spm_tokenizer(corpus, save_location, base_model_id, vocab_size, lang_set=None):
tmp_location = os.path.join(save_location, "sentencepiece.bpe.tmp")
os.makedirs(save_location, exist_ok=True)
spm.SentencePieceTrainer.train(input=corpus, model_prefix=tmp_location, vocab_size=vocab_size)
tok = wrap_tok_in_correct_class(tmp_location, base_model_id, lang_set)
remove_tmp_spm_files(tmp_location)
return tok
def do_new_tok(tokargs):
correction = get_stupid_correction(tokargs.mdl_id)
voc_size = tokargs.vocab_size - correction
location = tokargs.save_location
return learn_spm_tokenizer(tokargs.tok_train_file, location, base_model_id=tokargs.tok_mdl_id,
vocab_size=voc_size, lang_set=tokargs.new_langs)
def remove_known_toks(toks, tokenizer):
return [t for t in toks if not t in tokenizer.get_vocab()]
def _handle_new_tokenizer(args):
assert args.new_langs is not None, "lang_set must be provided"
assert args.tok_train_file is not None, "tok_train_file must be provided"
args.vocab_size = int(args.vocab_size)
log("Training new tokenizer")
tokenizer = do_new_tok(args)
return tokenizer
def get_postoken_filename(save_location):
return os.path.join(save_location, "postokens.json")
def save_postokens(added_tokens, location):
if added_tokens is not None:
os.makedirs(location, exist_ok=True)
with open(get_postoken_filename(location), "w") as f:
json.dump(added_tokens, f)
def _handle_adding_tokens(tokenizer, toks_to_add, args):
if len(toks_to_add) == 0:
return None
log(f"Adding tokens: {toks_to_add}")
base_idx = len(tokenizer)
added_tok_dict = { t: (base_idx + i) for i, t in enumerate(toks_to_add) }
added_tok_rev_dict = { int(i): t for t, i in added_tok_dict.items() }
comb_dict = { 'tok2idx': added_tok_dict, 'idx2tok': added_tok_rev_dict }
save_postokens(comb_dict, args.save_location)
return comb_dict
def _handle_existing_tokenizer(args):
log("Reusing existing tokenizer")
tokenizer, added_tokens = load_tokenizer(args.tok_mdl_id)
if args.new_langs is not None:
log("Extending existing tokenizer with languages")
extend_tok_langs(tokenizer, args.new_langs)
if args.merge_tokenizers or args.merge_tok_mdl_id:
"""
assert args.tok_train_file is not None, "For merging tokenizers a text file must be provided" \
+ " to find the top N tokens to merge"
assert args.merge_tokenizers is not None and args.merge_tok_mdl_id is not None, \
"Both merge_tokenizers and merge_tok_mdl_id must be provided"
"""
raise NotImplementedError("Merging is currently not supported")
added_tok_count = 0
if args.tok_train_file:
if args.merge_tokenizers:
"""
merge_tok_max = int(args.merge_tokenizers)
log(f"Extending existing tokenizer ({args.merge_tok_mdl_id}) with up to {merge_tok_max} top tokens" +
f" from another tokenizer and corpus ({args.tok_train_file})")
new_tok = AutoTokenizer.from_pretrained(args.merge_tok_mdl_id, token=hf_tok)
toks_to_maybe_add = get_top_toks(new_tok, args.tok_train_file, merge_tok_max)
"""
raise NotImplementedError("Merging is currently not supported")
else:
log(f"Extending existing tokenizer with UNK tokens from corpus ({args.tok_train_file})")
toks_to_maybe_add = get_unk_toks(tokenizer, args.tok_train_file, verbose=True)
toks_to_add = remove_known_toks(toks_to_maybe_add, tokenizer)
added_tok_count = len(toks_to_add)
added_tokens = _handle_adding_tokens(tokenizer, toks_to_add, args)
return tokenizer, added_tok_count, added_tokens
def train_or_extend_tokenizer_and_upd_model(args, model):
if hasattr(args, "vocab_size") and args.vocab_size:
# train a new sentence-piece tokenizer
tokenizer = _handle_new_tokenizer(args)
added_tok_count = 0
added_dict = None
else:
# save the pre-trained model's tokenizer, possibly adding new languages and tokens
tokenizer, added_tok_count, added_dict = _handle_existing_tokenizer(args)
upd_amt = get_stupid_correction(args.mdl_id)
new_len = len(tokenizer) + added_tok_count
model.resize_token_embeddings(new_len + upd_amt)
return tokenizer, added_dict
def load_tokenizer(tok_mdl_id):
orig_tokenizer = AutoTokenizer.from_pretrained(tok_mdl_id, token=hf_tok)
postoken_file = get_postoken_filename(tok_mdl_id)
if os.path.exists(postoken_file):
with open(postoken_file, "r") as f:
postokens = json.load(f)
else:
postokens = None
return orig_tokenizer, postokens
def tokenize_batch(tokenizer, sntlist, maxlen=8000):
#tokenizer.pad_token = '<|reserved_special_token_0|>'
tokenizer.pad_token = tokenizer.eos_token
output = tokenizer(sntlist, return_tensors="pt", max_length=maxlen, truncation=True, add_special_tokens=True,
padding=True)
output["labels"] = output["input_ids"].detach().clone()
return output
"""
def detokenizeit(toktup, tok_ids):
#return toktup[0].decode(tok_ids, skip_special_tokens=True)
toks = []
for tok_id_tensor in tok_ids:
tok_id = tok_id_tensor.item()
try:
if tok_id not in toktup[0].added_tokens_decoder:
toks.append(toktup[0].convert_ids_to_tokens(tok_id))
except IndexError:
toks.append(toktup[1]['idx2tok'][str(tok_id)])
result = "".join(toks).replace("▁", " ")[1:]
return result, toks
def detokenizemany(toktup, tok_mtx):
result = [detokenizeit(toktup, tok_ids)[0] for tok_ids in tok_mtx]
return result
def run_tokenizer_testing():
args = CmdlineArgs("Test a tokenizer: tokenize & de-tokenize some text and check if these match",
pos_arg_list=["tok_mdl_id", "txt_file"])
#tokenizer = AutoTokenizer.fromm_pretrained(args.tok_mdl_id, token=hf_tok) if os.path.exists()
toktup = load_tokenizer(args.tok_mdl_id)
success = 0
failure = 0
with open(args.txt_file, "r", encoding="utf-8") as f:
snts = f.read().split("\n")
toks = tokenizeit(toktup, snts, 1024, False)
for i, snt in enumerate(snts):
tok_ids = toks['input_ids'][i]
#detoks = toktup[0].decode(tok_ids, skip_special_tokens=True)
detoks, tok_strs = detokenizeit(toktup, tok_ids)
if detoks != snt:
failure += 1
#log(f"Tokens: {toktup[0].convert_ids_to_tokens(tok_ids)}")
log(f"Tokens: {tok_strs}")
log(f"Test failed:\n{snt} !=\n{detoks}")
else:
success += 1
i += 1
log(f"Test result: {success} successful / {failure} failed")
if __name__ == "__main__":
sys.argv = ['', 'models/nllbxt', 'data/tok-test.txt']
run_tokenizer_testing()
"""