Fix OOM regression for FSDP2 + cpu_ram_efficient_loading on large models#45649
Fix OOM regression for FSDP2 + cpu_ram_efficient_loading on large models#45649AmineDiro wants to merge 4 commits into
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
cc @Cyrilvallez I think |
albertvillanova
left a comment
There was a problem hiding this comment.
Thanks a lot for the clear diagnosis and the fix: Skip CPU param materialization on non-rank-0 FSDP ranks to avoid OOM
The OOM regression in #45050 is real: zeros_like forces an immediate physical-memory commit (page fault on every zero write), whereas empty_like relies on overcommit/lazy allocation. Note this was already commented by @ArthurZucker: https://github.com/huggingface/transformers/pull/45050/changes#r3029107360
the reason I don't want this is because its costly!
Let me trace through the full flow after the change to confirm:
- On non-rank-0 FSDP ranks:
- Parameters stay on meta device: zero physical memory committed
- Buffers (both persistent and non-persistent) get real CPU
zeros_likeplaceholders
- Then
_initialize_missing_keys(PR #44473) marks state-dict parameters (now meta tensors) as_is_hf_initialized = True.initialize_weights()then runs: for RotaryEmbedding, inv_freq and original_inv_freq are non-persistent buffers, so they are not in state_dict(), not marked, and _init_weights correctly computes and copies their values into the real CPU zero tensors - Accelerate's
fsdp2_prepare_modelthen:- Saves non-persistent buffers (now correctly initialized by _init_weights) from each rank
- Moves the model to meta; parameters that were already on meta: no-op
- Applies fully_shard
- fsdp2_load_full_state_dict broadcasts from rank-0 into all ranks: parameters receive correct values
- Restores non-persistent buffers from each rank's saved copy
The original NaN bug is still fixed: parameters that _init_weights skips (marked as initialized) are subsequently overwritten by the broadcast with rank-0's values. The difference from #45050 is that we never pay the cost of materializing them on non-rank-0 in the first place.
The fix is correct, targeted, and eliminates the OOM without reintroducing the NaN regression (I have confirmed this). 🤗
|
@albertvillanova @AmineDiro I just pushed what is to me the correct fix - basically only move non-persistent buffers. Could you double-check? I'm not 100% familiar of how fsdp2 works |
|
Thanks for addressing this as well, @Cyrilvallez. May I recommend to run one test with FSDP2 + cpu_ram_efficient_loading on a model that has at least one persistent buffer? Maybe something like DeepSeek V3 or creating a toy model with If the forward pass succeeds and the persistent buffer has the correct values from rank-0, the fix is confirmed. If it crashes with AttributeError or produces wrong values, the |
|
@Cyrilvallez I think 5623608 should also work , this was a regression so I just fixed it back to the previous non zeroed memory. But it should work fine 👍🏼 |
|
@albertvillanova Did you have one test in mind by any chance? I don't have an env setup with fsdp rn, so would appreciate it if you can quickly try it out by any chance 🙏🤗 |
What does this PR do?
PR #45050 replaces
torch.empty_likewithtorch.zeros_likein_move_missing_keys_from_meta_to_device. While this fixes a real issue (NaN garbage in uninitialized memory), it forces a physical-memory commit of the entire model on every non-rank-0 FSDP rank.With 8 ranks per node loading a 30B model, peak cpu mem jumps from ~60 GB to ~480 GB :/
The regression was identified by bisecting transformers commits between 2026-04-10 (working) and 2026-04-22 (failing) using a 2-node FSDP2 control config:
a001f34439(pre-#45050)ff49f7c4cb(PR #45050)Test config:
Qwen/Qwen3-30B-A3B, FSDP2, 2 nodes × 8 H100, DP=16, sdpa, max_steps=5,fsdp_cpu_ram_efficient_loading=true.The placeholder values on non-rank-0 ranks for state-dict params are immediately overwritte by
fsdp2_load_full_state_dictduring accelerate's FSDP2 prepare.acceleratemoves the entire model tometadevice before sharding inaccelerate.utils.fsdp_utils.fsdp2_prepare_modelSo allocating CPU placeholders for parameters on non-rank-0 ranks is unnecessary work. The parameters can stay on meta. Btw, from what I can understand buffers (RoPE caches, attention masks, etc.) are per-rank and not part of the broadcast, so they still need real allocations.
Fixes # (issue)
Code Agent Policy
Before submitting
Who can review?
@albertvillanova @ArthurZucker