Skip to content

Fsdp2 fully_shard embedding and norm #4015

Merged
SunMarc merged 4 commits into
mainfrom
fsdp2-smart-root-reshard
Apr 24, 2026
Merged

Fsdp2 fully_shard embedding and norm #4015
SunMarc merged 4 commits into
mainfrom
fsdp2-smart-root-reshard

Conversation

@SunMarc

@SunMarc SunMarc commented Apr 22, 2026

Copy link
Copy Markdown
Member

What does this PR do?

This PR updates how we fully shard the model. We create two new units:

  • one with the embedding (with reshard_after_forward=value passed by the user which is usually True)
  • one with the final norm + output_embedding (with reshard_after_forward=False hardcoded because we need to gather again for the backward right after)

SunMarc added 2 commits April 22, 2026 14:56
…rward

The FSDP2 path in Accelerate currently passes the plugin's
`reshard_after_forward` (default `True`) uniformly to every
`fully_shard()` call, including the final call on the whole model that
creates the root unit. PyTorch's default (`None`) is smarter: it
resolves to `True` for non-root units and `False` for the root, which
avoids an unhideable pre-backward all-gather of the root's leftover
params (embeddings, final norm, lm_head) that would otherwise be
resharded and immediately re-gathered.

Stripping the kwarg from the root call lets PyTorch apply its heuristic.
Peak memory is unchanged on typical workloads (dominated by backward,
where the root is gathered in either case); we save one all-gather per
step on the root unit. This also matches torchtitan's default wrapping,
which similarly lets the root call fall back to `None`.
@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec

Copy link
Copy Markdown
Member

lgtm! for context, it was flagged in huggingface/trl#5575 where we measured a significant slowdown because of reshard_after_forward=True being true by default

@SunMarc SunMarc changed the title Fsdp2 smart root reshard Fsdp2 fully_shard embedding and norm Apr 23, 2026
@SunMarc SunMarc merged commit 60b5c25 into main Apr 24, 2026
27 of 29 checks passed
@SunMarc SunMarc deleted the fsdp2-smart-root-reshard branch April 24, 2026 14:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants