Skip to content

[RFC] Model Sharding for distributed training #55207

@pritamdamania87

Description

@pritamdamania87

🚀 Feature

Provide a set of building blocks and APIs for PyTorch users to shard models easily for distributed training.

Motivation

There is a need to provide a standardized sharding mechanism in PyTorch. There are several types of model parallelism paradigms which require sharding (ex: pipeline parallelism, intra-layer parallelism etc.). The motivation of this feature is to provide a standardized sharding spec and a few building blocks for users to deal with sharding PyTorch models.

Pitch

Concrete goals for this proposal:

  1. Provide a standardized sharding specification in PyTorch which can be used to express how a model needs to be sharded across a set of nodes.
  2. Provide a simple barebones ShardedTensor abstraction for Tensors that might be sharded across different devices.
  3. Provide a way to save/load such sharded models for checkpointing.
  4. Provide a way to automatically shard a model, given a sharding specification.

Placement/Sharding Specification

The Placement/Sharding specification is a specification about how Tensors/Modules need to be placed on a set of devices.

SingleSplitShardingSpec

This is a straightforward sharding spec which covers most of the sharding requirements and provides a very simple easy to understand specification. This specification describes how a Tensor can be sharded across a single dimension.

# Interface defining how the Tensor should be placed on a set of devices. 
class PlacementSpec(object):
   pass
   
class DevicePlacement(PlacementSpec):
   # Device where a shard should be placed on in 
   # RemoteDevice format: https://github.com/pytorch/pytorch/issues/46554
   self.device: str
       
class SingleSplitShardingSpec(PlacementSpec):
    # List of placement of shards for the tensor should be sharded. Each 
    # entry in the List refers to one shard and defines where that shard 
    # should be placed.
    self.shard_placement: List[PlacementSpec]
    
    # A single integer indicating the dimension along which the tensor should be 
    # split. Could be dimension names as well following: 
    # https://pytorch.org/docs/stable/named_tensor.html
    self.split_dimension: int/str

    # Optional parameter to set the size of each shard (for uneven splits). 
    # If not specified, the split is even. Must be same length as self.devices.
    self.shard_sizes: List[int]
          
# Let's say we have a large tensor and we'd like to shard it row-wise across
# 4 GPUs on a single host. The sharding spec would be as follows:
spec = SingleSplitShardingSpec()
spec.devices = [
    "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3"
]
spec.sharding_def.split_dimension = 0

MeshShardingSpec

SingleSplitShardingSpec would cover most of the sharding requirements however it doesn't allow for arbitrary sharding specifications. MeshShardingSpec aims to address this issue by providing a generic sharding specification to arbitrarily shard a tensor.

class MeshShardingSpec(PlacementSpec):
    # Multidimensional array with the same number of dimensions as the tensor 
    # to be sharded. Each element of the array represents a ShardPlacement 
    # where a particular partition of the tensor should be placed. The array 
    # represents how the tensor should be partitioned across multiple 
    # placement options. The number of elements in the array represents the 
    # total number of partitions.
    self.shard_placement: ShardPlacement.    
    
# Let's say we have a tensor of shape [4 x 8] as follows:
# 1 1 2 2 3 3 4 4
# 1 1 2 2 3 3 4 4
# 5 5 6 6 7 7 8 8
# 5 5 6 6 7 7 8 8
# And our, shard_placement array is [2 x 4] as follows:
[
    ["rank:0/cuda:0", "rank:0/cuda:1", "rank:0/cuda:2", "rank:0/cuda:3"],
    ["rank:1/cuda:4", "rank:1/cuda:5", "rank:1/cuda:6", "rank:1/cuda:7"],
] 

# Then each device gets a [2 x 2] slice of the tensor, where cuda:0 gets:
# 1 1
# 1 1
#, cuda:1 gets:
# 2 2
# 2 2
#, cuda:4 gets,
# 5 5
# 5 5
# and so on....

Helper APIs

This section defines a bunch of helper APIs to easily instantiate the sharding specs:

torch.distributed.sharding_spec.single_split(
    split_dimension, 
    shard_placement: List[PlacementSpec],
    shard_sizes = None,
) -> SingleSplitShardingSpec

# Example:
torch.distributed.sharding_spec.single_split(
    0, ["rank:0/cuda:0", "rank:1/cuda:1"])

