# coding=utf-8 # Copyright 2020-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task. """ import contextlib import copy import functools import glob import importlib.metadata import inspect import json import math import os import random import re import shutil import sys import tempfile import time import warnings from collections.abc import Mapping from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union # Integrations must be imported before ML frameworks: # isort: off from transformers.integrations import ( get_reporting_integration_callbacks, hp_params, ) # isort: on import huggingface_hub.utils as hf_hub_utils import numpy as np import torch import torch.distributed as dist from huggingface_hub import ModelCard, create_repo, upload_folder from packaging import version from torch import nn from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler from transformers import __version__ from transformers.configuration_utils import PretrainedConfig from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator from transformers.debug_utils import DebugOption, DebugUnderflowOverflow from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor from transformers.feature_extraction_utils import FeatureExtractionMixin from transformers.hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend from transformers.image_processing_utils import BaseImageProcessor from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available from transformers.integrations.tpu import tpu_spmd_dataloader from transformers.modelcard import TrainingSummary from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model from transformers.models.auto.modeling_auto import ( MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES, ) from transformers.optimization import Adafactor, get_scheduler from transformers.processing_utils import ProcessorMixin from transformers.pytorch_utils import ( ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_2_3, ) from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.trainer_callback import ( CallbackHandler, DefaultFlowCallback, ExportableState, PrinterCallback, ProgressCallback, TrainerCallback, TrainerControl, TrainerState, ) from transformers.trainer_pt_utils import ( DistributedTensorGatherer, EvalLoopContainer, IterableDatasetShard, LabelSmoother, LayerWiseDummyOptimizer, LengthGroupedSampler, SequentialDistributedSampler, distributed_broadcast_scalars, distributed_concat, find_batch_size, get_model_param_count, get_module_class_from_name, get_parameter_names, nested_concat, nested_detach, nested_numpify, nested_xla_mesh_reduce, reissue_pt_warnings, remove_dummy_checkpoint, ) from transformers.trainer_utils import ( PREFIX_CHECKPOINT_DIR, BestRun, EvalLoopOutput, EvalPrediction, HPSearchBackend, HubStrategy, PredictionOutput, RemoveColumnsCollator, SaveStrategy, TrainerMemoryTracker, TrainOutput, check_target_module_exists, default_compute_objective, denumpify_detensorize, enable_full_determinism, find_executable_batch_size, get_last_checkpoint, has_length, neftune_post_forward_hook, number_of_arguments, seed_worker, set_seed, speed_metrics, ) from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments from transformers.utils import ( ADAPTER_CONFIG_NAME, ADAPTER_SAFE_WEIGHTS_NAME, ADAPTER_WEIGHTS_NAME, CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME, XLA_FSDPV2_MIN_VERSION, PushInProgress, PushToHubMixin, can_return_loss, find_labels, is_accelerate_available, is_apex_available, is_bitsandbytes_available, is_datasets_available, is_galore_torch_available, is_grokadamw_available, is_in_notebook, is_ipex_available, is_liger_kernel_available, is_lomo_available, is_peft_available, is_safetensors_available, is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_schedulefree_available, is_torch_compile_available, is_torch_mlu_available, is_torch_mps_available, is_torch_musa_available, is_torch_neuroncore_available, is_torch_npu_available, is_torch_xla_available, is_torch_xpu_available, is_torchao_available, logging, strtobool, ) from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.quantization_config import QuantizationMethod DEFAULT_CALLBACKS = [DefaultFlowCallback] DEFAULT_PROGRESS_CALLBACK = ProgressCallback if is_in_notebook(): from transformers.utils.notebook import NotebookProgressCallback DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback if is_apex_available(): from apex import amp if is_datasets_available(): import datasets if is_torch_xla_available(): import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met from torch_xla import __version__ as XLA_VERSION IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION) if IS_XLA_FSDPV2_POST_2_2: import torch_xla.distributed.spmd as xs import torch_xla.runtime as xr else: IS_XLA_FSDPV2_POST_2_2 = False if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp from smdistributed.modelparallel import __version__ as SMP_VERSION IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat else: IS_SAGEMAKER_MP_POST_1_10 = False if is_safetensors_available(): import safetensors.torch if is_peft_available(): from peft import PeftModel if is_accelerate_available(): from accelerate import Accelerator, skip_first_batches from accelerate import __version__ as accelerate_version from accelerate.state import AcceleratorState from accelerate.utils import ( DistributedDataParallelKwargs, DistributedType, load_fsdp_model, load_fsdp_optimizer, save_fsdp_model, save_fsdp_optimizer, ) DATA_SAMPLERS = [RandomSampler] if version.parse(accelerate_version) > version.parse("0.23.0"): from accelerate.data_loader import SeedableRandomSampler DATA_SAMPLERS += [SeedableRandomSampler] if is_deepspeed_available(): from accelerate.utils import DeepSpeedSchedulerWrapper if is_accelerate_available("0.28.0"): from accelerate.utils import DataLoaderConfiguration def _is_peft_model(model): if is_peft_available(): classes_to_check = (PeftModel,) if is_peft_available() else () # Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321 if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"): from peft import PeftMixedModel classes_to_check = (*classes_to_check, PeftMixedModel) return isinstance(model, classes_to_check) return False def _get_fsdp_ckpt_kwargs(): # TODO: @AjayP13, @younesbelkada replace this check with version check at the next `accelerate` release if is_accelerate_available() and "adapter_only" in list(inspect.signature(save_fsdp_model).parameters): return {"adapter_only": True} else: return {} def safe_globals(): # Starting from version 2.4 PyTorch introduces a check for the objects loaded # with torch.load(weights_only=True). Starting from 2.6 weights_only=True becomes # a default and requires allowlisting of objects being loaded. # See: https://github.com/pytorch/pytorch/pull/137602 # See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals # See: https://github.com/huggingface/accelerate/pull/3036 if version.parse(torch.__version__).release < version.parse("2.6").release: return contextlib.nullcontext() np_core = np._core if version.parse(np.__version__) >= version.parse("2.0.0") else np.core allowlist = [np_core.multiarray._reconstruct, np.ndarray, np.dtype] # numpy >1.25 defines numpy.dtypes.UInt32DType, but below works for # all versions of numpy allowlist += [type(np.dtype(np.uint32))] return torch.serialization.safe_globals(allowlist) if TYPE_CHECKING: import optuna if is_datasets_available(): import datasets logger = logging.get_logger(__name__) logger.setLevel("INFO") # Name of the files used for checkpointing TRAINING_ARGS_NAME = "training_args.bin" TRAINER_STATE_NAME = "trainer_state.json" OPTIMIZER_NAME = "optimizer.pt" OPTIMIZER_NAME_BIN = "optimizer.bin" SCHEDULER_NAME = "scheduler.pt" SCALER_NAME = "scaler.pt" FSDP_MODEL_NAME = "pytorch_model_fsdp" DATA_PRINT_ONCE = True BATCH = None def print_batch(batch, tokenizer, args): global DATA_PRINT_ONCE global BATCH if batch is not None: BATCH = batch else: batch = BATCH DATA_PRINT_ONCE = True if batch is None: return if DATA_PRINT_ONCE: global_rank = torch.distributed.get_rank() f = open(os.path.join(args.output_dir, f"print_batch_{global_rank}.log"), "a") torch.set_printoptions(threshold=100_000) if "loss_mask" in batch and batch["loss_mask"] is not None: loss_mask = batch["loss_mask"] print(f"loss_mask {loss_mask} {loss_mask.size()}", file=f) if "position_ids" in batch and batch["position_ids"] is not None: position_ids = batch["position_ids"] print(f"position_ids {position_ids} {position_ids.size()}", file=f) if "attention_mask" in batch and batch["attention_mask"] is not None: attention_mask = batch["attention_mask"] if isinstance(attention_mask, list): attention_mask = attention_mask[0] print(f"attention_mask {attention_mask} {attention_mask.size()}", file=f) if "input_ids" in batch and batch["input_ids"] is not None: tokens = batch["input_ids"] print(f"tokens {tokens} {tokens.size()}", file=f) tokens_ = tokens.cpu().clone().detach() tokens_ = tokenizer.batch_decode(tokens_.tolist(), skip_special_tokens=False) print(f"tokens_ {tokens_[:]}", file=f) if "labels" in batch and batch["labels"] is not None: labels = batch["labels"] print(f"labels {labels} {labels.size()}", file=f) labels_ = labels.cpu().clone().detach() labels_[labels_==-100] = tokenizer("-", add_special_tokens=False).input_ids[0] labels_ = tokenizer.batch_decode(labels_.tolist(), skip_special_tokens=False) print(f"labels {labels_}", file=f) # labels__ = labels.cpu().clone().detach() # labels__[loss_mask.to(torch.int64)==0] = tokenizer("-", add_special_tokens=False).input_ids[0] # labels__ = tokenizer.batch_decode(labels__.tolist(), skip_special_tokens=False) # print(f"labels__ {labels__}", file=f) for k, v in batch.items(): if isinstance(v, torch.Tensor): print(f"{k} {v} {v.size()}", file=f) else: print(f"{k} {v}", file=f) f.close() DATA_PRINT_ONCE = False from transformers import Trainer as HFTrainer class Trainer(HFTrainer): def get_train_dataloader(self) -> DataLoader: """ Returns the training [`~torch.utils.data.DataLoader`]. Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed training if necessary) otherwise. Subclass and override this method if you want to inject some custom behavior. """ if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") train_dataset = self.train_dataset data_collator = self.data_collator if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): train_dataset = self._remove_unused_columns(train_dataset, description="training") else: data_collator = self._get_collator_with_removed_columns(data_collator, description="training") dataloader_params = { "batch_size": self._train_batch_size, "collate_fn": data_collator, "num_workers": self.args.dataloader_num_workers, "pin_memory": self.args.dataloader_pin_memory, "persistent_workers": self.args.dataloader_persistent_workers, "multiprocessing_context": "spawn", } if not isinstance(train_dataset, torch.utils.data.IterableDataset): dataloader_params["sampler"] = self._get_train_sampler() dataloader_params["drop_last"] = self.args.dataloader_drop_last dataloader_params["worker_init_fn"] = seed_worker dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) def create_optimizer(self): """ Setup the optimizer. We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the Trainer's init through `optimizers`, or subclass and override this method in a subclass. """ opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if self.optimizer is None: decay_parameters = self.get_decay_parameter_names(opt_model) if self.args.vision_model_lr_mult != 1.0 or self.args.vision_model_lr_decay_rate != 1.0: vision_parameters = [name for name, _ in opt_model.named_parameters() if "vision_model" in name] logger.info(f"{vision_parameters=}") else: vision_parameters = [] if self.args.mtp_model_lr_mult != 1.0: mtp_parameters = [] mtp_names = ["mtp"] num_nextn_predict_layers = self.model.config.num_nextn_predict_layers num_hidden_layers = self.model.config.num_hidden_layers for mtp_idx in range(num_nextn_predict_layers): layer_idx = num_hidden_layers - num_nextn_predict_layers + mtp_idx mtp_names.append(f"model.layers.{layer_idx}") for name, param in opt_model.named_parameters(): if any([x in name for x in mtp_names]): mtp_parameters.append(name) logger.info(f"{mtp_parameters=}") else: mtp_parameters = [] exclude_parameters = vision_parameters + mtp_parameters optimizer_grouped_parameters = [ { "params": [ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and n not in exclude_parameters) ], "weight_decay": self.args.weight_decay, }, { "params": [ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad and n not in exclude_parameters) ], "weight_decay": 0.0, }, ] if self.args.vision_model_lr_decay_rate != 1.0: for n, p in opt_model.named_parameters(): if p.requires_grad and n in vision_parameters: pass else: continue if n in decay_parameters: weight_decay = self.args.weight_decay else: weight_decay = 0.0 lr = self.args.learning_rate * get_vit_lr_decay_rate(n, opt_model.config.visual.num_hidden_layers, self.args.vision_model_lr_decay_rate) optimizer_grouped_parameters.append( { "params": [p], "weight_decay": weight_decay, "lr": lr, } ) logger.info(f"create_optimizer name {n} weight_decay {weight_decay} lr {lr}") elif self.args.vision_model_lr_mult != 1.0: optimizer_grouped_parameters.extend( [ { "params": [ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and n in vision_parameters) ], "weight_decay": self.args.weight_decay, "lr": self.args.learning_rate * self.args.vision_model_lr_mult, }, { "params": [ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad and n in vision_parameters) ], "weight_decay": 0.0, "lr": self.args.learning_rate * self.args.vision_model_lr_mult, }, ] ) logger.info(f"create_optimizer name {[n for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and n in vision_parameters)]} weight_decay {self.args.weight_decay} lr_mult {self.args.vision_model_lr_mult}") logger.info(f"create_optimizer name {[n for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad and n in vision_parameters)]} weight_decay {0.0} lr_mult {self.args.vision_model_lr_mult}") if self.args.mtp_model_lr_mult != 1.0: optimizer_grouped_parameters.extend( [ { "params": [ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and n in mtp_parameters) ], "weight_decay": self.args.weight_decay, "lr": self.args.learning_rate * self.args.mtp_model_lr_mult, }, { "params": [ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad and n in mtp_parameters) ], "weight_decay": 0.0, "lr": self.args.learning_rate * self.args.mtp_model_lr_mult, }, ] ) logger.info(f"create_optimizer name {[n for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and n in mtp_parameters)]} weight_decay {self.args.weight_decay} lr_mult {self.args.mtp_model_lr_mult}") logger.info(f"create_optimizer name {[n for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad and n in mtp_parameters)]} weight_decay {0.0} lr_mult {self.args.mtp_model_lr_mult}") if self.optimizer_cls_and_kwargs is not None: optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs else: optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model) # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs` # e.g. for GaLore optimizer. if "params" in optimizer_kwargs: optimizer_grouped_parameters = optimizer_kwargs.pop("params") # Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs` # e.g. for LOMO optimizer. if "model" in optimizer_kwargs: optimizer_grouped_parameters = optimizer_kwargs.pop("model") # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict` # to avoid arguments conflicts. if "optimizer_dict" in optimizer_kwargs: optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict") self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) if optimizer_cls.__name__ == "Adam8bit": import bitsandbytes manager = bitsandbytes.optim.GlobalOptimManager.get_instance() skipped = 0 for module in opt_model.modules(): if isinstance(module, nn.Embedding): skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) logger.info(f"skipped {module}: {skipped/2**20}M params") manager.register_module_override(module, "weight", {"optim_bits": 32}) logger.debug(f"bitsandbytes: will optimize {module} in fp32") logger.info(f"skipped: {skipped/2**20}M params") if is_sagemaker_mp_enabled(): self.optimizer = smp.DistributedOptimizer(self.optimizer) return self.optimizer def training_step( self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None ) -> torch.Tensor: """ Perform a training step on a batch of inputs. Subclass and override to inject custom behavior. Args: model (`nn.Module`): The model to train. inputs (`Dict[str, Union[torch.Tensor, Any]]`): The inputs and targets of the model. The dictionary will be unpacked before being fed to the model. Most models expect the targets under the argument `labels`. Check your model's documentation for all accepted arguments. Return: `torch.Tensor`: The tensor with training loss on this batch. """ print_batch(inputs, self.processing_class, self.args) model.train() if hasattr(self.optimizer, "train") and callable(self.optimizer.train): self.optimizer.train() inputs = self._prepare_inputs(inputs) if is_sagemaker_mp_enabled(): loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) return loss_mb.reduce_mean().detach().to(self.args.device) with self.compute_loss_context_manager(): loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) del inputs if ( self.args.torch_empty_cache_steps is not None and self.state.global_step % self.args.torch_empty_cache_steps == 0 ): if is_torch_xpu_available(): torch.xpu.empty_cache() elif is_torch_mlu_available(): torch.mlu.empty_cache() elif is_torch_musa_available(): torch.musa.empty_cache() elif is_torch_npu_available(): torch.npu.empty_cache() elif is_torch_mps_available(min_version="2.0"): torch.mps.empty_cache() elif is_torch_hpu_available(): logger.warning( "`torch_empty_cache_steps` is set but HPU device/backend does not support empty_cache()." ) else: torch.cuda.empty_cache() kwargs = {} # For LOMO optimizers you need to explicitly use the learnign rate if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: kwargs["learning_rate"] = self._get_learning_rate() if self.args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training if self.use_apex: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: # Finally we need to normalize the loss for reporting if not self.model_accepts_loss_kwargs and self.compute_loss_func is None: loss = loss / self.args.gradient_accumulation_steps # Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled # https://github.com/huggingface/transformers/pull/35808 if self.accelerator.distributed_type == DistributedType.DEEPSPEED: kwargs["scale_wrt_gas"] = False self.accelerator.backward(loss, **kwargs) return loss.detach() def get_batch_samples(self, epoch_iterator, num_batches): batch_samples = [] num_items_in_batch = None for _ in range(num_batches): try: while True: batch_sample = next(epoch_iterator) if "input_ids" in batch_sample: break batch_samples += [batch_sample] except StopIteration: break if len(batch_samples) > 0 and "labels" in batch_samples[0]: # For now we don't support object detection try: num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) except (TypeError, AttributeError): pass if self.args.average_tokens_across_devices and num_items_in_batch is not None: num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() if torch.is_tensor(num_items_in_batch): num_items_in_batch = num_items_in_batch.item() return batch_samples, num_items_in_batch def get_vit_lr_decay_rate(name, num_layers, lr_decay_rate): layer_id = num_layers + 1 if "vision_model." in name: if ".position_embedding." in name or ".conv1." in name: layer_id = 0 elif ".layers." in name: layer_id = int(name[name.find(".layers.") :].split(".")[2]) + 1 return lr_decay_rate ** (num_layers + 1 - layer_id)