File size: 3,027 Bytes
76b1ec5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#!/usr/bin/env python3

import os
import json
import torch
import sys

from accelerate import Accelerator
from transformers import AutoModelForCausalLM, AutoTokenizer

from accel import SwitchingAccelerator
from modelops import hf_tok, save_all_models

from aux import log, CmdlineArgs
from data import do_list_in_batches


def _cmdline_args():
    description = """Train or tune decoder models"""

    result = CmdlineArgs(description,
                         pos_arg_list=["mdl_id", "save_location", "train_file"],
                         pos_arg_types=[str, str, str],
                         kw_arg_dict={ "continue_training": False, "save_steps": 100, "lr": 1.5e-5,
                            "batch_size": 1024, "nr_sents_per_gpu": 4, "log_steps": 1, "epochs": 4,
                            "max_length": 3000 })

    # if the directory args.save_location already exists, raise an exception:
    if not result.continue_training and os.path.exists(result.save_location):
        raise Exception(f"Save location '{result.save_location}' already exists, don't want to overwrite.")

    if result.nr_sents_per_gpu == 0:
        result.nr_sents_per_gpu = result.batch_size

    return result


def load_json_list(json_file):
    with open(json_file, "r") as f:
        data = json.load(f)
        return data


def load_hf_model(mdl_id, accelerator=None):
    if accelerator is None:
        model = AutoModelForCausalLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16)
    else:
        model = AutoModelForCausalLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16, device_map=accelerator.device)
    return model


def load_hf_tokenizer(mdl_id):
    tokenizer = AutoTokenizer.from_pretrained(mdl_id, token=hf_tok)
    return tokenizer


def _no_globals_main():
    accelerator = Accelerator()

    try:
        args = _cmdline_args()

        log(f"Num proc: {accelerator.num_processes}, proc ID: {accelerator.process_index}")
        log("loading model", accelerator=accelerator)
        mdl = load_hf_model(args.mdl_id)

        log("loading tokenizer", accelerator=accelerator)
        tok = load_hf_tokenizer(args.mdl_id)

        log("loading data", accelerator=accelerator, all_threads=True)
        train_set = load_json_list(args.train_file)

        log("training", accelerator=accelerator)

        acc_trainer = SwitchingAccelerator(train_set, args, mdl, tok, preinit_acc=accelerator)
        upd_model = acc_trainer.train()

        log("saving", accelerator=accelerator)
        save_all_models(args.save_location, upd_model, tok)
    except Exception as e:
        # in multiprocess scenarios it is hard to read the stack trace, so just show one:
        if accelerator.is_main_process:
            raise e


if __name__ == "__main__":
    #sys.argv = "_ models/llama3.2-1b models/newmdl tmp.json".split()
    #sys.argv = "_ models/llama3.2-1b models/newmdl2 tmpx.json batch_size=16 nr_sents_per_gpu=1 log_steps=1 save_steps=2000 epochs=1".split()

    _no_globals_main()