Rasmus Lellep
add loader
76b1ec5
#!/usr/bin/env python3
import os
import torch
from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM
from legacy.accel import SwitchingAccelerator
from accelerate import Accelerator
from data import MultilingualDatasetIterator
from aux import log, CmdlineArgs
from legacy.langconv import lang_set_maybe_smugri, is_dec_only_llm
from modelops import mdl_param_count, to_cpl_spec, hf_tok
from tokops import load_tokenizer
def freeze_model(model):
for n, p in model.named_parameters():
p.requires_grad = False
def load_hf_mdl_and_tok(mdl_id, tok_id=None, verbose=False):
if tok_id is None:
tok_id = mdl_id
tokenizer = load_tokenizer(tok_id) # AutoTokenizer.fromm_pretrained(tok_id, token=hf_tok)
if is_dec_only_llm(tokenizer[0]):
model = AutoModelForCausalLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16)
else:
model = AutoModelForSeq2SeqLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16)
if verbose:
mdl_size, _ = mdl_param_count(model)
log(f"Loaded {mdl_id} with {mdl_size} params, voc size {model.config.vocab_size}")
return model, tokenizer
def _cmdline_args():
description = """Train or tune models"""
pos_args = ["mdl_id", "save_location", "train_pretok_file", "langs"]
pos_types = [str, str, str, lang_set_maybe_smugri]
kw_args = { "anchor_mdl_id": None, "anchor_langs": None, "continue_training": False,
"save_steps": 100000, "lr": 1.5e-5, "nr_snts_in_batch": 0, "nr_words_in_batch": 0,
"log_steps": 100, "epochs": 4 }
#post-process the arguments
args = CmdlineArgs(description, pos_arg_list=pos_args, pos_arg_types=pos_types, kw_arg_dict=kw_args)
if args.anchor_langs is not None:
args.anchor_langs = lang_set_maybe_smugri(args.anchor_langs)
if (args.nr_snts_in_batch > 0) == (args.nr_words_in_batch > 0):
raise Exception(f"Specify the batch size either in words or in sentences.")
# if the directory args.save_location already exists, raise an exception:
if not args.continue_training and os.path.exists(args.save_location):
raise Exception(f"Save location '{args.save_location}' already exists, don't want to overwrite.")
return args
def yes_i_called_this_function_do_main():
args = _cmdline_args()
tmp_acc = Accelerator()
log(f"Num proc: {tmp_acc.num_processes}, proc ID: {tmp_acc.process_index}")
log("loading coupled model and tokenizer", accelerator=tmp_acc)
main_model, main_tokenizer = load_hf_mdl_and_tok(args.mdl_id, verbose=True)
coupling_specs = to_cpl_spec(args.langs, main_model, main_tokenizer[0], main_tokenizer[1], args.save_location)
if args.anchor_mdl_id:
log("loading anchor model and tokenizer", accelerator=tmp_acc)
anchor_model, anchor_tokenizer = load_hf_mdl_and_tok(args.anchor_mdl_id, verbose=True)
freeze_model(anchor_model)
coupling_specs += to_cpl_spec(args.anchor_langs, anchor_model, anchor_tokenizer[0], anchor_tokenizer[1], args.anchor_mdl_id)
train_set = MultilingualDatasetIterator(args.train_pretok_file)
acc_trainer = SwitchingAccelerator(coupling_specs, train_set, args)
upd_model, loss_list = acc_trainer.train()
#save_all_models(args.save_location, upd_model, main_tokenizer, coupling_specs, loss_list, trainer=acc_trainer.accelerator)
if __name__ == "__main__":
#sys.argv = ". models/smol models/smol_next data/smugri4a-dev.json-tokcache/thiscache.json smugri log_steps=1 lr=1e-5".split()
#sys.argv = ". models/llama3.2-1b models/llama-tuned data/smugri4a-dev.json-tokcache/llama.json smugri".split()
yes_i_called_this_function_do_main()