fully_shard load state_dict#90945
fully_shard load state_dict#90945rohan-varma wants to merge 9 commits intogh/rohan-varma/627/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90945
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 298681f: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
| [test_name_mapping[str(s)] if s is not None else "none" for s in args] | ||
| ) | ||
|
|
||
| def _broadcast_state_dict(rank, state_dict): |
There was a problem hiding this comment.
not strictly needed right now but will be used in composable tests.
[ghstack-poisoned]
[ghstack-poisoned]
awgu
left a comment
There was a problem hiding this comment.
LGTM! There a few to-dos you left for yourself. Feel free to address those before landing.
| @skip_if_lt_x_gpu(2) | ||
| def test_state_dict_save_load_flow(self): | ||
| """ | ||
| E2E test of save + load with rank0_only + CPU offload for TransformerWithSharedParams |
There was a problem hiding this comment.
In the future, will this test include different state dict types and subtest the different configs?
| buffers, buffer_dtypes, fsdp_state.compute_device | ||
| if buffers: | ||
| mixed_precision_enabled_for_buffers = ( | ||
| fsdp_state._mixed_precision_enabled_for_buffers() if not _is_composable(fsdp_state) |
There was a problem hiding this comment.
To-do: We can make _mixed_precision_enabled_for_buffers() not be a method of FullyShardedDataParallel to make this not have to if / else here. We would be able to just check fsdp_state.mxied_precision.buffer_dtype is not None -- flexible whether that is in its own function or written like that every time.
Ensures that load_state_dict for fully_shard works: - Don't add back FSDP prefix - Small fix to ensure mixed precision check for buffers work Follow ups: - state_dict_type does not work, blocking rank0_only and CPU offload as well as other state dict implementations - No testing when wrapped with AC, using mixed precision, integration with distributed checkpoint, etc. [ghstack-poisoned]
Ensures that load_state_dict for fully_shard works: - Don't add back FSDP prefix - Small fix to ensure mixed precision check for buffers work Follow ups: - state_dict_type does not work, blocking rank0_only and CPU offload as well as other state dict implementations - No testing when wrapped with AC, using mixed precision, integration with distributed checkpoint, etc. [ghstack-poisoned]
Ensures that load_state_dict for fully_shard works: - Don't add back FSDP prefix - Small fix to ensure mixed precision check for buffers work Follow ups: - state_dict_type does not work, blocking rank0_only and CPU offload as well as other state dict implementations - No testing when wrapped with AC, using mixed precision, integration with distributed checkpoint, etc. [ghstack-poisoned]
Ensures that load_state_dict for fully_shard works: - Don't add back FSDP prefix - Small fix to ensure mixed precision check for buffers work Follow ups: - state_dict_type does not work, blocking rank0_only and CPU offload as well as other state dict implementations - No testing when wrapped with AC, using mixed precision, integration with distributed checkpoint, etc. [ghstack-poisoned]
Ensures that load_state_dict for fully_shard works: - Don't add back FSDP prefix - Small fix to ensure mixed precision check for buffers work Follow ups: - state_dict_type does not work, blocking rank0_only and CPU offload as well as other state dict implementations - No testing when wrapped with AC, using mixed precision, integration with distributed checkpoint, etc. [ghstack-poisoned]
| param.zero_() | ||
| if zero_buffers: | ||
| for buffer in model.buffers(): | ||
| ctx = FSDP.summon_full_params(model) if summon_full else suppress() |
There was a problem hiding this comment.
Should we include any to-do or issue for following up on this? Or, could you remind me what the current status on this is?
There was a problem hiding this comment.
yeah, let me file an issue summarizing it.
|
@pytorchbot merge -f "CI passed" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Ensures that load_state_dict for fully_shard works: - Don't add back FSDP prefix - Small fix to ensure mixed precision check for buffers work Follow ups: - state_dict_type does not work, blocking rank0_only and CPU offload as well as other state dict implementations - No testing when wrapped with AC, using mixed precision, integration with distributed checkpoint, etc. Pull Request resolved: pytorch#90945 Approved by: https://github.com/awgu
Ensures that load_state_dict for fully_shard works: - Don't add back FSDP prefix - Small fix to ensure mixed precision check for buffers work Follow ups: - state_dict_type does not work, blocking rank0_only and CPU offload as well as other state dict implementations - No testing when wrapped with AC, using mixed precision, integration with distributed checkpoint, etc. Pull Request resolved: pytorch#90945 Approved by: https://github.com/awgu ghstack-source-id: 1fd8b50
Stack from ghstack (oldest at bottom):
Ensures that load_state_dict for fully_shard works:
Follow ups: