FSDP via SPMD (FSDP v2)
Introduction
FSDP, fully sharded data parallel, is a well-known distributed training algorithm in the PyTorch world. SPMD is PyTorch/XLA’s API that allows users to annotate a single device PyTorch model and then let XLA’s GSPMD feature turn it into a distributed model. This design doc focuses on how to utilize SPMD to express FSDP, and make this new implementation performant and easy to use. Since PyTorch/XLA already has a native implementation of FSDP here, this new implementation is also referred to as FSDPv2.
Background
A lot of past SPMD training experiments as conducted in has demonstrated that FSDP, i.e, 1D sharding has better performance than 2D sharding as long as the model can fit into the training fleet.
To express FSDP using the vanilla SPMD API, one currently needs to accomplish the following 5 steps. Examples are taken from our HF Llama 2 fork.
1. Define mesh
# Place DCN on an independent axis in the mesh. Model parameters should be
# replicated along the DCN axis, and inputs and activations should have
# the batch dimension sharded along the combined DCN and data axes.
num_devices = xr.global_runtime_device_count()
model_axis = max(model_args.spmd_2d_sharding, 1) # spmd_2d_sharding is set to 1 here.
dcn_axis = model_args.spmd_dcn_parallelism
data_axis = num_devices // model_axis // dcn_axis
ici_mesh_shape = (1, data_axis, model_axis)
dcn_mesh_shape = (dcn_axis, 1, 1)
spmd_mesh = xs.HybridMesh(ici_mesh_shape=ici_mesh_shape, dcn_mesh_shape=dcn_mesh_shape,axis_names=('dcn', 'data', 'model'))
2. Shard data loader
import torch_xla.experimental.xla_sharding as xs
import torch_xla.distributed.parallel_loader as pl
sharding_spec = xs.ShardingSpec(self.args.spmd_mesh, (('dcn', 'data'), None))
# TODO(jonbolin): Once integrated with Accelerate, we can use the Accelerate-prepared
# MpDeviceLoader instead of manually adding sharding and adding a dataset attribute.
loader = pl.MpDeviceLoader(dataloader, self.args.device, input_sharding=sharding_spec, loader_prefetch_size=self.args.train_batch_size, device_prefetch_size=4)
3. Shard weights
for name, param in model.named_parameters():
if model_args.spmd_fsdp_sharding:
print('> [FSDP] Sharding tensor', name, param.shape, param.dtype)
# We don't care about layernorm's weights, and
# LLaMA doesn't use biases.
if len(param.shape) == 1:
continue
assert len(param.shape) == 2
# Shard the largest dimension
if param.shape[0] > param.shape[1]:
partition_spec = ('data', None)
else:
partition_spec = (None, 'data')
xs.mark_sharding(param, spmd_mesh, partition_spec)
4. Shard activations
# Apply 2D sharding:
# hidden_states (batch, length, hidden)
# mesh (data, None, model)
if self.spmd_debug:
print('> Sharding hidden_states', hidden_states.shape, self.spmd_mesh.get_logical_mesh().shape)
xs.mark_sharding(hidden_states, self.spmd_mesh, (('dcn', 'data'), None, 'model'))
if self.spmd_debug:
print(torch_xla._XLAC._get_xla_sharding_spec(hidden_states))
5. Apply backward optimization barrier
for i, block in enumerate(model.model.layers):
# LLaMA-specific
xs.apply_backward_optimization_barrier(model.model.layers[i])
Even though FSDP requires far less sharding annotations than 2D sharding which will require a lot more sharding on activations and additional sharding on the attention layers if for LLMs, it’s still complex. Here is the recap of what our native FSDP implementation, which we will refer to as FSDPv1 in the design, is being used in general:
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy
auto_wrap_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls=transformer_cls_to_wrap)
self.model = model = FSDP(model, auto_wrap_policy=auto_wrap_policy, **fsdp_kwargs)
It’s far less code boilerplate to write, and thus a much better user experience! Here comes the problem statement: could we recreate the same user experience while keeping the same performance with vanilla SPMD FSDP annotations?
Goals
- Easy to use UX
- Competitive out of box performance
- Seamless integration with HF and Lightning
- Major way of using SPMD
Non-goals
Feature Requirements
Before talking about the design candidates, let’s detail what are the features and characteristics of this new system. This way, it can better help us navigate through the different design candidates.
P0: Shard on weights
This is the basic concept of FSDP where weights are sharded and distributed among the training fleet.
P0: Shard on activations
This is not needed in FSDPv1. However, if we omit this in the vanilla SPMD example in the Background section, we get much worse performance. The following performance benchmarks are taken with v4-8 and Llama 2 2B on 1K seq_len.
Shard activations
- Xprof: Missing
- Hardware FLOPS utilization: 65.0%
- Peak memory allocation: 16392.05 MiB
Don’t shard activations
- Xprof: Missing
- Hardware FLOPS utilization: 50.3%
- Peak memory allocation: 16673.46 MiB
It turns out that in the second case, the compiler decides to do some wild all-to-all and all-reduce in the attention layer.
Therefore, this feature is a must to instruct the compiler to follow the FSDP algorithm. Fortunately, we only need to shard either the input or the output hidden_states of the decoder layer in the case of LLM, and we don’t need to shard every activation.
P0: Backward optimization barrier
This is needed to prevent gigantic fusions on syncing the gradients. The only remaining question is whether it’s compatible with gradient checkpointing since both of them will overwrite the backward pass in some fashions. Theoretically speaking, it should be compatible and the application order shouldn’t matter.
P0: Manual wrapping
Most of the features here should be packaged together and can be applied separately to the root module and the children modules. Let’s take FSDPv1 as an example. Typically the wrapper will be applied to two modules:
- The root module.
- The decoder layer.
Even though FSDPv1 will be default to shard all the parameters including children’s in the wrapped module, yet the rebuilding, memory-freeing, gradient synchronizing logic only applies to wrapped module. If only the root module is wrapped, then all parameters will be built in full during the outermost forward and thus no memory-saving. If every child module is wrapped, the overhead will just be too much. That’s why usually only the above two types of modules are wrapped.
P1: MultiSlice support
The implementation should be flexible to support 1) data parallel over MultiSlice and 2) FSDP over MultiSlice.
P1: Defer Parameter initialization
This is needed when the total model size is larger than the host memory. In TPU v5e, the host memory is extremely limited and this feature becomes a must. Basically, what we need to do is to initialize the model layer by layer, and transfer the layer to the device immediately.
P1: Auto wrapping
This refers to the ability to apply the same set of rules, e.g., sharding/opt-barrier/etc, automatically to children modules from the root module.
P1: Distributed checkpointing support
Two use cases here: one is for exception handling during the training job, and the other one is for consolidating to be used for future inference. For different design candidates, this feature requirement might have different implications.** For example, a nn.Module wrapper approach will introduce additional naming prefixes in the state_dicts.**
P1: HuggingFace and Lightning integrations
FSDPv2 should design in mind to easily replace the current FSDPv1 integrations presented in HuggingFace and Lightning, and thus become the default distributed algorithm in those two high level frameworks for PyTorch/XLA.
P2: Mixed precision support
In FSDPv1, it offers manual mixed precision support where the weights are always in FP32 but compute can be performed in BF16. Mixed precision support is definitely needed but whether it’s supported via torch.amp or via this is under discussion.
P2: Gradient Averaging
In FSDPv1, it offers a nice way of averaging the gradients by world_size to avoid overflows during all_reduce. Unclear whether this is necessary for this design.
N.A.: Shard on attentions
This is needed in the case of 2D sharding as we need to pick num_attention_heads dim to shard on the model axis. Since we only shard on the bs dim, theoretically speaking we shouldn’t need it. Experiments also validate the theory.
N.A.: Replace nn.Linear
This is needed for 2D sharding as we don’t want PyTorch to collide the two dims of a tensor where both of them are sharded during a matmul operation. However, in FSDP, at most one dim of the tensor will be sharded, and therefore this is not needed. Here is the xprof that drops XLAPatchedLinear. No performance degradation is observed.
N.A.: Shard optimizer states
This is proven to be unnecessary during the 2D sharding exercise.
Design
In this section, two approaches will be discussed. Each of them will have a PoC implementation that includes all the P0 features to demonstrate the feasibility and pros & cons. Then one of them will be selected as the final design and then more P1 and P2 features will be added on top of.
As an nn.Module: SpmdFullyShardedDataParallel
This is a traditional approach like the FSDPv2. Here we have the following major components
init
It will take care of the following P0s:
- Shard on weights
- Apply backward optimization barrier
Below is the pseudo code:
def __init__(self, module: nn.Module, mesh: spmd.Mesh, shard_output:Optional[Callable] = None):
# Check the paramters
...
super().__init__()
self._orig_module = module
self._mesh = mesh
self._shard_output = shard_output # will explain in the next section
# shard the weights
for param in module.parameters():
spmd.mark_sharding(param, mesh, _prepare_spmd_partition_spec(param))
# apply the backward optimization barrier
spmd.xla_sharding.apply_backward_optimization_barrier(module)
forward
This is the most debatable part. The forward function is used to shard on activations, specifically the output of the original module. As discussed in the above section, this is required to maintain the high performance. However, the output of the forward function can be anything, and therefore it’s really hard to shard.
Here is the proposed solution. Conventionally, the output usually will be:
- A tensor
- A tuple of tensors
For the 1st case, we can safely shard it. For the 2nd case, we can shard on the 0th element and warn the user to provide an output sharding function if that element is not intended. For all the other cases, we will just raise a runtime error to ask users to provide an output sharding function. That’s what shard_output in the above section represents.
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
output = self.module(*args, **kwargs)
self._shard_output(output, self._mesh)
return output
Code boilerplate
Some code boilerplate that is copied from the FSDPv1.
def module(self) -> nn.Module:
"""make model.module accessible, just like DDP."""
...
def __getattr__(self, name: str) -> Union[torch.Tensor, nn.Module]:
"""Forward missing attributes to wrapped module."""
...
def __getitem__(self, key: int) -> nn.Module:
"""Forward indexing calls in case the module is a nn.Sequential."""
...
Manual wrapping
Here it follows the same logic as FSDPv1 where only the root module and important backbone children modules should be wrapped, like the decoder layers in LLMs. This is very important as we don’t want to register the backward optimization barrier for every linear. Otherwise, the compiler will lose a ton of optimization opportunities.
(WIP) As a function: spmd_fully_sharded_data_parallel
This will be very similar to the distribute_module, and in fact probably can be a high level wrapper on top of it. Basically, it will take care of everything within the function.
Test Plan
- Unit tests that examine all the basic functionalities.
- Performance tests that demonstrate the parity of FSDPv2 against the vanilla SPMD implementation.
(WIP) PoCs Evaluation
SpmdFullyShardedDataParallel
Performance
- Xprof: Missing
- Hardware FLOPS utilization: 63.5% v.s. 65.0%.
- Peak memory allocation: 16666.45 MiB v.s.16392.05 MiB
Project Management
Milestone 1:
- Wrap up initial design
- Provide PoC of SpmdFullyShardedDataParallel: 1. PyTorch/XLA PR, 2. HF integration.
- Provide PoC of spmd_fully_sharded_data_parallel.
- Evaluate the two PoCs and determine the right approach.
Milestone 2:
TBD, implement the rest of the feature set.
Open Questions
- Distributed checkpointing compatibility, naming conventions
cc @JackCaoG @yeounoh @jonb377 @wconstab @wanchaol
FSDP via SPMD (FSDP v2)
Introduction
FSDP, fully sharded data parallel, is a well-known distributed training algorithm in the PyTorch world. SPMD is PyTorch/XLA’s API that allows users to annotate a single device PyTorch model and then let XLA’s GSPMD feature turn it into a distributed model. This design doc focuses on how to utilize SPMD to express FSDP, and make this new implementation performant and easy to use. Since PyTorch/XLA already has a native implementation of FSDP here, this new implementation is also referred to as FSDPv2.
Background
A lot of past SPMD training experiments as conducted in has demonstrated that FSDP, i.e, 1D sharding has better performance than 2D sharding as long as the model can fit into the training fleet.
To express FSDP using the vanilla SPMD API, one currently needs to accomplish the following 5 steps. Examples are taken from our HF Llama 2 fork.
1. Define mesh
2. Shard data loader
3. Shard weights
4. Shard activations
5. Apply backward optimization barrier
Even though FSDP requires far less sharding annotations than 2D sharding which will require a lot more sharding on activations and additional sharding on the attention layers if for LLMs, it’s still complex. Here is the recap of what our native FSDP implementation, which we will refer to as FSDPv1 in the design, is being used in general:
It’s far less code boilerplate to write, and thus a much better user experience! Here comes the problem statement: could we recreate the same user experience while keeping the same performance with vanilla SPMD FSDP annotations?
Goals
Non-goals
Feature Requirements
Before talking about the design candidates, let’s detail what are the features and characteristics of this new system. This way, it can better help us navigate through the different design candidates.
P0: Shard on weights
This is the basic concept of FSDP where weights are sharded and distributed among the training fleet.
P0: Shard on activations
This is not needed in FSDPv1. However, if we omit this in the vanilla SPMD example in the Background section, we get much worse performance. The following performance benchmarks are taken with v4-8 and Llama 2 2B on 1K seq_len.
Shard activations
Don’t shard activations
It turns out that in the second case, the compiler decides to do some wild all-to-all and all-reduce in the attention layer.
Therefore, this feature is a must to instruct the compiler to follow the FSDP algorithm. Fortunately, we only need to shard either the input or the output hidden_states of the decoder layer in the case of LLM, and we don’t need to shard every activation.
P0: Backward optimization barrier
This is needed to prevent gigantic fusions on syncing the gradients. The only remaining question is whether it’s compatible with gradient checkpointing since both of them will overwrite the backward pass in some fashions. Theoretically speaking, it should be compatible and the application order shouldn’t matter.
P0: Manual wrapping
Most of the features here should be packaged together and can be applied separately to the root module and the children modules. Let’s take FSDPv1 as an example. Typically the wrapper will be applied to two modules:
Even though FSDPv1 will be default to shard all the parameters including children’s in the wrapped module, yet the rebuilding, memory-freeing, gradient synchronizing logic only applies to wrapped module. If only the root module is wrapped, then all parameters will be built in full during the outermost forward and thus no memory-saving. If every child module is wrapped, the overhead will just be too much. That’s why usually only the above two types of modules are wrapped.
P1: MultiSlice support
The implementation should be flexible to support 1) data parallel over MultiSlice and 2) FSDP over MultiSlice.
P1: Defer Parameter initialization
This is needed when the total model size is larger than the host memory. In TPU v5e, the host memory is extremely limited and this feature becomes a must. Basically, what we need to do is to initialize the model layer by layer, and transfer the layer to the device immediately.
P1: Auto wrapping
This refers to the ability to apply the same set of rules, e.g., sharding/opt-barrier/etc, automatically to children modules from the root module.
P1: Distributed checkpointing support
Two use cases here: one is for exception handling during the training job, and the other one is for consolidating to be used for future inference. For different design candidates, this feature requirement might have different implications.** For example, a nn.Module wrapper approach will introduce additional naming prefixes in the state_dicts.**
P1: HuggingFace and Lightning integrations
FSDPv2 should design in mind to easily replace the current FSDPv1 integrations presented in HuggingFace and Lightning, and thus become the default distributed algorithm in those two high level frameworks for PyTorch/XLA.
P2: Mixed precision support
In FSDPv1, it offers manual mixed precision support where the weights are always in FP32 but compute can be performed in BF16. Mixed precision support is definitely needed but whether it’s supported via torch.amp or via this is under discussion.
P2: Gradient Averaging
In FSDPv1, it offers a nice way of averaging the gradients by world_size to avoid overflows during all_reduce. Unclear whether this is necessary for this design.
N.A.: Shard on attentions
This is needed in the case of 2D sharding as we need to pick
num_attention_headsdim to shard on themodelaxis. Since we only shard on the bs dim, theoretically speaking we shouldn’t need it. Experiments also validate the theory.N.A.: Replace nn.Linear
This is needed for 2D sharding as we don’t want PyTorch to collide the two dims of a tensor where both of them are sharded during a matmul operation. However, in FSDP, at most one dim of the tensor will be sharded, and therefore this is not needed. Here is the xprof that drops XLAPatchedLinear. No performance degradation is observed.
N.A.: Shard optimizer states
This is proven to be unnecessary during the 2D sharding exercise.
Design
In this section, two approaches will be discussed. Each of them will have a PoC implementation that includes all the P0 features to demonstrate the feasibility and pros & cons. Then one of them will be selected as the final design and then more P1 and P2 features will be added on top of.
As an nn.Module: SpmdFullyShardedDataParallel
This is a traditional approach like the FSDPv2. Here we have the following major components
init
It will take care of the following P0s:
Below is the pseudo code:
forward
This is the most debatable part. The forward function is used to shard on activations, specifically the output of the original module. As discussed in the above section, this is required to maintain the high performance. However, the output of the forward function can be anything, and therefore it’s really hard to shard.
Here is the proposed solution. Conventionally, the output usually will be:
For the 1st case, we can safely shard it. For the 2nd case, we can shard on the 0th element and warn the user to provide an output sharding function if that element is not intended. For all the other cases, we will just raise a runtime error to ask users to provide an output sharding function. That’s what
shard_outputin the above section represents.Code boilerplate
Some code boilerplate that is copied from the FSDPv1.
Manual wrapping
Here it follows the same logic as FSDPv1 where only the root module and important backbone children modules should be wrapped, like the decoder layers in LLMs. This is very important as we don’t want to register the backward optimization barrier for every linear. Otherwise, the compiler will lose a ton of optimization opportunities.
(WIP) As a function: spmd_fully_sharded_data_parallel
This will be very similar to the distribute_module, and in fact probably can be a high level wrapper on top of it. Basically, it will take care of everything within the function.
Test Plan
(WIP) PoCs Evaluation
SpmdFullyShardedDataParallel
Performance
Project Management
Milestone 1:
Milestone 2:
TBD, implement the rest of the feature set.
Open Questions
cc @JackCaoG @yeounoh @jonb377 @wconstab @wanchaol