Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from accelerate import Accelerator | |
from datetime import datetime | |
from transformers import get_scheduler | |
from aux import SameLineLogger, log | |
from data import DataState, BatchingIterator | |
from modelops import save_all_models, report_devices | |
def chain_params(coupling_specs): | |
for spec in coupling_specs: | |
yield from spec.model.parameters() | |
class TrainLossList: | |
def __init__(self): | |
self.data = [] | |
def append(self, loss_val, sub_batch_idx, epoch_batch_idx, _epoch_idx): | |
self.data.append((loss_val, sub_batch_idx, epoch_batch_idx, _epoch_idx)) | |
def state_dict(self): | |
return {'data': self.data} | |
def load_state_dict(self, state_dict): | |
self.data = state_dict['data'] | |
class SwitchingAccelerator: | |
def __init__(self, train_set, train_kwargs, model, tokenizer, preinit_acc=None): | |
self.kwargs = train_kwargs | |
self.train_set_iter = BatchingIterator(train_set, self.kwargs.batch_size, tokenizer, train_kwargs.max_length) | |
self.model = model | |
self.tokenizer = tokenizer | |
self.train_loss_list = TrainLossList() | |
self.data_state = DataState(epoch_idx=0) | |
self._init_acc_and_stuff(preinit_acc) | |
self._init_time_keepers() | |
def _init_time_keepers(self): | |
if self.kwargs.log_steps < 0 and self.accelerator.is_main_process: | |
t = datetime.now() | |
self._tk_zero = t - t | |
self._tk_stats = {} | |
self._tk_time = {} | |
def _add_timekeeper(self, msg): | |
if self.kwargs.log_steps < 0 and self.accelerator.is_main_process: | |
self._tk_stats[msg] = [] | |
self._tk_time[msg] = None | |
def _add_timekeepers(self, msgs): | |
for msg in msgs: | |
self._add_timekeeper(msg) | |
def _tk_start(self, msg): | |
if self.kwargs.log_steps < 0 and self.accelerator.is_main_process: | |
assert self._tk_time[msg] is None | |
self._tk_time[msg] = datetime.now() | |
def _tk_stop(self, msg): | |
if self.kwargs.log_steps < 0 and self.accelerator.is_main_process: | |
assert self._tk_time[msg] is not None | |
this_time = datetime.now() - self._tk_time[msg] | |
self._tk_time[msg] = None | |
self._tk_stats[msg].append(this_time) | |
log(f"{msg} took {this_time}, avg time: " + | |
f" {sum(self._tk_stats[msg], self._tk_zero) / len(self._tk_stats[msg])}" + | |
f" over {len(self._tk_stats[msg])} samples") | |
def __handle_accum(self): | |
assert self.kwargs.batch_size % (self.accelerator.num_processes * self.kwargs.nr_sents_per_gpu) == 0,\ | |
"batch size must be divisible by number of processes and number of segments per GPU" | |
accum_steps = int((self.kwargs.batch_size / self.accelerator.num_processes) / self.kwargs.nr_sents_per_gpu) | |
self.accelerator.gradient_accumulation_steps = accum_steps | |
log(f"Nr sents/GPU: {self.kwargs.nr_sents_per_gpu}, accum steps: {accum_steps}, " + | |
f"nr. procs: {self.accelerator.num_processes}, batch size: {self.kwargs.batch_size}", | |
accelerator=self.accelerator) | |
def ___get_train_scalars(self): | |
epoch_len = len(self.train_set_iter) | |
train_len = epoch_len * self.kwargs.epochs | |
num_warmup = 0 #int(train_len * 0.01) | |
log(f"Warmup steps: {num_warmup}, epoch len: {epoch_len}, train len: {train_len}", accelerator=self.accelerator) | |
return train_len, num_warmup | |
def __init_opt_lr_and_what_else(self): | |
train_len, num_warmup = self.___get_train_scalars() | |
opt = torch.optim.AdamW(self.model.parameters(), lr=self.kwargs.lr) | |
numtr = train_len * self.accelerator.num_processes | |
lr_scheduler = get_scheduler("linear", optimizer=opt, num_warmup_steps=num_warmup, num_training_steps=numtr) | |
self.optimizer, self.lr_scheduler, self.model = self.accelerator.prepare(opt, lr_scheduler, self.model) | |
self.accelerator.register_for_checkpointing(self.data_state, self.train_loss_list) | |
def _init_acc_and_stuff(self, preinit_acc=None): | |
#self.accelerator = Accelerator(gradient_accumulation_steps=self.kwargs.accum_steps, kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)]) | |
if preinit_acc is None: | |
self.accelerator = Accelerator() | |
else: | |
self.accelerator = preinit_acc | |
self.__handle_accum() | |
self.__init_opt_lr_and_what_else() | |
if self.kwargs.continue_training: | |
self.accelerator.load_state(self.kwargs.mdl_id) | |
log(f"Reloaded data state: {self.data_state}", accelerator=self.accelerator) | |
def train(self, dry_run=False): | |
try: | |
self._main_loop(dry_run) | |
except Exception as e: | |
#in multiprocess scenarios it is hard to read the stack trace, so just show one: | |
if self.accelerator.is_main_process: | |
raise e | |
self.accelerator.wait_for_everyone() | |
unwr_coupled_model = self.accelerator.unwrap_model(self.model) | |
return unwr_coupled_model | |
def _prepare_inputs(self, batch, sub_batch_idx, sub_batch_size, proc_batch_size): | |
from_proc_idx = proc_batch_size * self.accelerator.process_index + sub_batch_size * sub_batch_idx | |
to_proc_idx = from_proc_idx + sub_batch_size | |
#log(f"----> DEBUG for sub_b idx {sub_batch_idx}, proc {self.accelerator.process_index}: {from_proc_idx}:{to_proc_idx}") | |
return {k: batch[k][from_proc_idx:to_proc_idx].to(self.accelerator.device) for k in batch} | |
def _get_split_batch_params(self): | |
batch_nr_snts = self.kwargs.batch_size | |
assert batch_nr_snts % self.accelerator.num_processes == 0, "Batch size must be divisible by number of processes." | |
proc_batch_nr_snts = batch_nr_snts // self.accelerator.num_processes | |
sub_batch_size = self.kwargs.nr_sents_per_gpu | |
nr_steps = -(proc_batch_nr_snts // -sub_batch_size) | |
#log(f"--> DEBUG: sub_batch {sub_batch_size} X steps {nr_steps} ~ {proc_batch_nr_snts} ({batch_nr_snts} / {self.accelerator.num_processes})", accelerator=self.accelerator) | |
return sub_batch_size, nr_steps, proc_batch_nr_snts | |
def _report_mem_every_once_in_a_while(self, sub_batch_idx, epoch_batch_idx, batch_dim): | |
if sub_batch_idx == 0: | |
report_devices(f"training memory usage (batch size: {self.kwargs.batch_size} / {batch_dim[1]}", | |
self.accelerator, self.model) | |
def _main_loop(self, dry_run): | |
if self.accelerator.is_main_process: | |
logger = SameLineLogger(len(self.train_set_iter), self.kwargs.epochs, self.data_state) | |
logger.line_start() | |
else: | |
logger = None | |
self.model.train() | |
self.train_set_iter.thats_where(self.data_state) | |
tks = "full_batch", "prep_inputs", "forward", "backward", "upd_step" | |
tk_batch, tk_prep, tk_fw, tk_bk, tk_step = tks | |
self._add_timekeepers(tks) | |
with self.accelerator.accumulate(self.model): | |
for _epoch_idx in range(self.data_state.epoch_idx, self.kwargs.epochs): | |
for batch, epoch_batch_idx in self.train_set_iter: | |
if dry_run: | |
log(f"Dry run, batch width: {batch['input_ids'].size()}") | |
else: | |
self._report_mem_every_once_in_a_while(0, epoch_batch_idx, batch['input_ids'].size()) | |
sub_batch_size, nr_steps, proc_batch_size = self._get_split_batch_params() | |
self._tk_start(tk_batch) | |
loss = None | |
for sub_batch_idx in range(nr_steps): | |
self._tk_start(tk_prep) ######## | |
inputs = self._prepare_inputs(batch, sub_batch_idx, sub_batch_size, proc_batch_size) | |
inputs['labels'] = inputs['input_ids'].copy() | |
self._tk_stop(tk_prep) ######## | |
self._tk_start(tk_fw) ######## | |
outputs = self.model(**inputs) | |
loss = outputs.loss | |
self._tk_stop(tk_fw) ######## | |
self.train_loss_list.append(loss.item(), sub_batch_idx, epoch_batch_idx, _epoch_idx) | |
self._tk_start(tk_bk) ######## | |
self.accelerator.backward(loss) | |
self._tk_stop(tk_bk) ######## | |
self._tk_start(tk_step) ######## | |
self.optimizer.step() | |
self.lr_scheduler.step() | |
self.optimizer.zero_grad() | |
self._tk_stop(tk_step) ######## | |
self._tk_stop(tk_batch) | |
#assert self.accelerator.sync_gradients, "It is not time to sync gradients yet." | |
self._step_and_perhaps_save(logger, epoch_batch_idx, _epoch_idx, float(loss.item())) | |
if self.accelerator.is_main_process: | |
logger.line_break() | |
def get_total_grad(self): | |
result = 0 | |
grad_count = 0 | |
all_count = 0 | |
for p in self.model.parameters(): | |
if p.grad is not None: | |
result += p.grad.abs().mean().item() | |
grad_count += 1 | |
all_count += 1 | |
return result/grad_count if grad_count != 0 else -1 | |
def _step_and_perhaps_save(self, logger, epoch_batch_idx, epoch_i, loss): | |
epoch_len = len(self.train_set_iter) | |
global_batch_idx = epoch_batch_idx + epoch_i * epoch_len | |
is_end_of_epoch = (epoch_batch_idx == epoch_len) | |
if self.accelerator.is_main_process \ | |
and (epoch_batch_idx % self.kwargs.log_steps == 0 or is_end_of_epoch): | |
grad = self.get_total_grad() | |
logger.step(global_batch_idx, epoch_batch_idx, epoch_i, loss, self.lr_scheduler.get_last_lr()[0], grad) | |
#self.optimizer.zero_grad() | |
if (global_batch_idx % self.kwargs.save_steps == 0) or is_end_of_epoch: | |
self.accelerator.wait_for_everyone() | |
if self.accelerator.is_main_process: | |
logger.line_break() | |
log(f"Saving at {epoch_batch_idx} steps, epoch {epoch_i + 1} ({global_batch_idx} global steps)", accelerator=self.accelerator) | |
self._save_all(global_batch_idx, epoch_i) | |
logger.line_start() | |
def _save_all(self, global_batch_idx, epoch_i): | |
epoch_len = len(self.train_set_iter) | |
ckpt_name = (f"checkpoint-e{epoch_i + 1:02}-" + | |
(f"b{global_batch_idx:07}" if (global_batch_idx % epoch_len) else f"full")) | |
this_location = os.path.join(self.kwargs.save_location, ckpt_name) | |
if os.path.exists(this_location): | |
raise FileExistsError(f"Cannot overwrite existing checkpoint {this_location}!") | |
self.data_state.copy_from(self.train_set_iter.where_are_we(), epoch_idx=epoch_i) | |
model_to_save = self.accelerator.unwrap_model(self.model) | |
save_all_models(this_location, model_to_save, self.tokenizer, trainer=self.accelerator) | |
def test_this_damn_thing(): | |
# testing | |
import torch | |
import json | |
from torch.optim import AdamW | |
from modelops import hf_tok | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
mdl_id = "models/llama3.2-1b" | |
tokenizer = AutoTokenizer.from_pretrained(mdl_id, token=hf_tok) | |
model = AutoModelForCausalLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16) | |
with open("tmpx.json", "r") as f: | |
training_data_raw = json.load(f) | |
optimizer = AdamW(model.parameters(), lr=5e-6) | |
print("Initial 0:", optimizer.param_groups[0]['lr']) # Should be [5e-6] | |
scheduler = get_scheduler( | |
"linear", | |
optimizer=optimizer, | |
num_warmup_steps=0, | |
num_training_steps=2445 | |
) | |
accel = Accelerator() | |
p_optimizer, p_lr_scheduler, p_model = accel.prepare(optimizer, scheduler, model) | |
print("Initial 1:", p_lr_scheduler.get_last_lr()) # Should be [5e-6] | |
""" | |
for _ in range(2): | |
optimizer.step() | |
scheduler.step() | |
print("Step:", scheduler.get_last_lr()) | |
""" | |
if __name__ == "__main__": | |
test_this_damn_thing() | |