Rasmus Lellep
add loader
76b1ec5
#!/usr/bin/env python3
import os
import json
import torch
import sys
from accelerate import Accelerator
from transformers import AutoModelForCausalLM, AutoTokenizer
from accel import SwitchingAccelerator
from modelops import hf_tok, save_all_models
from aux import log, CmdlineArgs
from data import do_list_in_batches
def _cmdline_args():
description = """Train or tune decoder models"""
result = CmdlineArgs(description,
pos_arg_list=["mdl_id", "save_location", "train_file"],
pos_arg_types=[str, str, str],
kw_arg_dict={ "continue_training": False, "save_steps": 100, "lr": 1.5e-5,
"batch_size": 1024, "nr_sents_per_gpu": 4, "log_steps": 1, "epochs": 4,
"max_length": 3000 })
# if the directory args.save_location already exists, raise an exception:
if not result.continue_training and os.path.exists(result.save_location):
raise Exception(f"Save location '{result.save_location}' already exists, don't want to overwrite.")
if result.nr_sents_per_gpu == 0:
result.nr_sents_per_gpu = result.batch_size
return result
def load_json_list(json_file):
with open(json_file, "r") as f:
data = json.load(f)
return data
def load_hf_model(mdl_id, accelerator=None):
if accelerator is None:
model = AutoModelForCausalLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16)
else:
model = AutoModelForCausalLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16, device_map=accelerator.device)
return model
def load_hf_tokenizer(mdl_id):
tokenizer = AutoTokenizer.from_pretrained(mdl_id, token=hf_tok)
return tokenizer
def _no_globals_main():
accelerator = Accelerator()
try:
args = _cmdline_args()
log(f"Num proc: {accelerator.num_processes}, proc ID: {accelerator.process_index}")
log("loading model", accelerator=accelerator)
mdl = load_hf_model(args.mdl_id)
log("loading tokenizer", accelerator=accelerator)
tok = load_hf_tokenizer(args.mdl_id)
log("loading data", accelerator=accelerator, all_threads=True)
train_set = load_json_list(args.train_file)
log("training", accelerator=accelerator)
acc_trainer = SwitchingAccelerator(train_set, args, mdl, tok, preinit_acc=accelerator)
upd_model = acc_trainer.train()
log("saving", accelerator=accelerator)
save_all_models(args.save_location, upd_model, tok)
except Exception as e:
# in multiprocess scenarios it is hard to read the stack trace, so just show one:
if accelerator.is_main_process:
raise e
if __name__ == "__main__":
#sys.argv = "_ models/llama3.2-1b models/newmdl tmp.json".split()
#sys.argv = "_ models/llama3.2-1b models/newmdl2 tmpx.json batch_size=16 nr_sents_per_gpu=1 log_steps=1 save_steps=2000 epochs=1".split()
_no_globals_main()