Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
from .promptops import PF_SMUGRI_MT | |
from .aux import log, CmdlineArgs | |
from .data import load_training_data | |
import json | |
import os, socket, torch | |
from datetime import datetime | |
from accelerate import Accelerator | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
TrainingArguments, | |
Trainer, | |
DataCollatorForLanguageModeling, | |
logging, | |
TrainerCallback | |
) | |
""" | |
1/3 This simply reads in command-line arguments | |
""" | |
def _cmdline_args(): | |
description = """Train or tune decoder models""" | |
result = CmdlineArgs(description, | |
pos_arg_list=["mdl_id", "save_location", "train_file"], | |
pos_arg_types=[str, str, str], | |
kw_arg_dict={ "continue_training": False, "save_steps": 100, "lr": 1.5e-5, | |
"batch_size": 1024, "nr_sents_per_gpu": 4, "log_steps": 1, "epochs": 4, | |
"max_length": 2000, "prompt_format": PF_SMUGRI_MT, | |
"deepspeed": "none"}) | |
# if the directory args.save_location already exists, raise an exception: | |
if not result.continue_training and os.path.exists(result.save_location): | |
raise Exception(f"Save location '{result.save_location}' already exists, don't want to overwrite.") | |
if result.nr_sents_per_gpu == 0: | |
result.nr_sents_per_gpu = result.batch_size | |
if result.deepspeed == "none": | |
result.deepspeed = None | |
return result | |
""" | |
2/3 This here is used in training in order to report timing and predictions | |
""" | |
class StepTimerCallback(TrainerCallback): | |
def __init__(self): | |
self._step_start = None | |
self.lengths = [] | |
self.abs_start = datetime.now() | |
self.actual_first_step = None | |
self.zero = self.abs_start - self.abs_start | |
def on_step_begin(self, args, state, control, **kwargs): | |
# called right before each training step | |
self._step_start = datetime.now() | |
def on_step_end(self, args, state, control, **kwargs): | |
if self.actual_first_step is None: | |
self.actual_first_step = state.global_step - 1 | |
# called right after each training step | |
now = datetime.now() | |
elapsed = now - self._step_start | |
tot_elapsed = now - self.abs_start | |
self.lengths.append(elapsed) | |
avg = sum(self.lengths, start=self.zero) / len(self.lengths) | |
remaining = state.max_steps - self.actual_first_step - state.global_step | |
prediction = (tot_elapsed/(state.global_step - self.actual_first_step)) * remaining | |
# you can use logging.get_logger(...) instead of print | |
print(f"[step {state.global_step}/{state.max_steps}] took {elapsed}, avg {avg}; approx {prediction} remaining") | |
""" | |
3/3 Finally, the filling of TrainingArguments and the launching of Trainer: | |
""" | |
def get_training_args(cmdline_args, acc): | |
world_size = acc.num_processes | |
assert cmdline_args.batch_size % (cmdline_args.nr_sents_per_gpu * world_size) == 0, \ | |
"Batch size must be divisible by the number of GPUs and nr of sents per GPU" | |
accum_steps = cmdline_args.batch_size // (cmdline_args.nr_sents_per_gpu * world_size) | |
log(f"Nr of processes (GPUs): {world_size}, per-device batch: {cmdline_args.nr_sents_per_gpu}, accum. steps: {accum_steps}") | |
if cmdline_args.deepspeed is not None: | |
with open(cmdline_args.deepspeed, "r") as f: | |
dpspd = json.load(f) | |
#correct the dictionary with current values, so that we wouldn't need to update the JSON every time | |
dpspd['train_batch_size'] = cmdline_args.batch_size | |
dpspd['train_micro_batch_size_per_gpu'] = cmdline_args.nr_sents_per_gpu | |
dpspd['gradient_accumulation_steps'] = accum_steps | |
log(f"Using deepspeed with config {dpspd}") | |
else: | |
dpspd = None | |
tr_args = TrainingArguments( | |
output_dir=cmdline_args.save_location, | |
per_device_train_batch_size=cmdline_args.nr_sents_per_gpu, | |
gradient_accumulation_steps=accum_steps, | |
num_train_epochs=cmdline_args.epochs, | |
save_steps=cmdline_args.save_steps, | |
save_total_limit=10, | |
logging_steps=cmdline_args.log_steps, | |
deepspeed=dpspd, | |
learning_rate=cmdline_args.lr, | |
save_strategy="epoch", | |
disable_tqdm=True, | |
report_to="none", | |
# Optional but often helpful on LUMI/ROCm if you enable it in your args: | |
bf16=True, | |
ddp_find_unused_parameters=False, | |
#dataloader_num_workers=1, | |
#group_by_length=True, | |
log_level="debug", | |
#gradient_checkpointing=True, | |
#dataloader_persistent_workers=True | |
) | |
return tr_args | |
def load_model(mdl_id, device, accelerator=None, attention="flash_attention_2"): | |
log(f"Load model", accelerator=accelerator) | |
model = AutoModelForCausalLM.from_pretrained(mdl_id, | |
low_cpu_mem_usage=False, | |
torch_dtype=torch.bfloat16, | |
attn_implementation=attention) | |
model.config.use_cache = False | |
model = model.to(device) | |
log(f"Model loaded on device: {model.device}.", accelerator=accelerator) | |
return model | |
def load_tokenizer(mdl_id, accelerator=None): | |
log(f"Load tokenizer", accelerator=accelerator) | |
tokenizer = AutoTokenizer.from_pretrained(mdl_id) | |
# LLaMA 3.x: no pad token by default | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = "<|reserved_special_token_100|>" | |
return tokenizer | |
def simple_train(): | |
cmd_args = _cmdline_args() | |
acc = Accelerator() | |
device = acc.device # it seems that the accelerator loses/changes this info later | |
training_args = get_training_args(cmd_args, acc) | |
tokenizer = load_tokenizer(cmd_args.mdl_id, acc) | |
model = load_model(cmd_args.mdl_id, device, acc) | |
if getattr(model.config, "pad_token_id", None) is None: | |
model.config.pad_token_id = tokenizer.pad_token_id | |
log(f"Load data", accelerator=acc) | |
tokenized_train_data = load_training_data(cmd_args.train_file, tokenizer, cmd_args) | |
data_collator = DataCollatorForLanguageModeling( | |
tokenizer=tokenizer, | |
mlm=False, | |
pad_to_multiple_of=8, # GPT says this helps performance | |
) | |
log(f"Preparing to train", accelerator=acc) | |
clbks = [StepTimerCallback] if acc.is_main_process else [] | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=tokenized_train_data, | |
tokenizer=tokenizer, | |
data_collator=data_collator, | |
callbacks=clbks, | |
) | |
logging.set_verbosity_debug() | |
log(f"Starting training", accelerator=acc) | |
trainer.train(resume_from_checkpoint=cmd_args.continue_training) | |
log(f"Done, saving model", accelerator=acc) | |
trainer.save_model() | |
def env_stuff(): | |
os.environ.setdefault("LOCAL_RANK", os.environ.get("SLURM_LOCALID", "---")) | |
os.environ.setdefault("RANK", os.environ.get("SLURM_PROCID", "0")) | |
os.environ.setdefault("WORLD_SIZE", os.environ.get("SLURM_NTASKS", "1")) | |
os.environ.setdefault("MASTER_ADDR", os.environ.get("SLURM_LAUNCH_NODE_IPADDR", "127.0.0.1")) | |
os.environ.setdefault("MASTER_PORT", "29500") # pick an open port | |
# Optional: make sure each process selects its own GPU | |
torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) | |
try: | |
log( | |
f"host={socket.gethostname()} " | |
f"RANK={os.environ['RANK']}/{os.environ['WORLD_SIZE']} " | |
f"LOCAL_RANK={os.environ['LOCAL_RANK']} " | |
f"HIP_VISIBLE_DEVICES={os.environ.get('HIP_VISIBLE_DEVICES')} " | |
f"ROCR_VISIBLE_DEVICES={os.environ.get('ROCR_VISIBLE_DEVICES')} " | |
f"cuda_count={torch.cuda.device_count()} curr_dev={torch.cuda.current_device()}" | |
) | |
except AssertionError: | |
log( | |
f"host={socket.gethostname()} " | |
f"RANK={os.environ['RANK']}/{os.environ['WORLD_SIZE']} " | |
f"LOCAL_RANK={os.environ['LOCAL_RANK']} " | |
f"HIP_VISIBLE_DEVICES={os.environ.get('HIP_VISIBLE_DEVICES')} " | |
f"ROCR_VISIBLE_DEVICES={os.environ.get('ROCR_VISIBLE_DEVICES')} " | |
f"no cuda" | |
) | |
""" | |
This replaces the trainer, in order to | |
print out the final batch when training, | |
and commit harakiri. So only for temporary | |
debugging-related usage | |
""" | |
class LoggingKillingTrainer(Trainer): | |
def compute_loss(self, model, inputs, **kwargs): | |
log(f"Here is the batch for training: {inputs}") | |
raise NotImplementedError | |
#return super().compute_loss(model, inputs, **kwargs) | |
if __name__ == "__main__": | |
env_stuff() | |
simple_train() | |