Skip to content

Basic Validation for FSDP state_dict transformations of modules with persistent buffers#93396

Closed
speediedan wants to merge 4 commits intopytorch:masterfrom
speediedan:fix_mp_buffers_fsdp_state_dict
Closed

Basic Validation for FSDP state_dict transformations of modules with persistent buffers#93396
speediedan wants to merge 4 commits intopytorch:masterfrom
speediedan:fix_mp_buffers_fsdp_state_dict

Conversation

@speediedan
Copy link
Copy Markdown
Contributor

Fixes #93391

Thank you to the PyTorch Distributed team for your invaluable contributions to the PyTorch ecosystem, your work is immensely impressive and inspiring!

As mentioned in #93391, in preparing the downstream package I maintain (finetuning-scheduler) to support PyTorch 2.0's version of FSDP, I noticed modules that include multiple persistent buffers were not having their state properly transformed during saving of state_dicts.

The issue was that the post-state_dict hook codepath shared by the FULL_STATE_DICT and SHARDED_STATE_DICT _state_dict_types (_common_unshard_post_state_dict_hook) was inadvertently referencing a local variable (buffer) that was used in a prior transformation, instead of the buffers variable that should have been referenced in the iteration context:

for buffers, clean_fqn in zip(buffers, buffer_clean_fqns):
fqn = f"{prefix}{clean_fqn}"
state_dict[fqn] = buffer.clone()

In this case, modules with a single persistent buffer or without mixed precision enabled would be unaffected. With multiple buffers and mixed precision enabled however, the issue may appear stochastically in proportion to the ratio of persistent buffers that have compatible dimensions (since the value of the last buffer visited in the buffer_names Set is copied to all buffers and the Set iteration order will of course vary)

File ".../pytorch/torch/nn/modules/module.py", line 2028, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for FullyShardedDataParallel:
    size mismatch for _fsdp_wrapped_module.1._fsdp_wrapped_module.running_mean: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([10]).

To both address this issue and enhance coverage to avoid similar issues, this PR fixes the aforementioned typo and adds an additional set of basic tests that validate state_dict saving and loading for modules with persistent buffers in various contexts.

I found that adding another model along with additional buffer-specific logic to adapt test_basic_save_and_load_state_dict for the purposes of this coverage seemed to increase complexity of that test to an undesirable degree.

Instead of adding additional complexity to that existing test, I've added a new test test_buffers_save_and_load_state_dict that does basic validation of state_dict saving and loading with mixed precision, state_dict_type and CPU offloading parameterization. Certainly let me know if you prefer I extend the logic of/add the persistent buffers model into the existing basic state_dict test, I'm happy to do so, just thought it was cleaner this way. Also, I thought doubling the number of tests with a use_orig_params parameterization or by testing additional different non-default buffer mixed precision data types was computationally imprudent but let me know if you'd like me to add those tests as well.

The only other notable test change is that I've refactored TestFSDPStateDict._compare_models to accommodate both buffers and parameters comparisons without code duplication.

Thanks again to the PyTorch Distributed team for your exceptional contributions. I've got some more to do adapting my package for 2.0's FSDP but it's been a delight so far thanks to your superlative work!

…tions of modules with persistent buffers failed with mixed precision enabled
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Jan 31, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/93396

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures

As of commit 0fd6bc3:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Jan 31, 2023
@speediedan speediedan marked this pull request as ready for review February 1, 2023 01:53
@awgu awgu requested a review from fegin February 1, 2023 02:07
@drisspg drisspg added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 1, 2023
Copy link
Copy Markdown
Contributor

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks! Just some minor questions / comments after which we can merge

assert_fn(state_base, state_new)

