Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
import os | |
import json | |
import torch | |
import sys | |
from accelerate import Accelerator | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from accel import SwitchingAccelerator | |
from modelops import hf_tok, save_all_models | |
from aux import log, CmdlineArgs | |
from data import do_list_in_batches | |
def _cmdline_args(): | |
description = """Train or tune decoder models""" | |
result = CmdlineArgs(description, | |
pos_arg_list=["mdl_id", "save_location", "train_file"], | |
pos_arg_types=[str, str, str], | |
kw_arg_dict={ "continue_training": False, "save_steps": 100, "lr": 1.5e-5, | |
"batch_size": 1024, "nr_sents_per_gpu": 4, "log_steps": 1, "epochs": 4, | |
"max_length": 3000 }) | |
# if the directory args.save_location already exists, raise an exception: | |
if not result.continue_training and os.path.exists(result.save_location): | |
raise Exception(f"Save location '{result.save_location}' already exists, don't want to overwrite.") | |
if result.nr_sents_per_gpu == 0: | |
result.nr_sents_per_gpu = result.batch_size | |
return result | |
def load_json_list(json_file): | |
with open(json_file, "r") as f: | |
data = json.load(f) | |
return data | |
def load_hf_model(mdl_id, accelerator=None): | |
if accelerator is None: | |
model = AutoModelForCausalLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16) | |
else: | |
model = AutoModelForCausalLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16, device_map=accelerator.device) | |
return model | |
def load_hf_tokenizer(mdl_id): | |
tokenizer = AutoTokenizer.from_pretrained(mdl_id, token=hf_tok) | |
return tokenizer | |
def _no_globals_main(): | |
accelerator = Accelerator() | |
try: | |
args = _cmdline_args() | |
log(f"Num proc: {accelerator.num_processes}, proc ID: {accelerator.process_index}") | |
log("loading model", accelerator=accelerator) | |
mdl = load_hf_model(args.mdl_id) | |
log("loading tokenizer", accelerator=accelerator) | |
tok = load_hf_tokenizer(args.mdl_id) | |
log("loading data", accelerator=accelerator, all_threads=True) | |
train_set = load_json_list(args.train_file) | |
log("training", accelerator=accelerator) | |
acc_trainer = SwitchingAccelerator(train_set, args, mdl, tok, preinit_acc=accelerator) | |
upd_model = acc_trainer.train() | |
log("saving", accelerator=accelerator) | |
save_all_models(args.save_location, upd_model, tok) | |
except Exception as e: | |
# in multiprocess scenarios it is hard to read the stack trace, so just show one: | |
if accelerator.is_main_process: | |
raise e | |
if __name__ == "__main__": | |
#sys.argv = "_ models/llama3.2-1b models/newmdl tmp.json".split() | |
#sys.argv = "_ models/llama3.2-1b models/newmdl2 tmpx.json batch_size=16 nr_sents_per_gpu=1 log_steps=1 save_steps=2000 epochs=1".split() | |
_no_globals_main() | |