Transformers documentation
FSDP2
FSDP2
Fully Sharded Data Parallel (FSDP2) shards the model, gradients, and optimizer states across GPUs. Before computation, each GPU gathers a complete set of parameters from all shards, then frees them afterward. Sharding lets you train models larger than a single GPUβs memory, at the cost of more communication than DDP. Use FSDP when your model or optimizer states donβt fit on a single GPU.
βββββββββββββββββββ
β training data β
ββββββββββ¬βββββββββ
ββββββββββββββββββββΌβββββββββββββββββββ
β shard 0 β shard 1 β shard 2
βΌ βΌ βΌ
βββββββββββββββ βββββββββββββββ βββββββββββββββ
β param β β param β β param β
β shard 0 β β shard 1 β β shard 2 β
β GPU 0 β β GPU 1 β β GPU 2 β
ββββββββ¬βββββββ ββββββββ¬βββββββ ββββββββ¬βββββββ
β β β
βββββββββ all-gather (params) βββββββββ
β
full params on each GPU
β
ββββββββββββββββββββΌβββββββββββββββββββ
βΌ βΌ βΌ
forward forward forward
β β β
ββββββ reduce-scatter (grads) βββββββββ
β
ββββββββββββββββββββΌβββββββββββββββββββ
βΌ βΌ βΌ
grad shard 0 grad shard 1 grad shard 2
optim shard 0 optim shard 1 optim shard 2
step step stepSharding strategies
FSDP2 controls sharding with ~TrainingArguments.fsdp_config. Set fsdp=True to enable FSDP, and set reshard_after_forward in the FSDP config to choose the memory and throughput tradeoff.
reshard_after_forward | behavior |
|---|---|
true | reshard parameters after the forward pass to save more memory |
false | keep parameters gathered between forward and backward to avoid the re-all-gather, at the cost of higher peak memory |
auto_wrap_policy controls how modules are wrapped into FSDP units. It defaults to "TRANSFORMER_BASED_WRAP", which wraps the modelβs transformer layers. Without wrapping ("NO_WRAP"), the entire model is one FSDP unit and you lose the memory benefit of sharding.
Configure FSDP
These fields control how FSDP2 wraps, shards, and loads the model. reshard_after_forward and auto_wrap_policy are covered in Sharding strategies.
cpu_offloadoffloads parameters and gradients to CPU when they arenβt in use to save GPU memory.transformer_layer_cls_to_wrapdefines the transformer layer to wrap into an FSDP unit whenauto_wrap_policyis"TRANSFORMER_BASED_WRAP". Each unit manages its own gather and scatter ops. Only the current unitβs parameters are gathered during the forward pass. The previous unitsβ parameters are released to save memory.Wrapping only the top-level model yields no GPU memory savings. Wrapping every individual
Linearlayer makes inter-unit communication very expensive. Leave this field empty and FSDP reads the value from the model definition.min_num_paramssets the minimum number of parameters per module for size-based wrapping. It is only used whenauto_wrap_policyis"SIZE_BASED_WRAP".state_dict_typecontrols the checkpoint format. Defaults to"FULL_STATE_DICT"for a single Transformers-compatible checkpoint. Use"SHARDED_STATE_DICT"for one checkpoint file per rank, which is faster for large models. Sharded checkpoints only load back into FSDP, so save a"FULL_STATE_DICT"for the final checkpoint you want to share or load outside FSDP.cpu_ram_efficient_loadingloads the checkpoint from disk on rank 0 only. Other GPUs initialize an empty model and receive the weights by broadcast, avoiding multiple processes loading a large model into CPU RAM.activation_checkpointingrecomputes activations during the backward pass instead of storing them. Use this instead of gradient checkpointing in TrainingArguments. Setting both raises an error.
Configure FSDP training with either an Accelerate config file or an FSDP config file passed to fsdp_config.
Run the accelerate config command and answer questions about your hardware and training setup. This creates a default_config.yaml file in your cache.
Run accelerate launch with a Trainer-based script. The fsdp_config is unnecessary because the Accelerate config file covers the same settings.
accelerate launch train.pyNext steps
- See DDP for data-parallel training when your model fits on one GPU.
- See DeepSpeed for ZeRO optimization and NVMe offloading.
- For FSDP on TPUs with PyTorch/XLA, set
xla,xla_fsdp_settings, andxla_fsdp_grad_ckptin~TrainingArguments.fsdp_config. - Read the FSDP chapter from The Ultra-Scale Playbook for more information about how FSDP works.