def _compare_models(
self, model, model_new, assert_fn, check_fp16=False, check_buffers=False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: is there an issue with setting check_buffers=True by default to enable it for all tests that use this helper?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea! I was being a little overcautious in my attempt to minimize the changes I introduced. As long as we add a guard to the buffer comparison in _compare_models to ensure there are buffers to compare (many of the current test models don't have any) we can set check_buffers=True to the default.
Note, I'm only checking persistent buffers given this is a state_dict inspection but there is at least one line of code that checks for non-persistent buffers that won't be covered (maybe best for a different PR if we want to tackle)

in the context of non-default mixed precision, different ``state_dict_type`` s and CPU offloading.
"""
if state_dict_rank0_and_offload and state_dict_type != "state_dict":
return # not supported
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to fix in this PR, but this sort of pattern is not ideal from a TTS perspective, i.e. we still spawn N processes, initialize NCCL, etc only to return and tear everything down, paying unnecessary overhead.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed! Note, this (anti)pattern is used in test_state_dict_with_manual_ac_wrapper and test_basic_save_and_load_state_dict as well but those similar TTS issues can be addressed in the context of the refactoring options @awgu outlines below (either of which seem useful and worth doing!).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One easy way to fix this anti pattern is to consolidate the usage of StateDictType and StateDictConfig. This worth another PR to fix. Thanks for pointing out the issue.

if mixed_precision
else None
)
model_call = partial(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the need for partial if we're directly using this result after?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am guessing this partial application is being saved for the additional model_new = model_call() below.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I was using it in model_new but can switch to avoid using partial pass the the model function directly if you guys prefer that.

model_call = partial(
self._get_multibuffer_nested_model,
cpu_offload=cpu_offload,
use_orig_params=False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awgu Is it worth testing use_orig_params=True path as well, i.e. is the checkpointing implementation different enough to justify the TTS increase?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should test use_orig_params=True even if it doubles the unit test time for this added unit test since this is the code path we want to converge to going forward.

I think if we are to address the high time-to-signal (TTS) issue, we need to do it from the ground up rather than being cautious about testing useful configs:

  1. Refactor existing unit tests to avoid redundancy.
  2. Use subtests as much as possible while actually ensuring that failing subtests report the config that failed.

1 requires some effort from our side, which I was thinking would be justifiable as we invest in fully_shard. The latter part of 2 has not been done yet. The main tricky part I need to figure out is how to handle the self.assertRaises... cases vs. regular non-erroring cases.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent point @awgu. I've gone ahead and extended test_buffers_save_and_load_state_dict to test use_orig_params as recommended and will push that commit now.

nn.BatchNorm1d(10).cuda(),
nn.Linear(10, 10, bias=False).cuda(),
)
return model
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just for my understanding, where do the multi buffers come from? Is it the multiple BN units?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "multi-buffers" is just referencing the multiple buffers in the BN module (e.g. 'running_var', 'running_mean' etc.). I thought it was worth emphasizing/testing particularly because the bug here does not rear its ugly head for modules with a single buffer.

buffers, buffer_dtypes, fsdp_state.compute_device
)
for buffers, clean_fqn in zip(buffers, buffer_clean_fqns):
for buffer, clean_fqn in zip(buffers, buffer_clean_fqns):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oof, this was some nasty variable shadowing combined with an unfortunate typo. Thanks for the catch!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was gnarly indeed! Also, its stochastic nature in the context I was testing both ultimately helped narrow down the issue but also made it initially all the more mysterious. Again, thank you to you both for your impressive contributions (both in volume and quality!), it's a real pleasure to get to collaborate with and learn from you.

Copy link
Copy Markdown
Collaborator

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me! Feel free to merge after pushing your most recent local changes.

Thanks for the thorough investigation and high-quality contribution :)

…ms` scenario. Make new `check_buffers` option in `_compare_models` default to `True`.
Copy link
Copy Markdown
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

@parametrize("state_dict_rank0_and_offload", [True, False])
def test_buffers_save_and_load_state_dict(
self,
state_dict_type: StateDictType,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The typing is str not StateDictType. This is a legacy issue that multiple ways to declare state_dict_type exist in the same test.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch! I've now updated for both this test and test_basic_save_and_load_state_dict.

in the context of non-default mixed precision, different ``state_dict_type`` s and CPU offloading.
"""
if state_dict_rank0_and_offload and state_dict_type != "state_dict":
return # not supported
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One easy way to fix this anti pattern is to consolidate the usage of StateDictType and StateDictConfig. This worth another PR to fix. Thanks for pointing out the issue.

@speediedan
Copy link
Copy Markdown
Contributor Author

This looks good to me! Feel free to merge after pushing your most recent local changes.

Thanks for the thorough investigation and high-quality contribution :)

My pleasure! That means a lot coming for you. I look forward to seeing how FSDP evolves and contributing at some point again.

…ate_dict` and `test_buffers_save_and_load_state_dict` tests.
@speediedan
Copy link
Copy Markdown
Contributor Author

speediedan commented Feb 1, 2023

Looks like there is 1 (seemingly) unrelated test failure but before attempting a pytorchbot merge -f ... I wanted to check with the distributed team.

If you guys consent I'm happy to try the force merge. I most likely don't have permission to execute the force merge though so if any of you deem appropriate, please feel free to execute it on my behalf. Thanks!
@rohan-varma @fegin @awgu

@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 2, 2023
@awgu
Copy link
Copy Markdown
Collaborator

awgu commented Feb 2, 2023

Looks like there is 1 (seemingly) unrelated test failure but before attempting a pytorchbot merge -f ... I wanted to check with the distributed team.

If you guys consent I'm happy to try the force merge. I most likely don't have permission to execute the force merge though so if any of you deem appropriate, please feel free to execute it on my behalf. Thanks! @rohan-varma @fegin @awgu

Sounds good. I added the ciflow/trunk tag (which normally gets triggered when we try to merge). We can force merge the PR on your behalf after those additional test finish running if all failures are unrelated.

@awgu
Copy link
Copy Markdown
Collaborator

awgu commented Feb 2, 2023

linux-bionic-py3_8-clang8-xla / test (xla, 1, 1, linux.4xlarge) failure looks unrelated:

======================================================================
ERROR [0.010s]: test_view_copy_out_xla (__main__.TestViewOpsXLA)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/torch/testing/_internal/common_device_type.py", line 414, in instantiated_test
    raise rte
  File "/opt/conda/lib/python3.8/site-packages/torch/testing/_internal/common_device_type.py", line 401, in instantiated_test
    result = test(self, **param_kwargs)
  File "/var/lib/jenkins/workspace/xla/test/../../test/test_view_ops.py", line 949, in test_view_copy_out
    torch.split_copy(a, 2, out=(out1, out2))
RuntimeError: Expected out tensor to have device cpu, but got xla:0 instead

----------------------------------------------------------------------
Ran 166 tests in 1.178s

FAILED (errors=1, skipped=124, expected failures=3)

@awgu
Copy link
Copy Markdown
Collaborator

awgu commented Feb 2, 2023

@pytorchbot merge -f "unrelated xla failures"

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

ragulpr added a commit to ragulpr/pytorch that referenced this pull request Feb 2, 2023
…n-dev-setup

* origin: (898 commits)
  Move dynamo.optimizations.distributed to backends (pytorch#93408)
  Remove cuda 11.6 from nightly (pytorch#93979)
  Refactor dynamo register_backend/BACKENDS (pytorch#93389)
  Remove cuda 11.6 from CI replace with 11.7 (pytorch#93406)
  [Dynamo] Rename `GuardBuilder.guarded_code` -> `check_fn_manager` (pytorch#93934)
  Revert "Remove CUDA 11.6 from nightly builds (pytorch#93404)"
  Revert "[inductor] fix crash issue when input is a view tensor (pytorch#90150)"
  Basic Validation for FSDP `state_dict` transformations of modules with persistent buffers (pytorch#93396)
  Merge Inductor perf smoke test with other inductor CI tests (pytorch#93395)
  [inductor] Don't import torchvision (pytorch#93027)
  [FSDP][3/N] Refactor `summon_full_params` unit tests (pytorch#92298)
  [FSDP][2/N] `_summon_full_params` -> `_unshard_params` (pytorch#92297)
  Remove CUDA 11.6 from nightly builds (pytorch#93404)
  Mark buffers that reuse other buffers (pytorch#93329)
  Refactor to allow reuse of SchedulerNode.allocate (pytorch#93328)
  retire sparse_mask_helper (pytorch#91714)
  update fbgemm third party (pytorch#93907)
  [inductor] fix crash issue when input is a view tensor (pytorch#90150)
  [Inductor] add config for weight prepacking (pytorch#93811)
  Check for none for NNModuleVariable.__module__ (pytorch#93326)
  ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: distributed (fsdp) release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

FSDP state_dict transformations of modules with persistent buffers fail with mixed precision enabled

7 participants