Spaces:
Sleeping
Sleeping
""" | |
import os | |
import torch | |
from accelerate import Accelerator, DistributedDataParallelKwargs | |
from transformers import get_scheduler | |
from aux import SameLineLogger, log | |
from data import DataState | |
from langconv import is_dec_only_llm | |
from modelops import save_all_models, report_devices | |
from translate import encode | |
raise NotImplementedError("This is a backup package, do not run or import from it") | |
def chain_params(coupling_specs): | |
for spec in coupling_specs: | |
yield from spec.model.parameters() | |
class TrainLossList: | |
def __init__(self): | |
self.data = [] | |
def append(self, loss_val, src_k, tgt_k): | |
self.data.append((loss_val, src_k, tgt_k)) | |
def state_dict(self): | |
return {'data': self.data} | |
def load_state_dict(self, state_dict): | |
self.data = state_dict['data'] | |
class SwitchingAccelerator: | |
def __init__(self, coupling_specs, train_set, train_kwargs): | |
self.coupling_specs = coupling_specs | |
self.train_set = train_set | |
self.kwargs = train_kwargs | |
self.is_generative = is_dec_only_llm(self.coupling_specs[0].tokenizer) | |
self.train_loss_list = TrainLossList() | |
self.data_state = DataState(epoch_idx=0) | |
self._init_acc_and_stuff() | |
def _init_acc_and_stuff(self): | |
#self.accelerator = Accelerator(gradient_accumulation_steps=self.kwargs.accum_steps, kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)]) | |
#self.accelerator = Accelerator(gradient_accumulation_steps=self.kwargs.accum_steps) | |
self.accelerator = Accelerator() | |
epoch_len = len(self.train_set) | |
train_len = epoch_len * self.kwargs.epochs | |
num_warmup = int(train_len * 0.01) | |
log(f"Warmup steps: {num_warmup}, epoch len: {epoch_len}, train len: {train_len}", accelerator=self.accelerator) | |
opt = torch.optim.AdamW(chain_params(self.coupling_specs), lr=self.kwargs.lr) | |
lr_scheduler = get_scheduler("linear", optimizer=opt, num_warmup_steps=num_warmup, | |
num_training_steps=train_len * self.accelerator.num_processes) | |
models = [s.model for s in self.coupling_specs] | |
self.optimizer, self.lr_scheduler, *self.models = self.accelerator.prepare(opt, lr_scheduler, *models) | |
self.accelerator.register_for_checkpointing(self.lr_scheduler, self.data_state, self.train_loss_list) | |
if self.kwargs.continue_training: | |
self.accelerator.load_state(self.kwargs.mdl_id) | |
log(f"Reloaded data state: {self.data_state}", accelerator=self.accelerator) | |
def train(self): | |
try: | |
self._main_loop() | |
except Exception as e: | |
#in multi-process scenarios it is hard to read the stack trace, so just show one: | |
if self.accelerator.is_main_process: | |
raise e | |
self.accelerator.wait_for_everyone() | |
unwr_coupled_model = self.accelerator.unwrap_model(self.models[0]) | |
return unwr_coupled_model, self.train_loss_list | |
def _split_batch_and_bin_idxs(self, batch_with_idxs): | |
if self.is_generative: | |
batch, _ = batch_with_idxs | |
src_k = 0 | |
tgt_k = 0 | |
else: | |
batch, src_k, tgt_k, _ = batch_with_idxs | |
return batch, src_k, tgt_k | |
def _prepare_inputs(self, batch, sub_batch_idx, sub_batch_size, proc_batch_size): | |
from_proc_idx = proc_batch_size * self.accelerator.process_index + sub_batch_size * sub_batch_idx | |
to_proc_idx = from_proc_idx + sub_batch_size | |
#log(f"----> DEBUG for sub_b idx {sub_batch_idx}, proc {self.accelerator.process_index}: {from_proc_idx}:{to_proc_idx}") | |
return {k: batch[k][from_proc_idx:to_proc_idx].to(self.accelerator.device) for k in batch} | |
def _get_split_batch_params(self, batch): | |
batch_nr_snts = batch['input_ids'].size()[0] | |
snt_nr_words = batch['input_ids'].size()[1] | |
assert batch_nr_snts % self.accelerator.num_processes == 0, "Batch size must be divisible by number of processes." | |
proc_batch_nr_snts = batch_nr_snts // self.accelerator.num_processes | |
if self.kwargs.nr_snts_in_batch > 0: | |
sub_batch_size = self.kwargs.nr_snts_in_batch | |
else: | |
sub_batch_size = max(1, self.kwargs.nr_words_in_batch // snt_nr_words) | |
#log(f"DEBUG: #words/snt {snt_nr_words} X #snt in sub batch {sub_batch_size} = {snt_nr_words*sub_batch_size} ~ {self.kwargs.nr_words_in_batch}", accelerator=self.accelerator) | |
nr_steps = -(proc_batch_nr_snts // -sub_batch_size) | |
#log(f"--> DEBUG: sub_batch {sub_batch_size} X steps {nr_steps} ~ {proc_batch_nr_snts} ({batch_nr_snts} / {self.accelerator.num_processes})", accelerator=self.accelerator) | |
return sub_batch_size, nr_steps, proc_batch_nr_snts | |
def _main_loop(self): | |
#countdown_till_do_it_once = 0 | |
if self.accelerator.is_main_process: | |
logger = SameLineLogger(len(self.train_set), self.kwargs.epochs) | |
logger.line_start() | |
else: | |
logger = None | |
self.models[0].train() | |
self.train_set.thats_where(self.data_state) | |
for _epoch_idx in range(self.data_state.epoch_idx, self.kwargs.epochs): | |
for batch_with_bin_idxs, epoch_batch_idx in self.train_set: | |
batch, src_k, tgt_k = self._split_batch_and_bin_idxs(batch_with_bin_idxs) | |
sub_batch_size, nr_steps, proc_batch_size = self._get_split_batch_params(batch) | |
loss = None | |
for sub_batch_idx in range(nr_steps): | |
inputs = self._prepare_inputs(batch, sub_batch_idx, sub_batch_size, proc_batch_size) | |
if self.is_generative: | |
inputs['labels'] = inputs['input_ids'] | |
outputs = self.models[0](**inputs) | |
else: | |
encoder_vecs = encode(self.models[src_k], inputs) | |
outputs = self.models[tgt_k](attention_mask=inputs['attention_mask'], labels=inputs['labels'], encoder_outputs=encoder_vecs) | |
loss = outputs.loss | |
#if countdown_till_do_it_once > 0: | |
# countdown_till_do_it_once -= 1 | |
#elif countdown_till_do_it_once == 0: | |
if sub_batch_idx == 5: | |
batch_size = sum([inputs[k].size()[0] * inputs[k].size()[1] for k in 'input_ids labels attention_mask'.split(' ')]) | |
report_devices(f"training memory usage (batch size: {batch_size}; inputs:" + | |
f"snts {inputs['input_ids'].size()[0]} X words {inputs['input_ids'].size()[1]})", | |
self.accelerator, self.models[0]) | |
countdown_till_do_it_once = 0 | |
self.train_loss_list.append(loss.item(), src_k, tgt_k) | |
self.accelerator.backward(loss) | |
for k in inputs: | |
inputs[k] = inputs[k].to('cpu') | |
self._step_and_perhaps_save(logger, epoch_batch_idx, _epoch_idx, float(loss.item())) | |
if self.accelerator.is_main_process: | |
logger.line_break() | |
def get_total_grad(self): | |
result = 0 | |
grad_count = 0 | |
all_count = 0 | |
for p in self.models[0].parameters(): | |
if p.grad is not None: | |
result += p.grad.abs().mean().item() | |
grad_count += 1 | |
all_count += 1 | |
return result/grad_count if grad_count > 0 else -1 | |
def _step_and_perhaps_save(self, logger, epoch_batch_idx, epoch_i, loss): | |
epoch_len = len(self.train_set) | |
global_batch_idx = epoch_batch_idx + epoch_i * epoch_len | |
self.optimizer.step() | |
self.lr_scheduler.step() | |
self.accelerator.wait_for_everyone() | |
is_end_of_epoch = (epoch_batch_idx == epoch_len) | |
if self.accelerator.is_main_process and (epoch_batch_idx % self.kwargs.log_steps == 0 or is_end_of_epoch): | |
grad = self.get_total_grad() | |
logger.step(global_batch_idx, epoch_batch_idx, epoch_i, loss, self.lr_scheduler.get_last_lr()[0], grad) | |
self.optimizer.zero_grad() | |
if (global_batch_idx % self.kwargs.save_steps == 0) or is_end_of_epoch: | |
self.accelerator.wait_for_everyone() | |
if self.accelerator.is_main_process: | |
logger.line_break() | |
log(f"Saving at {epoch_batch_idx} steps, epoch {epoch_i + 1} ({global_batch_idx} global steps)", accelerator=self.accelerator) | |
self._save_all(global_batch_idx, epoch_i) | |
logger.line_start() | |
def _save_all(self, global_batch_idx, epoch_i): | |
epoch_len = len(self.train_set) | |
ckpt_name = (f"checkpoint-e{epoch_i + 1:02}-" + | |
(f"b{global_batch_idx:07}" if (global_batch_idx % epoch_len) else f"full")) | |
this_location = os.path.join(self.kwargs.save_location, ckpt_name) | |
if os.path.exists(this_location): | |
raise FileExistsError(f"Cannot overwrite existing checkpoint {this_location}!") | |
self.data_state.copy_from(self.train_set.where_are_we(), epoch_idx=epoch_i) | |
model_to_save = self.accelerator.unwrap_model(self.models[0]) | |
save_all_models(this_location, model_to_save, self.coupling_specs[0].tokenizer, | |
self.coupling_specs, trainer=self.accelerator) | |
""" |