torch.distributed.sharding_spec.mesh(
    shard_placement: Array[str]
) -> MeshShardingSpec

# Example:
torch.distributed.sharding_spec.mesh(
    [
        ["rank:0/cuda:0", "rank:0/cuda:1", "rank:0/cuda:2", "rank:0/cuda:3"],
        ["rank:1/cuda:4", "rank:1/cuda:5", "rank:1/cuda:6", "rank:1/cuda:7"],
    ]
)

Composite ShardingSpec

SingleSplitShardingSpec and MeshShardingSpec address the needs for providing a specification for sharding individual Tensors. These specifications can also be applied to an entire module or a module hierarchy applying the specification recursively to all parameters of the module. However, there would be cases where users would like to specify different placement/sharding specs for different parameters/submodules of a module. This is where composite sharding comes into play.

CompositeShardingSpec(PlacementSpec):
    # Provides a mapping from "parameter name" to appropriate 
    # ShardingSpec/PlacementSpec
    param_placement: Dict[str, PlacementSpec]

# Helper API
# Accepts **kwargs of parameter name to PlacementSpec
torch.distributed.sharding_spec.composite(**kwargs)

# Example 1: Shard a module by placing parameters on different devices

class MyModel(nn.Module):
    def __init__(self):
        self.net1 = nn.Linear(20, 20)
        self.net2 = nn.Linear(20, 20)
        self.net3 = nn.Linear(20, 20)
        
torch.distributed.sharding_spec.composite(
    net1="cuda:0", net2="cuda:1", net3="cuda:0",
)

# Example 2: Apply sharding spec to submodules

from torch.distributed.sharding_spec import (composite, single_split)

class MyModel(nn.Module):
    def __init__(self):
        self.net1 = nn.Linear(20, 20)
        self.net2 = nn.Linear(20, 20)
        self.net3 = nn.Linear(20, 20)
        
# Spec for net1 and net2
composite_spec = composite(weight=single_split(0, ["cuda:0", "cuda:1"]))

# Spec for entire module
composite(
    net1=composite_spec, 
    net2=composite_spec, 
    net3=single_split(0, ["cuda:2", "cuda:3"])
)

ShardedTensor

ShardedTensor is an abstraction to simply describe how a Tensor is sharded across multiple devices. ShardedTensor is not designed to behave like a Tensor and support all torch operations, instead it exposes the sharding metadata to users so that they can manipulate and work with the shards themselves. Note that this puts overhead on the user in terms of dealing with how sharded computations should work, but is important for flexibility for power users who want tight control over how to deal with sharded computations. The section below on "Sharding with torch.fx" will go over additional details of how we can automate some of the sharded computations using torch.fx for PyTorch users who do no want to worry about sharded computations.

ShardedTensor is initialized and used in an SPMD like fashion. Basically each node has an instance of a ShardedTensor which holds the local shard and also the global information for the entire Tensor (ex: all shards and their remote devices).

Creating a ShardedTensor

# Similar to torch.empty, if needed we can have additional creation ops like 
# torch.ones/torch.zeros etc.
# This is done in SPMD fashion and needs to be called on all ranks. Each rank 
# will instantiate its local shard based on the ShardingSpec given.
torch.distributed.sharded_tensor.empty(
    *size, *, *sharding_spec**: ShardingSpec*, dtype=None, 
    requires_grad=False, pin_memory=False, names=None) → ShardedTensor

Shard torch.Tensor

class MyModel(nn.Module):
    def __init__(self):
        self.fc1 = nn.Linear(10, 10)
        self.fc2 = nn.Linear(10, 10)
        
model = MyModel()

# SPMD called on all ranks, rank 0 is used as the authoritative source of the 
# data and appropriate shards are broadcast to each rank.
# model.fc1.weight would be replaced with a ShardedTensor on each rank.
torch.distributed.shard_tensor(model.fc1.weight, sharding_spec)

Resharding ShardedTensor

# Need to be called in SPMD fashion on all ranks. Reshards the ShardedTensor 
# on which this method is called.
sharded_tensor = torch.distributed.sharded_tensor.empty(...)

sharded_tensor.reshard(
    sharding_spec: ShardingSpec
) -> ShardedTensor

Saving/Loading ShardedTensor

ShardedTensors can be registered as parameters of the nn.Module such that module.state_dict() has an entry as follows: <param_name> : ShardedTensor. Then using the sharded tensor metadata, users can save and load sharded tensors as they wish during checkpoints.

