File size: 3,733 Bytes
76b1ec5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#!/usr/bin/env python3

import os
import torch

from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM

from legacy.accel import SwitchingAccelerator
from accelerate import Accelerator
from data import MultilingualDatasetIterator
from aux import log, CmdlineArgs
from legacy.langconv import lang_set_maybe_smugri, is_dec_only_llm
from modelops import mdl_param_count, to_cpl_spec, hf_tok
from tokops import load_tokenizer


def freeze_model(model):
    for n, p in model.named_parameters():
        p.requires_grad = False


def load_hf_mdl_and_tok(mdl_id, tok_id=None, verbose=False):
    if tok_id is None:
        tok_id = mdl_id

    tokenizer = load_tokenizer(tok_id) # AutoTokenizer.fromm_pretrained(tok_id, token=hf_tok)

    if is_dec_only_llm(tokenizer[0]):
        model = AutoModelForCausalLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16)
    else:
        model = AutoModelForSeq2SeqLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16)

    if verbose:
        mdl_size, _ = mdl_param_count(model)
        log(f"Loaded {mdl_id} with {mdl_size} params, voc size {model.config.vocab_size}")

    return model, tokenizer


def _cmdline_args():
    description = """Train or tune models"""

    pos_args = ["mdl_id", "save_location", "train_pretok_file", "langs"]
    pos_types = [str, str, str, lang_set_maybe_smugri]

    kw_args = { "anchor_mdl_id": None, "anchor_langs": None, "continue_training": False,
                "save_steps": 100000, "lr": 1.5e-5, "nr_snts_in_batch": 0, "nr_words_in_batch": 0,
                "log_steps": 100, "epochs": 4 }

    #post-process the arguments
    args = CmdlineArgs(description, pos_arg_list=pos_args, pos_arg_types=pos_types, kw_arg_dict=kw_args)

    if args.anchor_langs is not None:
        args.anchor_langs = lang_set_maybe_smugri(args.anchor_langs)

    if (args.nr_snts_in_batch > 0) == (args.nr_words_in_batch > 0):
        raise Exception(f"Specify the batch size either in words or in sentences.")

    # if the directory args.save_location already exists, raise an exception:
    if not args.continue_training and os.path.exists(args.save_location):
        raise Exception(f"Save location '{args.save_location}' already exists, don't want to overwrite.")

    return args


def yes_i_called_this_function_do_main():
    args = _cmdline_args()
    tmp_acc = Accelerator()

    log(f"Num proc: {tmp_acc.num_processes}, proc ID: {tmp_acc.process_index}")

    log("loading coupled model and tokenizer", accelerator=tmp_acc)
    main_model, main_tokenizer = load_hf_mdl_and_tok(args.mdl_id, verbose=True)

    coupling_specs = to_cpl_spec(args.langs, main_model, main_tokenizer[0], main_tokenizer[1], args.save_location)

    if args.anchor_mdl_id:
        log("loading anchor model and tokenizer", accelerator=tmp_acc)
        anchor_model, anchor_tokenizer = load_hf_mdl_and_tok(args.anchor_mdl_id, verbose=True)
        freeze_model(anchor_model)

        coupling_specs += to_cpl_spec(args.anchor_langs, anchor_model, anchor_tokenizer[0], anchor_tokenizer[1], args.anchor_mdl_id)

    train_set = MultilingualDatasetIterator(args.train_pretok_file)

    acc_trainer = SwitchingAccelerator(coupling_specs, train_set, args)

    upd_model, loss_list = acc_trainer.train()

    #save_all_models(args.save_location, upd_model, main_tokenizer, coupling_specs, loss_list, trainer=acc_trainer.accelerator)


if __name__ == "__main__":
    #sys.argv = ". models/smol models/smol_next data/smugri4a-dev.json-tokcache/thiscache.json smugri log_steps=1 lr=1e-5".split()
    #sys.argv = ". models/llama3.2-1b models/llama-tuned data/smugri4a-dev.json-tokcache/llama.json smugri".split()
    yes_i_called_this_function_do_main()