[DSD] Fix to remove non_persistent buffer in distributed state dict#125337
[DSD] Fix to remove non_persistent buffer in distributed state dict#125337fegin wants to merge 3 commits intogh/fegin/234/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125337
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 1 Unrelated FailureAs of commit bab5c42 with merge base 746da87 ( NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| "dont_save_me", torch.rand(100, device="cuda"), persistent=False | ||
| ) | ||
| ddp_model = DDP(copy.deepcopy(model)) | ||
| set_model_state_dict(ddp_model, get_model_state_dict(ddp_model)) |
There was a problem hiding this comment.
IIUC, set_model_state_dict(module, get_model_state_dict(module)) should be a no-op. Is this just testing that set_model_state_dict() does not error?
There was a problem hiding this comment.
Yes, just ensure that there is no error for set when there is non_persistent buffer. The actual value comparison to the single rank model is done below.
| for name, obj in chain( | ||
| module.named_buffers(recurse=False), module.named_parameters(recurse=False) | ||
| ): | ||
| if name in module._non_persistent_buffers_set: |
There was a problem hiding this comment.
I might have missed some discussion. Could you remind me why we use named_buffers() rather than some logic that relies only on the keys in the state dict itself?
There was a problem hiding this comment.
It will trigger all_gather for FSDP. Since many users still use FSDP not FSDP2, we will have to ensure no performance penalty for this API.
There was a problem hiding this comment.
I see. The issue is that both full and sharded state dict all-gather?
|
@pytorchbot merge -f "The failing tests are not related." |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…ytorch#125337) Summary: Fixes pytorch#122792 state_dict includes only persistent buffers, while named_buffers() would include non_persistent buffers. Pull Request resolved: pytorch#125337 Approved by: https://github.com/awgu ghstack dependencies: pytorch#125333, pytorch#125501, pytorch#125334, pytorch#125335, pytorch#125336
…125337) (#127219) * [DSD] Fix to remove non_persistent buffer in distributed state dict (#125337) Summary: Fixes #122792 state_dict includes only persistent buffers, while named_buffers() would include non_persistent buffers. Pull Request resolved: #125337 Approved by: https://github.com/awgu ghstack dependencies: #125333, #125501, #125334, #125335, #125336 * lintrunner * lint --------- Co-authored-by: Chien-Chin Huang <chienchin@fb.com> Co-authored-by: Andrey Talman <atalman@fb.com>
Stack from ghstack (oldest at bottom):
Summary:
Fixes #122792
state_dict includes only persistent buffers, while named_buffers() would
include non_persistent buffers.
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @LucasLLC