model = my_sharded_model()

# On all ranks in SPMD fashion:

state_dict = my_sharded_model.state_dict()
# Save the state_dict on all ranks.
my_custom_save_model(state_dict, storage)

# Load the model from a checkpoint
state_dict = custom_load_state_dict(storage)
model = custom_init_model(state_dict)

# Now, if we'd like to modify sharding before training:
model.sharded_tensor1.reshard(sharding_spec)

# Using torch.save/torch.load:

model = my_sharded_model()

# On all ranks in SPMD fashion:

state_dict = my_sharded_model.state_dict()
# Save the state_dict on all ranks.
torch.save(state_dict, file)

# Load the model from a checkpoint
model.load_state_dict(torch.load(file))

The state_dict for ShardedTensor would consist only its local shard and the sharding metadata (basically contents of the local ShardedTensor object). If users would like to collect everything on a single node and save the model in a single file. We could have dedicated APIs for that:

# Returns a consolidated state_dict on rank 0 after collecting all shards of 
# all ShardedTensors.
torch.distributed.collect_state_dict(model.state_dict()) -> state_dict

# The state dict can be then loaded on a single rank, non-sharded weights 
# replicated and custom sharding spec can be applied to Tensors to shard them 
# using "torch.distributed.shard_tensor".

Retrieving ShardedTensor metadata

We could support certain native ops on ShardedTensors, however it’s probably more important to have APIs to expose the sharding information to the end user. With this information, users can implement their own custom logic to deal with sharded tensors using collectives.

# Retrieves the Tensor representing the local shard of the ShardedTensor
sharded_tensor.local_shard() -> Tensor

# Retrieves the sharding metadata for the Tensor, which outlines how the 
# Tensor is sharded across all devices.
sharded_tensor.sharding_metadata() -> ShardedTensorMetadata

class ShardedTensorMetadata:
    self.shards: List[ShardMetadata]

# Metadata for a single shard.
class ShardMetadata:
    # Dimensions of the tensor representing this shard.
    self.dims: List[int],
    # Device that this shard is placed on (in RemoteDevice format:  https://github.com/pytorch/pytorch/issues/46554).
    self.device: str,
    # Starting offsets for each dimension of this shard in the 
    # ShardedTensor. Should be same size as self.dims. self.offsets 
    # combined with self.dims accurately identifies which partition of the 
    # Tensor this shard represents.
    self.offsets: List[int],

Sharding with torch.fx

Virtual Device

A key challenge of training large models in PyTorch today is that PyTorch modules need to be allocated on some device (typically done on CPU) before they are placed on appropriate devices for model parallelism. First of all this is wasteful since you need to allocate memory on the CPU and then transfer all that data to the appropriate device. Secondly, in case of very large models, there might not even be enough CPU memory to host the entire model.

To address this limitation, the key idea is to introduce a virtual/meta device concept in PyTorch which is a device that can be used to initialize a PyTorch module such that it doesn't allocate any memory. Then the module can be moved to appropriate intended device and initialized accordingly. One idea is to introduce a new nn.Module contract via two parameters (another under discussion is using a context manager for the device):

  • reset_parameters: If this is False (default to True), the nn.Module should not initialize its parameters, but can allocate them.
  • device: If this device is specified, all parameters for the module need to be placed on this device. The device can be “meta” which initializes all Tensors as MetaTensors (which doesn’t allocate any memory).

As a result, by specifying device=“meta” and reset_parameters=False, we can initialize nn.Modules without allocating memory for any tensors or initializing them.

torch.fx transformations

Once we have a module initialized on a "meta" device, we can apply one of the sharding specifications mentioned above to shard the model and produce ShardedTensors. To ensure users don't have to worry about dealing with ShardedTensors and how to write their own sharded computations, the plan is to use torch.fx to trace the module and inject appropriate communication operations to transparently shard PyTorch modules. A very simple prototype demonstrating this idea for nn.Linear can be found here: https://gist.github.com/pritamdamania87/368d249abe835cd0a2eac201efbe373a

As a result, the long term vision is to provide an API like this for PyTorch users which can automatically shard a model for them without having to worry about how to deal with sharded computations:

torch.distributed.shard_model(model: nn.Module, spec: PlacementSpec)

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @agolynski @SciPioneer @H-Huang @mrzzd @cbalioglu

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions