Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
import os | |
import sentencepiece as spm | |
import json | |
from transformers import AutoTokenizer | |
from transformers.models.nllb import NllbTokenizer | |
from transformers.models.t5 import T5Tokenizer | |
from collections import defaultdict | |
from aux import log | |
from legacy.langconv import langs_to_madlad, langs_to_nllb, is_nllb, is_madlad, is_dec_only_llm | |
from modelops import hf_tok | |
def test_tok(tok, snt, lang): | |
tok.src_lang = lang | |
out = tok(text = snt) | |
print(out['input_ids']) | |
print(tok.tokenize(snt)) | |
print(tok.convert_ids_to_tokens(out['input_ids'])) | |
print("-") | |
def get_stupid_correction(mdl_id): | |
l_mdl_id = mdl_id.lower() | |
if "m2m" in l_mdl_id: | |
correction = 108 | |
elif "nllb" in l_mdl_id: | |
correction = 2 | |
else: | |
correction = 0 | |
return correction | |
def tsv_to_json_vocab(location): | |
new_location = location + ".json" | |
with open(location, "r") as f, open(new_location, "w") as w: | |
idx_dict = { "<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3 } | |
for line in f: | |
tok, _ = line.strip().split("\t") | |
if tok not in idx_dict: | |
idx_dict[tok] = len(idx_dict) | |
json.dump(idx_dict, w) | |
return new_location | |
def get_unk_toks(tokenizer, corpus, verbose=False): | |
unk_id = tokenizer.unk_token_id | |
unk_toks = defaultdict(int) | |
all_toks = set() | |
total_count = 0 | |
unk_count = 0 | |
with open(corpus, "r", encoding='utf-8') as f: | |
for snt in f: | |
toks = tokenizer.tokenize(snt.strip()) | |
ids = tokenizer.convert_tokens_to_ids(toks) | |
for t, i in zip(toks, ids): | |
if i == unk_id: | |
unk_toks[t] += 1 | |
unk_count += 1 | |
total_count += 1 | |
all_toks.add(t) | |
if verbose: | |
print(f"Tokenizer vocab size: {tokenizer.vocab_size}, nr of actually used tokens: {len(all_toks)}") | |
print(f"Corpus token count: {total_count}, UNK token percentage: {100*unk_count/total_count:.2f}%") | |
return list(unk_toks) | |
def get_top_toks(tokenizer, corpus, num_top_toks): | |
freq_count = defaultdict(int) | |
with open(corpus, "r", encoding='utf-8') as f: | |
for snt in f: | |
toks = tokenizer.tokenize(snt.strip()) | |
for t in toks: | |
freq_count[t] += 1 | |
sorted_freq_count = sorted(freq_count.keys(), key=lambda x: -freq_count[x]) | |
return sorted_freq_count[:num_top_toks] | |
def extend_tok_langs(tokenizer, lang_set_raw): | |
if is_nllb(tokenizer): | |
lang_set = langs_to_nllb(lang_set_raw) | |
elif is_madlad(tokenizer): | |
lang_set = langs_to_madlad(lang_set_raw) | |
elif is_dec_only_llm(tokenizer): | |
return | |
else: | |
raise NotImplementedError | |
if 'additional_special_tokens' in tokenizer.special_tokens_map: | |
orig_langs = tokenizer.special_tokens_map['additional_special_tokens'] | |
orig_lang_set = set(orig_langs) | |
addable_langs = list(set(lang_set) - orig_lang_set) | |
else: | |
orig_langs = [] | |
addable_langs = lang_set | |
tokenizer.add_special_tokens({'additional_special_tokens': orig_langs + addable_langs}) | |
def wrap_tok_in_correct_class(location, base_model_id, lang_set): | |
l_base_mdl_id = base_model_id.lower() | |
if "nllb" in l_base_mdl_id: | |
nllb_lang_set = langs_to_nllb(lang_set) | |
return NllbTokenizer(location + ".model", additional_special_tokens=nllb_lang_set) | |
elif "madlad" in l_base_mdl_id or "t5" in l_base_mdl_id: | |
madlad_lang_set = langs_to_madlad(lang_set) | |
return T5Tokenizer(location + ".model", additional_special_tokens=madlad_lang_set) | |
else: | |
raise ValueError("Incompatible model type for tokenizer") | |
def remove_tmp_spm_files(location): | |
for tmp_file in (".vocab", ".model"): | |
os.remove(location + tmp_file) | |
def learn_spm_tokenizer(corpus, save_location, base_model_id, vocab_size, lang_set=None): | |
tmp_location = os.path.join(save_location, "sentencepiece.bpe.tmp") | |
os.makedirs(save_location, exist_ok=True) | |
spm.SentencePieceTrainer.train(input=corpus, model_prefix=tmp_location, vocab_size=vocab_size) | |
tok = wrap_tok_in_correct_class(tmp_location, base_model_id, lang_set) | |
remove_tmp_spm_files(tmp_location) | |
return tok | |
def do_new_tok(tokargs): | |
correction = get_stupid_correction(tokargs.mdl_id) | |
voc_size = tokargs.vocab_size - correction | |
location = tokargs.save_location | |
return learn_spm_tokenizer(tokargs.tok_train_file, location, base_model_id=tokargs.tok_mdl_id, | |
vocab_size=voc_size, lang_set=tokargs.new_langs) | |
def remove_known_toks(toks, tokenizer): | |
return [t for t in toks if not t in tokenizer.get_vocab()] | |
def _handle_new_tokenizer(args): | |
assert args.new_langs is not None, "lang_set must be provided" | |
assert args.tok_train_file is not None, "tok_train_file must be provided" | |
args.vocab_size = int(args.vocab_size) | |
log("Training new tokenizer") | |
tokenizer = do_new_tok(args) | |
return tokenizer | |
def get_postoken_filename(save_location): | |
return os.path.join(save_location, "postokens.json") | |
def save_postokens(added_tokens, location): | |
if added_tokens is not None: | |
os.makedirs(location, exist_ok=True) | |
with open(get_postoken_filename(location), "w") as f: | |
json.dump(added_tokens, f) | |
def _handle_adding_tokens(tokenizer, toks_to_add, args): | |
if len(toks_to_add) == 0: | |
return None | |
log(f"Adding tokens: {toks_to_add}") | |
base_idx = len(tokenizer) | |
added_tok_dict = { t: (base_idx + i) for i, t in enumerate(toks_to_add) } | |
added_tok_rev_dict = { int(i): t for t, i in added_tok_dict.items() } | |
comb_dict = { 'tok2idx': added_tok_dict, 'idx2tok': added_tok_rev_dict } | |
save_postokens(comb_dict, args.save_location) | |
return comb_dict | |
def _handle_existing_tokenizer(args): | |
log("Reusing existing tokenizer") | |
tokenizer, added_tokens = load_tokenizer(args.tok_mdl_id) | |
if args.new_langs is not None: | |
log("Extending existing tokenizer with languages") | |
extend_tok_langs(tokenizer, args.new_langs) | |
if args.merge_tokenizers or args.merge_tok_mdl_id: | |
""" | |
assert args.tok_train_file is not None, "For merging tokenizers a text file must be provided" \ | |
+ " to find the top N tokens to merge" | |
assert args.merge_tokenizers is not None and args.merge_tok_mdl_id is not None, \ | |
"Both merge_tokenizers and merge_tok_mdl_id must be provided" | |
""" | |
raise NotImplementedError("Merging is currently not supported") | |
added_tok_count = 0 | |
if args.tok_train_file: | |
if args.merge_tokenizers: | |
""" | |
merge_tok_max = int(args.merge_tokenizers) | |
log(f"Extending existing tokenizer ({args.merge_tok_mdl_id}) with up to {merge_tok_max} top tokens" + | |
f" from another tokenizer and corpus ({args.tok_train_file})") | |
new_tok = AutoTokenizer.from_pretrained(args.merge_tok_mdl_id, token=hf_tok) | |
toks_to_maybe_add = get_top_toks(new_tok, args.tok_train_file, merge_tok_max) | |
""" | |
raise NotImplementedError("Merging is currently not supported") | |
else: | |
log(f"Extending existing tokenizer with UNK tokens from corpus ({args.tok_train_file})") | |
toks_to_maybe_add = get_unk_toks(tokenizer, args.tok_train_file, verbose=True) | |
toks_to_add = remove_known_toks(toks_to_maybe_add, tokenizer) | |
added_tok_count = len(toks_to_add) | |
added_tokens = _handle_adding_tokens(tokenizer, toks_to_add, args) | |
return tokenizer, added_tok_count, added_tokens | |
def train_or_extend_tokenizer_and_upd_model(args, model): | |
if hasattr(args, "vocab_size") and args.vocab_size: | |
# train a new sentence-piece tokenizer | |
tokenizer = _handle_new_tokenizer(args) | |
added_tok_count = 0 | |
added_dict = None | |
else: | |
# save the pre-trained model's tokenizer, possibly adding new languages and tokens | |
tokenizer, added_tok_count, added_dict = _handle_existing_tokenizer(args) | |
upd_amt = get_stupid_correction(args.mdl_id) | |
new_len = len(tokenizer) + added_tok_count | |
model.resize_token_embeddings(new_len + upd_amt) | |
return tokenizer, added_dict | |
def load_tokenizer(tok_mdl_id): | |
orig_tokenizer = AutoTokenizer.from_pretrained(tok_mdl_id, token=hf_tok) | |
postoken_file = get_postoken_filename(tok_mdl_id) | |
if os.path.exists(postoken_file): | |
with open(postoken_file, "r") as f: | |
postokens = json.load(f) | |
else: | |
postokens = None | |
return orig_tokenizer, postokens | |
def tokenize_batch(tokenizer, sntlist, maxlen=8000): | |
#tokenizer.pad_token = '<|reserved_special_token_0|>' | |
tokenizer.pad_token = tokenizer.eos_token | |
output = tokenizer(sntlist, return_tensors="pt", max_length=maxlen, truncation=True, add_special_tokens=True, | |
padding=True) | |
output["labels"] = output["input_ids"].detach().clone() | |
return output | |
""" | |
def detokenizeit(toktup, tok_ids): | |
#return toktup[0].decode(tok_ids, skip_special_tokens=True) | |
toks = [] | |
for tok_id_tensor in tok_ids: | |
tok_id = tok_id_tensor.item() | |
try: | |
if tok_id not in toktup[0].added_tokens_decoder: | |
toks.append(toktup[0].convert_ids_to_tokens(tok_id)) | |
except IndexError: | |
toks.append(toktup[1]['idx2tok'][str(tok_id)]) | |
result = "".join(toks).replace("▁", " ")[1:] | |
return result, toks | |
def detokenizemany(toktup, tok_mtx): | |
result = [detokenizeit(toktup, tok_ids)[0] for tok_ids in tok_mtx] | |
return result | |
def run_tokenizer_testing(): | |
args = CmdlineArgs("Test a tokenizer: tokenize & de-tokenize some text and check if these match", | |
pos_arg_list=["tok_mdl_id", "txt_file"]) | |
#tokenizer = AutoTokenizer.fromm_pretrained(args.tok_mdl_id, token=hf_tok) if os.path.exists() | |
toktup = load_tokenizer(args.tok_mdl_id) | |
success = 0 | |
failure = 0 | |
with open(args.txt_file, "r", encoding="utf-8") as f: | |
snts = f.read().split("\n") | |
toks = tokenizeit(toktup, snts, 1024, False) | |
for i, snt in enumerate(snts): | |
tok_ids = toks['input_ids'][i] | |
#detoks = toktup[0].decode(tok_ids, skip_special_tokens=True) | |
detoks, tok_strs = detokenizeit(toktup, tok_ids) | |
if detoks != snt: | |
failure += 1 | |
#log(f"Tokens: {toktup[0].convert_ids_to_tokens(tok_ids)}") | |
log(f"Tokens: {tok_strs}") | |
log(f"Test failed:\n{snt} !=\n{detoks}") | |
else: | |
success += 1 | |
i += 1 | |
log(f"Test result: {success} successful / {failure} failed") | |
if __name__ == "__main__": | |
sys.argv = ['', 'models/nllbxt', 'data/tok-test.txt'] | |
run_tokenizer_testing() | |
""" | |