Basic Validation for FSDP state_dict transformations of modules with persistent buffers#93396
Basic Validation for FSDP state_dict transformations of modules with persistent buffers#93396speediedan wants to merge 4 commits intopytorch:masterfrom
state_dict transformations of modules with persistent buffers#93396Conversation
…tions of modules with persistent buffers failed with mixed precision enabled
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/93396
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 FailuresAs of commit 0fd6bc3: NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
rohan-varma
left a comment
There was a problem hiding this comment.
Looks great, thanks! Just some minor questions / comments after which we can merge
| assert_fn(state_base, state_new) | ||
|
|
||
| def _compare_models( | ||
| self, model, model_new, assert_fn, check_fp16=False, check_buffers=False |
There was a problem hiding this comment.
qq: is there an issue with setting check_buffers=True by default to enable it for all tests that use this helper?
There was a problem hiding this comment.
Great idea! I was being a little overcautious in my attempt to minimize the changes I introduced. As long as we add a guard to the buffer comparison in _compare_models to ensure there are buffers to compare (many of the current test models don't have any) we can set check_buffers=True to the default.
Note, I'm only checking persistent buffers given this is a state_dict inspection but there is at least one line of code that checks for non-persistent buffers that won't be covered (maybe best for a different PR if we want to tackle)
| in the context of non-default mixed precision, different ``state_dict_type`` s and CPU offloading. | ||
| """ | ||
| if state_dict_rank0_and_offload and state_dict_type != "state_dict": | ||
| return # not supported |
There was a problem hiding this comment.
no need to fix in this PR, but this sort of pattern is not ideal from a TTS perspective, i.e. we still spawn N processes, initialize NCCL, etc only to return and tear everything down, paying unnecessary overhead.
There was a problem hiding this comment.
Agreed! Note, this (anti)pattern is used in test_state_dict_with_manual_ac_wrapper and test_basic_save_and_load_state_dict as well but those similar TTS issues can be addressed in the context of the refactoring options @awgu outlines below (either of which seem useful and worth doing!).
There was a problem hiding this comment.
One easy way to fix this anti pattern is to consolidate the usage of StateDictType and StateDictConfig. This worth another PR to fix. Thanks for pointing out the issue.
| if mixed_precision | ||
| else None | ||
| ) | ||
| model_call = partial( |
There was a problem hiding this comment.
what's the need for partial if we're directly using this result after?
There was a problem hiding this comment.
I am guessing this partial application is being saved for the additional model_new = model_call() below.
There was a problem hiding this comment.
Yep, I was using it in model_new but can switch to avoid using partial pass the the model function directly if you guys prefer that.
| model_call = partial( | ||
| self._get_multibuffer_nested_model, | ||
| cpu_offload=cpu_offload, | ||
| use_orig_params=False, |
There was a problem hiding this comment.
@awgu Is it worth testing use_orig_params=True path as well, i.e. is the checkpointing implementation different enough to justify the TTS increase?
There was a problem hiding this comment.
I think we should test use_orig_params=True even if it doubles the unit test time for this added unit test since this is the code path we want to converge to going forward.
I think if we are to address the high time-to-signal (TTS) issue, we need to do it from the ground up rather than being cautious about testing useful configs:
- Refactor existing unit tests to avoid redundancy.
- Use subtests as much as possible while actually ensuring that failing subtests report the config that failed.
1 requires some effort from our side, which I was thinking would be justifiable as we invest in fully_shard. The latter part of 2 has not been done yet. The main tricky part I need to figure out is how to handle the self.assertRaises... cases vs. regular non-erroring cases.
There was a problem hiding this comment.
Excellent point @awgu. I've gone ahead and extended test_buffers_save_and_load_state_dict to test use_orig_params as recommended and will push that commit now.
| nn.BatchNorm1d(10).cuda(), | ||
| nn.Linear(10, 10, bias=False).cuda(), | ||
| ) | ||
| return model |
There was a problem hiding this comment.
just for my understanding, where do the multi buffers come from? Is it the multiple BN units?
There was a problem hiding this comment.
The "multi-buffers" is just referencing the multiple buffers in the BN module (e.g. 'running_var', 'running_mean' etc.). I thought it was worth emphasizing/testing particularly because the bug here does not rear its ugly head for modules with a single buffer.
| buffers, buffer_dtypes, fsdp_state.compute_device | ||
| ) | ||
| for buffers, clean_fqn in zip(buffers, buffer_clean_fqns): | ||
| for buffer, clean_fqn in zip(buffers, buffer_clean_fqns): |
There was a problem hiding this comment.
Oof, this was some nasty variable shadowing combined with an unfortunate typo. Thanks for the catch!
There was a problem hiding this comment.
It was gnarly indeed! Also, its stochastic nature in the context I was testing both ultimately helped narrow down the issue but also made it initially all the more mysterious. Again, thank you to you both for your impressive contributions (both in volume and quality!), it's a real pleasure to get to collaborate with and learn from you.
awgu
left a comment
There was a problem hiding this comment.
This looks good to me! Feel free to merge after pushing your most recent local changes.
Thanks for the thorough investigation and high-quality contribution :)
…ms` scenario. Make new `check_buffers` option in `_compare_models` default to `True`.
| @parametrize("state_dict_rank0_and_offload", [True, False]) | ||
| def test_buffers_save_and_load_state_dict( | ||
| self, | ||
| state_dict_type: StateDictType, |
There was a problem hiding this comment.
The typing is str not StateDictType. This is a legacy issue that multiple ways to declare state_dict_type exist in the same test.
There was a problem hiding this comment.
good catch! I've now updated for both this test and test_basic_save_and_load_state_dict.
| in the context of non-default mixed precision, different ``state_dict_type`` s and CPU offloading. | ||
| """ | ||
| if state_dict_rank0_and_offload and state_dict_type != "state_dict": | ||
| return # not supported |
There was a problem hiding this comment.
One easy way to fix this anti pattern is to consolidate the usage of StateDictType and StateDictConfig. This worth another PR to fix. Thanks for pointing out the issue.
My pleasure! That means a lot coming for you. I look forward to seeing how FSDP evolves and contributing at some point again. |
…ate_dict` and `test_buffers_save_and_load_state_dict` tests.
|
Looks like there is 1 (seemingly) unrelated test failure but before attempting a If you guys consent I'm happy to try the force merge. I most likely don't have permission to execute the force merge though so if any of you deem appropriate, please feel free to execute it on my behalf. Thanks! |
Sounds good. I added the ciflow/trunk tag (which normally gets triggered when we try to merge). We can force merge the PR on your behalf after those additional test finish running if all failures are unrelated. |
|
linux-bionic-py3_8-clang8-xla / test (xla, 1, 1, linux.4xlarge) failure looks unrelated: |
|
@pytorchbot merge -f "unrelated xla failures" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…n-dev-setup * origin: (898 commits) Move dynamo.optimizations.distributed to backends (pytorch#93408) Remove cuda 11.6 from nightly (pytorch#93979) Refactor dynamo register_backend/BACKENDS (pytorch#93389) Remove cuda 11.6 from CI replace with 11.7 (pytorch#93406) [Dynamo] Rename `GuardBuilder.guarded_code` -> `check_fn_manager` (pytorch#93934) Revert "Remove CUDA 11.6 from nightly builds (pytorch#93404)" Revert "[inductor] fix crash issue when input is a view tensor (pytorch#90150)" Basic Validation for FSDP `state_dict` transformations of modules with persistent buffers (pytorch#93396) Merge Inductor perf smoke test with other inductor CI tests (pytorch#93395) [inductor] Don't import torchvision (pytorch#93027) [FSDP][3/N] Refactor `summon_full_params` unit tests (pytorch#92298) [FSDP][2/N] `_summon_full_params` -> `_unshard_params` (pytorch#92297) Remove CUDA 11.6 from nightly builds (pytorch#93404) Mark buffers that reuse other buffers (pytorch#93329) Refactor to allow reuse of SchedulerNode.allocate (pytorch#93328) retire sparse_mask_helper (pytorch#91714) update fbgemm third party (pytorch#93907) [inductor] fix crash issue when input is a view tensor (pytorch#90150) [Inductor] add config for weight prepacking (pytorch#93811) Check for none for NNModuleVariable.__module__ (pytorch#93326) ...
Fixes #93391
Thank you to the PyTorch Distributed team for your invaluable contributions to the PyTorch ecosystem, your work is immensely impressive and inspiring!
As mentioned in #93391, in preparing the downstream package I maintain (finetuning-scheduler) to support PyTorch 2.0's version of FSDP, I noticed modules that include multiple persistent buffers were not having their state properly transformed during saving of
state_dicts.The issue was that the post-state_dict hook codepath shared by the
FULL_STATE_DICTandSHARDED_STATE_DICT_state_dict_types (_common_unshard_post_state_dict_hook) was inadvertently referencing a local variable (buffer) that was used in a prior transformation, instead of thebuffersvariable that should have been referenced in the iteration context:pytorch/torch/distributed/fsdp/_state_dict_utils.py
Lines 251 to 253 in 332d55d
In this case, modules with a single persistent buffer or without mixed precision enabled would be unaffected. With multiple buffers and mixed precision enabled however, the issue may appear stochastically in proportion to the ratio of persistent buffers that have compatible dimensions (since the value of the last buffer visited in the
buffer_namesSetis copied to all buffers and theSetiteration order will of course vary)To both address this issue and enhance coverage to avoid similar issues, this PR fixes the aforementioned typo and adds an additional set of basic tests that validate
state_dictsaving and loading for modules with persistent buffers in various contexts.I found that adding another model along with additional buffer-specific logic to adapt
test_basic_save_and_load_state_dictfor the purposes of this coverage seemed to increase complexity of that test to an undesirable degree.Instead of adding additional complexity to that existing test, I've added a new test
test_buffers_save_and_load_state_dictthat does basic validation ofstate_dictsaving and loading with mixed precision,state_dict_typeand CPU offloading parameterization. Certainly let me know if you prefer I extend the logic of/add the persistent buffers model into the existing basicstate_dicttest, I'm happy to do so, just thought it was cleaner this way. Also, I thought doubling the number of tests with ause_orig_paramsparameterization or by testing additional different non-default buffer mixed precision data types was computationally imprudent but let me know if you'd like me to add those tests as well.The only other notable test change is that I've refactored
TestFSDPStateDict._compare_modelsto accommodate bothbuffersandparameterscomparisons without code duplication.Thanks again to the PyTorch Distributed team for your exceptional contributions. I've got some more to do adapting my package for 2.0's FSDP but it's been a delight so far thanks to your superlative work!