Skip to content

fully_shard load state_dict#90945

Closed
rohan-varma wants to merge 9 commits intogh/rohan-varma/627/basefrom
gh/rohan-varma/627/head
Closed

fully_shard load state_dict#90945
rohan-varma wants to merge 9 commits intogh/rohan-varma/627/basefrom
gh/rohan-varma/627/head

Conversation

@rohan-varma
Copy link
Contributor

@rohan-varma rohan-varma commented Dec 15, 2022

Stack from ghstack (oldest at bottom):

Ensures that load_state_dict for fully_shard works:

  • Don't add back FSDP prefix
  • Small fix to ensure mixed precision check for buffers work

Follow ups:

  • state_dict_type does not work, blocking rank0_only and CPU offload as well as other state dict implementations
  • No testing when wrapped with AC, using mixed precision, integration with distributed checkpoint, etc.

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 15, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90945

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 298681f:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Dec 15, 2022
rohan-varma added a commit that referenced this pull request Dec 15, 2022
ghstack-source-id: 1aa5381
Pull Request resolved: #90945
@rohan-varma rohan-varma changed the title [WIP] fully_shard load state_dict fully_shard load state_dict Dec 15, 2022
[test_name_mapping[str(s)] if s is not None else "none" for s in args]
)

def _broadcast_state_dict(rank, state_dict):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not strictly needed right now but will be used in composable tests.

Copy link
Collaborator

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! There a few to-dos you left for yourself. Feel free to address those before landing.

@skip_if_lt_x_gpu(2)
def test_state_dict_save_load_flow(self):
"""
E2E test of save + load with rank0_only + CPU offload for TransformerWithSharedParams
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future, will this test include different state dict types and subtest the different configs?

buffers, buffer_dtypes, fsdp_state.compute_device
if buffers:
mixed_precision_enabled_for_buffers = (
fsdp_state._mixed_precision_enabled_for_buffers() if not _is_composable(fsdp_state)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To-do: We can make _mixed_precision_enabled_for_buffers() not be a method of FullyShardedDataParallel to make this not have to if / else here. We would be able to just check fsdp_state.mxied_precision.buffer_dtype is not None -- flexible whether that is in its own function or written like that every time.

Ensures that load_state_dict for fully_shard works:
- Don't add back FSDP prefix
- Small fix to ensure mixed precision check for buffers work

Follow ups:
- state_dict_type does not work, blocking rank0_only and CPU offload as well as other state dict implementations
- No testing when wrapped with AC, using mixed precision, integration with distributed checkpoint, etc.


[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Dec 16, 2022
ghstack-source-id: 2e6a824
Pull Request resolved: #90945
Ensures that load_state_dict for fully_shard works:
- Don't add back FSDP prefix
- Small fix to ensure mixed precision check for buffers work

Follow ups:
- state_dict_type does not work, blocking rank0_only and CPU offload as well as other state dict implementations
- No testing when wrapped with AC, using mixed precision, integration with distributed checkpoint, etc.


[ghstack-poisoned]
Ensures that load_state_dict for fully_shard works:
- Don't add back FSDP prefix
- Small fix to ensure mixed precision check for buffers work

Follow ups:
- state_dict_type does not work, blocking rank0_only and CPU offload as well as other state dict implementations
- No testing when wrapped with AC, using mixed precision, integration with distributed checkpoint, etc.


[ghstack-poisoned]
Ensures that load_state_dict for fully_shard works:
- Don't add back FSDP prefix
- Small fix to ensure mixed precision check for buffers work

Follow ups:
- state_dict_type does not work, blocking rank0_only and CPU offload as well as other state dict implementations
- No testing when wrapped with AC, using mixed precision, integration with distributed checkpoint, etc.


[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Dec 16, 2022
ghstack-source-id: 72ce2ef
Pull Request resolved: #90945
Ensures that load_state_dict for fully_shard works:
- Don't add back FSDP prefix
- Small fix to ensure mixed precision check for buffers work

Follow ups:
- state_dict_type does not work, blocking rank0_only and CPU offload as well as other state dict implementations
- No testing when wrapped with AC, using mixed precision, integration with distributed checkpoint, etc.


[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Dec 19, 2022
ghstack-source-id: 980b4f2
Pull Request resolved: #90945
@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 19, 2022
param.zero_()
if zero_buffers:
for buffer in model.buffers():
ctx = FSDP.summon_full_params(model) if summon_full else suppress()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we include any to-do or issue for following up on this? Or, could you remind me what the current status on this is?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, let me file an issue summarizing it.

@rohan-varma
Copy link
Contributor Author

@pytorchbot merge -f "CI passed"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

awgu pushed a commit to awgu/pytorch that referenced this pull request Dec 20, 2022
Ensures that load_state_dict for fully_shard works:
- Don't add back FSDP prefix
- Small fix to ensure mixed precision check for buffers work

Follow ups:
- state_dict_type does not work, blocking rank0_only and CPU offload as well as other state dict implementations
- No testing when wrapped with AC, using mixed precision, integration with distributed checkpoint, etc.

Pull Request resolved: pytorch#90945
Approved by: https://github.com/awgu
awgu pushed a commit to awgu/pytorch that referenced this pull request Dec 20, 2022
Ensures that load_state_dict for fully_shard works:
- Don't add back FSDP prefix
- Small fix to ensure mixed precision check for buffers work

Follow ups:
- state_dict_type does not work, blocking rank0_only and CPU offload as well as other state dict implementations
- No testing when wrapped with AC, using mixed precision, integration with distributed checkpoint, etc.

Pull Request resolved: pytorch#90945
Approved by: https://github.com/awgu
ghstack-source-id: 1fd8b50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants