Rasmus Lellep
add loader
76b1ec5
"""
import os
import torch
from accelerate import Accelerator, DistributedDataParallelKwargs
from transformers import get_scheduler
from aux import SameLineLogger, log
from data import DataState
from langconv import is_dec_only_llm
from modelops import save_all_models, report_devices
from translate import encode
raise NotImplementedError("This is a backup package, do not run or import from it")
def chain_params(coupling_specs):
for spec in coupling_specs:
yield from spec.model.parameters()
class TrainLossList:
def __init__(self):
self.data = []
def append(self, loss_val, src_k, tgt_k):
self.data.append((loss_val, src_k, tgt_k))
def state_dict(self):
return {'data': self.data}
def load_state_dict(self, state_dict):
self.data = state_dict['data']
class SwitchingAccelerator:
def __init__(self, coupling_specs, train_set, train_kwargs):
self.coupling_specs = coupling_specs
self.train_set = train_set
self.kwargs = train_kwargs
self.is_generative = is_dec_only_llm(self.coupling_specs[0].tokenizer)
self.train_loss_list = TrainLossList()
self.data_state = DataState(epoch_idx=0)
self._init_acc_and_stuff()
def _init_acc_and_stuff(self):
#self.accelerator = Accelerator(gradient_accumulation_steps=self.kwargs.accum_steps, kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)])
#self.accelerator = Accelerator(gradient_accumulation_steps=self.kwargs.accum_steps)
self.accelerator = Accelerator()
epoch_len = len(self.train_set)
train_len = epoch_len * self.kwargs.epochs
num_warmup = int(train_len * 0.01)
log(f"Warmup steps: {num_warmup}, epoch len: {epoch_len}, train len: {train_len}", accelerator=self.accelerator)
opt = torch.optim.AdamW(chain_params(self.coupling_specs), lr=self.kwargs.lr)
lr_scheduler = get_scheduler("linear", optimizer=opt, num_warmup_steps=num_warmup,
num_training_steps=train_len * self.accelerator.num_processes)
models = [s.model for s in self.coupling_specs]
self.optimizer, self.lr_scheduler, *self.models = self.accelerator.prepare(opt, lr_scheduler, *models)
self.accelerator.register_for_checkpointing(self.lr_scheduler, self.data_state, self.train_loss_list)
if self.kwargs.continue_training:
self.accelerator.load_state(self.kwargs.mdl_id)
log(f"Reloaded data state: {self.data_state}", accelerator=self.accelerator)
def train(self):
try:
self._main_loop()
except Exception as e:
#in multi-process scenarios it is hard to read the stack trace, so just show one:
if self.accelerator.is_main_process:
raise e
self.accelerator.wait_for_everyone()
unwr_coupled_model = self.accelerator.unwrap_model(self.models[0])
return unwr_coupled_model, self.train_loss_list
def _split_batch_and_bin_idxs(self, batch_with_idxs):
if self.is_generative:
batch, _ = batch_with_idxs
src_k = 0
tgt_k = 0
else:
batch, src_k, tgt_k, _ = batch_with_idxs
return batch, src_k, tgt_k
def _prepare_inputs(self, batch, sub_batch_idx, sub_batch_size, proc_batch_size):
from_proc_idx = proc_batch_size * self.accelerator.process_index + sub_batch_size * sub_batch_idx
to_proc_idx = from_proc_idx + sub_batch_size
#log(f"----> DEBUG for sub_b idx {sub_batch_idx}, proc {self.accelerator.process_index}: {from_proc_idx}:{to_proc_idx}")
return {k: batch[k][from_proc_idx:to_proc_idx].to(self.accelerator.device) for k in batch}
def _get_split_batch_params(self, batch):
batch_nr_snts = batch['input_ids'].size()[0]
snt_nr_words = batch['input_ids'].size()[1]
assert batch_nr_snts % self.accelerator.num_processes == 0, "Batch size must be divisible by number of processes."
proc_batch_nr_snts = batch_nr_snts // self.accelerator.num_processes
if self.kwargs.nr_snts_in_batch > 0:
sub_batch_size = self.kwargs.nr_snts_in_batch
else:
sub_batch_size = max(1, self.kwargs.nr_words_in_batch // snt_nr_words)
#log(f"DEBUG: #words/snt {snt_nr_words} X #snt in sub batch {sub_batch_size} = {snt_nr_words*sub_batch_size} ~ {self.kwargs.nr_words_in_batch}", accelerator=self.accelerator)
nr_steps = -(proc_batch_nr_snts // -sub_batch_size)
#log(f"--> DEBUG: sub_batch {sub_batch_size} X steps {nr_steps} ~ {proc_batch_nr_snts} ({batch_nr_snts} / {self.accelerator.num_processes})", accelerator=self.accelerator)
return sub_batch_size, nr_steps, proc_batch_nr_snts
def _main_loop(self):
#countdown_till_do_it_once = 0
if self.accelerator.is_main_process:
logger = SameLineLogger(len(self.train_set), self.kwargs.epochs)
logger.line_start()
else:
logger = None
self.models[0].train()
self.train_set.thats_where(self.data_state)
for _epoch_idx in range(self.data_state.epoch_idx, self.kwargs.epochs):
for batch_with_bin_idxs, epoch_batch_idx in self.train_set:
batch, src_k, tgt_k = self._split_batch_and_bin_idxs(batch_with_bin_idxs)
sub_batch_size, nr_steps, proc_batch_size = self._get_split_batch_params(batch)
loss = None
for sub_batch_idx in range(nr_steps):
inputs = self._prepare_inputs(batch, sub_batch_idx, sub_batch_size, proc_batch_size)
if self.is_generative:
inputs['labels'] = inputs['input_ids']
outputs = self.models[0](**inputs)
else:
encoder_vecs = encode(self.models[src_k], inputs)
outputs = self.models[tgt_k](attention_mask=inputs['attention_mask'], labels=inputs['labels'], encoder_outputs=encoder_vecs)
loss = outputs.loss
#if countdown_till_do_it_once > 0:
# countdown_till_do_it_once -= 1
#elif countdown_till_do_it_once == 0:
if sub_batch_idx == 5:
batch_size = sum([inputs[k].size()[0] * inputs[k].size()[1] for k in 'input_ids labels attention_mask'.split(' ')])
report_devices(f"training memory usage (batch size: {batch_size}; inputs:" +
f"snts {inputs['input_ids'].size()[0]} X words {inputs['input_ids'].size()[1]})",
self.accelerator, self.models[0])
countdown_till_do_it_once = 0
self.train_loss_list.append(loss.item(), src_k, tgt_k)
self.accelerator.backward(loss)
for k in inputs:
inputs[k] = inputs[k].to('cpu')
self._step_and_perhaps_save(logger, epoch_batch_idx, _epoch_idx, float(loss.item()))
if self.accelerator.is_main_process:
logger.line_break()
def get_total_grad(self):
result = 0
grad_count = 0
all_count = 0
for p in self.models[0].parameters():
if p.grad is not None:
result += p.grad.abs().mean().item()
grad_count += 1
all_count += 1
return result/grad_count if grad_count > 0 else -1
def _step_and_perhaps_save(self, logger, epoch_batch_idx, epoch_i, loss):
epoch_len = len(self.train_set)
global_batch_idx = epoch_batch_idx + epoch_i * epoch_len
self.optimizer.step()
self.lr_scheduler.step()
self.accelerator.wait_for_everyone()
is_end_of_epoch = (epoch_batch_idx == epoch_len)
if self.accelerator.is_main_process and (epoch_batch_idx % self.kwargs.log_steps == 0 or is_end_of_epoch):
grad = self.get_total_grad()
logger.step(global_batch_idx, epoch_batch_idx, epoch_i, loss, self.lr_scheduler.get_last_lr()[0], grad)
self.optimizer.zero_grad()
if (global_batch_idx % self.kwargs.save_steps == 0) or is_end_of_epoch:
self.accelerator.wait_for_everyone()
if self.accelerator.is_main_process:
logger.line_break()
log(f"Saving at {epoch_batch_idx} steps, epoch {epoch_i + 1} ({global_batch_idx} global steps)", accelerator=self.accelerator)
self._save_all(global_batch_idx, epoch_i)
logger.line_start()
def _save_all(self, global_batch_idx, epoch_i):
epoch_len = len(self.train_set)
ckpt_name = (f"checkpoint-e{epoch_i + 1:02}-" +
(f"b{global_batch_idx:07}" if (global_batch_idx % epoch_len) else f"full"))
this_location = os.path.join(self.kwargs.save_location, ckpt_name)
if os.path.exists(this_location):
raise FileExistsError(f"Cannot overwrite existing checkpoint {this_location}!")
self.data_state.copy_from(self.train_set.where_are_we(), epoch_idx=epoch_i)
model_to_save = self.accelerator.unwrap_model(self.models[0])
save_all_models(this_location, model_to_save, self.coupling_specs[0].tokenizer,
self.coupling_specs, trainer=self.accelerator)
"""