Rasmus Lellep
add loader
76b1ec5
import os
import torch
from accelerate import Accelerator
from datetime import datetime
from transformers import get_scheduler
from aux import SameLineLogger, log
from data import DataState, BatchingIterator
from modelops import save_all_models, report_devices
def chain_params(coupling_specs):
for spec in coupling_specs:
yield from spec.model.parameters()
class TrainLossList:
def __init__(self):
self.data = []
def append(self, loss_val, sub_batch_idx, epoch_batch_idx, _epoch_idx):
self.data.append((loss_val, sub_batch_idx, epoch_batch_idx, _epoch_idx))
def state_dict(self):
return {'data': self.data}
def load_state_dict(self, state_dict):
self.data = state_dict['data']
class SwitchingAccelerator:
def __init__(self, train_set, train_kwargs, model, tokenizer, preinit_acc=None):
self.kwargs = train_kwargs
self.train_set_iter = BatchingIterator(train_set, self.kwargs.batch_size, tokenizer, train_kwargs.max_length)
self.model = model
self.tokenizer = tokenizer
self.train_loss_list = TrainLossList()
self.data_state = DataState(epoch_idx=0)
self._init_acc_and_stuff(preinit_acc)
self._init_time_keepers()
def _init_time_keepers(self):
if self.kwargs.log_steps < 0 and self.accelerator.is_main_process:
t = datetime.now()
self._tk_zero = t - t
self._tk_stats = {}
self._tk_time = {}
def _add_timekeeper(self, msg):
if self.kwargs.log_steps < 0 and self.accelerator.is_main_process:
self._tk_stats[msg] = []
self._tk_time[msg] = None
def _add_timekeepers(self, msgs):
for msg in msgs:
self._add_timekeeper(msg)
def _tk_start(self, msg):
if self.kwargs.log_steps < 0 and self.accelerator.is_main_process:
assert self._tk_time[msg] is None
self._tk_time[msg] = datetime.now()
def _tk_stop(self, msg):
if self.kwargs.log_steps < 0 and self.accelerator.is_main_process:
assert self._tk_time[msg] is not None
this_time = datetime.now() - self._tk_time[msg]
self._tk_time[msg] = None
self._tk_stats[msg].append(this_time)
log(f"{msg} took {this_time}, avg time: " +
f" {sum(self._tk_stats[msg], self._tk_zero) / len(self._tk_stats[msg])}" +
f" over {len(self._tk_stats[msg])} samples")
def __handle_accum(self):
assert self.kwargs.batch_size % (self.accelerator.num_processes * self.kwargs.nr_sents_per_gpu) == 0,\
"batch size must be divisible by number of processes and number of segments per GPU"
accum_steps = int((self.kwargs.batch_size / self.accelerator.num_processes) / self.kwargs.nr_sents_per_gpu)
self.accelerator.gradient_accumulation_steps = accum_steps
log(f"Nr sents/GPU: {self.kwargs.nr_sents_per_gpu}, accum steps: {accum_steps}, " +
f"nr. procs: {self.accelerator.num_processes}, batch size: {self.kwargs.batch_size}",
accelerator=self.accelerator)
def ___get_train_scalars(self):
epoch_len = len(self.train_set_iter)
train_len = epoch_len * self.kwargs.epochs
num_warmup = 0 #int(train_len * 0.01)
log(f"Warmup steps: {num_warmup}, epoch len: {epoch_len}, train len: {train_len}", accelerator=self.accelerator)
return train_len, num_warmup
def __init_opt_lr_and_what_else(self):
train_len, num_warmup = self.___get_train_scalars()
opt = torch.optim.AdamW(self.model.parameters(), lr=self.kwargs.lr)
numtr = train_len * self.accelerator.num_processes
lr_scheduler = get_scheduler("linear", optimizer=opt, num_warmup_steps=num_warmup, num_training_steps=numtr)
self.optimizer, self.lr_scheduler, self.model = self.accelerator.prepare(opt, lr_scheduler, self.model)
self.accelerator.register_for_checkpointing(self.data_state, self.train_loss_list)
def _init_acc_and_stuff(self, preinit_acc=None):
#self.accelerator = Accelerator(gradient_accumulation_steps=self.kwargs.accum_steps, kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)])
if preinit_acc is None:
self.accelerator = Accelerator()
else:
self.accelerator = preinit_acc
self.__handle_accum()
self.__init_opt_lr_and_what_else()
if self.kwargs.continue_training:
self.accelerator.load_state(self.kwargs.mdl_id)
log(f"Reloaded data state: {self.data_state}", accelerator=self.accelerator)
def train(self, dry_run=False):
try:
self._main_loop(dry_run)
except Exception as e:
#in multiprocess scenarios it is hard to read the stack trace, so just show one:
if self.accelerator.is_main_process:
raise e
self.accelerator.wait_for_everyone()
unwr_coupled_model = self.accelerator.unwrap_model(self.model)
return unwr_coupled_model
def _prepare_inputs(self, batch, sub_batch_idx, sub_batch_size, proc_batch_size):
from_proc_idx = proc_batch_size * self.accelerator.process_index + sub_batch_size * sub_batch_idx
to_proc_idx = from_proc_idx + sub_batch_size
#log(f"----> DEBUG for sub_b idx {sub_batch_idx}, proc {self.accelerator.process_index}: {from_proc_idx}:{to_proc_idx}")
return {k: batch[k][from_proc_idx:to_proc_idx].to(self.accelerator.device) for k in batch}
def _get_split_batch_params(self):
batch_nr_snts = self.kwargs.batch_size
assert batch_nr_snts % self.accelerator.num_processes == 0, "Batch size must be divisible by number of processes."
proc_batch_nr_snts = batch_nr_snts // self.accelerator.num_processes
sub_batch_size = self.kwargs.nr_sents_per_gpu
nr_steps = -(proc_batch_nr_snts // -sub_batch_size)
#log(f"--> DEBUG: sub_batch {sub_batch_size} X steps {nr_steps} ~ {proc_batch_nr_snts} ({batch_nr_snts} / {self.accelerator.num_processes})", accelerator=self.accelerator)
return sub_batch_size, nr_steps, proc_batch_nr_snts
def _report_mem_every_once_in_a_while(self, sub_batch_idx, epoch_batch_idx, batch_dim):
if sub_batch_idx == 0:
report_devices(f"training memory usage (batch size: {self.kwargs.batch_size} / {batch_dim[1]}",
self.accelerator, self.model)
def _main_loop(self, dry_run):
if self.accelerator.is_main_process:
logger = SameLineLogger(len(self.train_set_iter), self.kwargs.epochs, self.data_state)
logger.line_start()
else:
logger = None
self.model.train()
self.train_set_iter.thats_where(self.data_state)
tks = "full_batch", "prep_inputs", "forward", "backward", "upd_step"
tk_batch, tk_prep, tk_fw, tk_bk, tk_step = tks
self._add_timekeepers(tks)
with self.accelerator.accumulate(self.model):
for _epoch_idx in range(self.data_state.epoch_idx, self.kwargs.epochs):
for batch, epoch_batch_idx in self.train_set_iter:
if dry_run:
log(f"Dry run, batch width: {batch['input_ids'].size()}")
else:
self._report_mem_every_once_in_a_while(0, epoch_batch_idx, batch['input_ids'].size())
sub_batch_size, nr_steps, proc_batch_size = self._get_split_batch_params()
self._tk_start(tk_batch)
loss = None
for sub_batch_idx in range(nr_steps):
self._tk_start(tk_prep) ########
inputs = self._prepare_inputs(batch, sub_batch_idx, sub_batch_size, proc_batch_size)
inputs['labels'] = inputs['input_ids'].copy()
self._tk_stop(tk_prep) ########
self._tk_start(tk_fw) ########
outputs = self.model(**inputs)
loss = outputs.loss
self._tk_stop(tk_fw) ########
self.train_loss_list.append(loss.item(), sub_batch_idx, epoch_batch_idx, _epoch_idx)
self._tk_start(tk_bk) ########
self.accelerator.backward(loss)
self._tk_stop(tk_bk) ########
self._tk_start(tk_step) ########
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
self._tk_stop(tk_step) ########
self._tk_stop(tk_batch)
#assert self.accelerator.sync_gradients, "It is not time to sync gradients yet."
self._step_and_perhaps_save(logger, epoch_batch_idx, _epoch_idx, float(loss.item()))
if self.accelerator.is_main_process:
logger.line_break()
def get_total_grad(self):
result = 0
grad_count = 0
all_count = 0
for p in self.model.parameters():
if p.grad is not None:
result += p.grad.abs().mean().item()
grad_count += 1
all_count += 1
return result/grad_count if grad_count != 0 else -1
def _step_and_perhaps_save(self, logger, epoch_batch_idx, epoch_i, loss):
epoch_len = len(self.train_set_iter)
global_batch_idx = epoch_batch_idx + epoch_i * epoch_len
is_end_of_epoch = (epoch_batch_idx == epoch_len)
if self.accelerator.is_main_process \
and (epoch_batch_idx % self.kwargs.log_steps == 0 or is_end_of_epoch):
grad = self.get_total_grad()
logger.step(global_batch_idx, epoch_batch_idx, epoch_i, loss, self.lr_scheduler.get_last_lr()[0], grad)
#self.optimizer.zero_grad()
if (global_batch_idx % self.kwargs.save_steps == 0) or is_end_of_epoch:
self.accelerator.wait_for_everyone()
if self.accelerator.is_main_process:
logger.line_break()
log(f"Saving at {epoch_batch_idx} steps, epoch {epoch_i + 1} ({global_batch_idx} global steps)", accelerator=self.accelerator)
self._save_all(global_batch_idx, epoch_i)
logger.line_start()
def _save_all(self, global_batch_idx, epoch_i):
epoch_len = len(self.train_set_iter)
ckpt_name = (f"checkpoint-e{epoch_i + 1:02}-" +
(f"b{global_batch_idx:07}" if (global_batch_idx % epoch_len) else f"full"))
this_location = os.path.join(self.kwargs.save_location, ckpt_name)
if os.path.exists(this_location):
raise FileExistsError(f"Cannot overwrite existing checkpoint {this_location}!")
self.data_state.copy_from(self.train_set_iter.where_are_we(), epoch_idx=epoch_i)
model_to_save = self.accelerator.unwrap_model(self.model)
save_all_models(this_location, model_to_save, self.tokenizer, trainer=self.accelerator)
def test_this_damn_thing():
# testing
import torch
import json
from torch.optim import AdamW
from modelops import hf_tok
from transformers import AutoModelForCausalLM, AutoTokenizer
mdl_id = "models/llama3.2-1b"
tokenizer = AutoTokenizer.from_pretrained(mdl_id, token=hf_tok)
model = AutoModelForCausalLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16)
with open("tmpx.json", "r") as f:
training_data_raw = json.load(f)
optimizer = AdamW(model.parameters(), lr=5e-6)
print("Initial 0:", optimizer.param_groups[0]['lr']) # Should be [5e-6]
scheduler = get_scheduler(
"linear",
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=2445
)
accel = Accelerator()
p_optimizer, p_lr_scheduler, p_model = accel.prepare(optimizer, scheduler, model)
print("Initial 1:", p_lr_scheduler.get_last_lr()) # Should be [5e-6]
"""
for _ in range(2):
optimizer.step()
scheduler.step()
print("Step:", scheduler.get_last_lr())
"""
if __name__ == "__main__":
test_this_damn_thing()