#!/usr/bin/env python3 import os import torch from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM from legacy.accel import SwitchingAccelerator from accelerate import Accelerator from data import MultilingualDatasetIterator from aux import log, CmdlineArgs from legacy.langconv import lang_set_maybe_smugri, is_dec_only_llm from modelops import mdl_param_count, to_cpl_spec, hf_tok from tokops import load_tokenizer def freeze_model(model): for n, p in model.named_parameters(): p.requires_grad = False def load_hf_mdl_and_tok(mdl_id, tok_id=None, verbose=False): if tok_id is None: tok_id = mdl_id tokenizer = load_tokenizer(tok_id) # AutoTokenizer.fromm_pretrained(tok_id, token=hf_tok) if is_dec_only_llm(tokenizer[0]): model = AutoModelForCausalLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16) else: model = AutoModelForSeq2SeqLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16) if verbose: mdl_size, _ = mdl_param_count(model) log(f"Loaded {mdl_id} with {mdl_size} params, voc size {model.config.vocab_size}") return model, tokenizer def _cmdline_args(): description = """Train or tune models""" pos_args = ["mdl_id", "save_location", "train_pretok_file", "langs"] pos_types = [str, str, str, lang_set_maybe_smugri] kw_args = { "anchor_mdl_id": None, "anchor_langs": None, "continue_training": False, "save_steps": 100000, "lr": 1.5e-5, "nr_snts_in_batch": 0, "nr_words_in_batch": 0, "log_steps": 100, "epochs": 4 } #post-process the arguments args = CmdlineArgs(description, pos_arg_list=pos_args, pos_arg_types=pos_types, kw_arg_dict=kw_args) if args.anchor_langs is not None: args.anchor_langs = lang_set_maybe_smugri(args.anchor_langs) if (args.nr_snts_in_batch > 0) == (args.nr_words_in_batch > 0): raise Exception(f"Specify the batch size either in words or in sentences.") # if the directory args.save_location already exists, raise an exception: if not args.continue_training and os.path.exists(args.save_location): raise Exception(f"Save location '{args.save_location}' already exists, don't want to overwrite.") return args def yes_i_called_this_function_do_main(): args = _cmdline_args() tmp_acc = Accelerator() log(f"Num proc: {tmp_acc.num_processes}, proc ID: {tmp_acc.process_index}") log("loading coupled model and tokenizer", accelerator=tmp_acc) main_model, main_tokenizer = load_hf_mdl_and_tok(args.mdl_id, verbose=True) coupling_specs = to_cpl_spec(args.langs, main_model, main_tokenizer[0], main_tokenizer[1], args.save_location) if args.anchor_mdl_id: log("loading anchor model and tokenizer", accelerator=tmp_acc) anchor_model, anchor_tokenizer = load_hf_mdl_and_tok(args.anchor_mdl_id, verbose=True) freeze_model(anchor_model) coupling_specs += to_cpl_spec(args.anchor_langs, anchor_model, anchor_tokenizer[0], anchor_tokenizer[1], args.anchor_mdl_id) train_set = MultilingualDatasetIterator(args.train_pretok_file) acc_trainer = SwitchingAccelerator(coupling_specs, train_set, args) upd_model, loss_list = acc_trainer.train() #save_all_models(args.save_location, upd_model, main_tokenizer, coupling_specs, loss_list, trainer=acc_trainer.accelerator) if __name__ == "__main__": #sys.argv = ". models/smol models/smol_next data/smugri4a-dev.json-tokcache/thiscache.json smugri log_steps=1 lr=1e-5".split() #sys.argv = ". models/llama3.2-1b models/llama-tuned data/smugri4a-dev.json-tokcache/llama.json smugri".split() yes_i_called_this_function_do_main()