Skip to content

[RFC] Use shard_as to improve sharding and avoid OOM #8883

@tengyifei

Description

@tengyifei

🚀 Use shard_as to improve sharding and avoid OOM

Summary

2D sharding propagation is harder than 1D sharding propagation due to
incompatible sharding. This problem is worse in a scan / XLA While op, and
the shard_as GSPMD feature seems to help.

Motivation

This proposal is primarily to improve the sharding propgation of
torch_xla.experimental.scan.

When the decoder layer is wrapped in an XLA While op through
torch_xla.experimental.scan, Llama 3 8B trains a-okay with gbs 16 on a v6e-8
TPU, but we still get a OOM when scaling to Llama 3.1 405B on v6e-256 with 2D
(FSDP + TP) sharding.

By inspecting the memory profiles, we can infer the following:

  • The OOM occurs during the scan in the backward pass (judging from the
    referenced body computation)
  • The OOM occurs because the compiler emits a convolution (convolution.171)
    whose output shape is [1, 4K, 16K].
  • That output tensor is then all-reduced over the FSDP axis (judging from the
    replica groups), keeping the shape unchanged.
  • The all-reduced tensor gets written to a [126, 4K, 16K] stacked output
    tensor. This tensor is too large to materialize in a single chip so
    compilation fails. Note that 126 is the number of layers in Llama 3.1 405B.

We deduced that the convolution before the all-reduce is computing the gradient
for the weight tensor of the
o_proj operation in self attention:

    # Code snippet of Llama self attention
    attn_output = attn_output.transpose(1, 2).contiguous()
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
    attn_output = self.o_proj(attn_output)  # <--- here
    return attn_output

During the backward pass, we will compute grad_o_proj which is a matmul of a
2D sharded input with a 2D sharded attn_output. Based on the profile, this
gradient tensor is only 1D sharded: its shape is [1, 4K, 16K], where 16K is
the size of the embedding dim. We expect it to have the shape of [1, 4K, 4K].

Breakdown of the problem

When GSPMD propagates 2D sharding annotations over a matmul, and the contraction
dim has matching sharding annotations:

$$A[M_X, N_Y] \cdot B[N_Y , M_X] = C[M_?, M_?]$$

(using scaling book sharding notation)

Dimension $N$ is contracted away. The mesh axis $X$ also disappears. Based on my
understanding of the GSPMD paper, the result will only be 1D sharded, barring
influence from any other operations. Therefore $C$ is only 1D sharded. Since $C$
is a gradient tensor and scan outputs a stacked array of all gradients for all
126 Llama 3.1 405B layers during the backward pass, this 1D sharding goes on to
"infect" the stacked array with a leading dim size of 126, resulting in an array
of shape [126, 4K, 16K], which is too large to fit in HBM.

Pitch

I followed the HLO spec the JAX implementation to add
a shard_as function to PyTorch/XLA and use it in scan during the backward pass.
PR here. shard_as will ensure that the inputs have the same sharding after GSPMD sharding propagation. Specifically, instead of scanning over the decoder layer's backward pass during the backward of scan,
we'll scan over a wrapper that adds additional sharding constraints to shard
the gradients the same way as their corresponding inputs:

# This backward pass wrapper calls the original backward pass of a layer, and then use `shard_as` to ensure that
# the carry is sharded the same as grad_carry, and the grad_x (gradient for input) is sharded the same as the
# first element of the stacked input array.
def _backward_shard_alike(carry, x, backward, init, xs):
  grad_carry, grad_x = backward(carry, x)
  # Propagate sharding between forward inputs and backward outputs.
  _, grad_carry = shard_as(init, grad_carry)
  _, grad_x = shard_as(tree_map(lambda v: v[0], xs), grad_x)
  return grad_carry, grad_x

The PR also has a unit test that checks the result of sharding propagation and
fails if we remove the shard_as usage from scan.

Alternatives

Rather than using shard_as, we could expose a keyword argument on scan that
takes in the intended sharding annotation of all the weights during the backward
pass of a layer. Potentially, the user may specify that the gradient for the
o_proj weight should be sharded a certain way. There are some drawbacks:

  • Since scan lowers the combine function using AOTAutograd into a functional
    graph, we can't tell the tensors from each other. We don't even know what is
    the variable name that corresponds to some specific output of an FX graph
    extracted by AOTAutograd.
  • SPMD and scan are orthogonal concerns and it's a code smell to expose both
    APIs in one function.

In contrast, shard_as doesn't require telling tensors apart. It just says to
constrain the sharding of the N gradient tensors to be the same as the N input
tensors.

Additional context

GSPMD sharding merging behavior

JAX shard_alike

Metadata

Metadata

Assignees

Labels

distributedSPMD and other distributed things.enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions