This repository was archived by the owner on Aug 7, 2024. It is now read-only.
[5/x] make FSDP2 with float8 all-gather work for Float8Linear#296
Closed
vkuzo wants to merge 2 commits intogh/vkuzo/15/basefrom
Closed
[5/x] make FSDP2 with float8 all-gather work for Float8Linear#296vkuzo wants to merge 2 commits intogh/vkuzo/15/basefrom
vkuzo wants to merge 2 commits intogh/vkuzo/15/basefrom
Conversation
Summary: Adds test coverage for `Float8Linear` with all dynamic scaling and FSDP2 with float8 all-gather. To make the tests pass, fixes a bug with initilization ordering in `Float8Linear.from_float`, we need to have the right forward config set before stashing it on the weight wrapper. Test Plan: ``` python test/test_fsdp2/test_fsdp2_eager.py /test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
This was referenced Jul 1, 2024
vkuzo
added a commit
that referenced
this pull request
Jul 2, 2024
Summary: Adds test coverage for `Float8Linear` with all dynamic scaling and FSDP2 with float8 all-gather. To make the tests pass, fixes a bug with initilization ordering in `Float8Linear.from_float`, we need to have the right forward config set before stashing it on the weight wrapper. Test Plan: ``` python test/test_fsdp2/test_fsdp2_eager.py /test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: b6d6525 Pull Request resolved: #296
…ear" Summary: Adds test coverage for `Float8Linear` with all dynamic scaling and FSDP2 with float8 all-gather. To make the tests pass, fixes a bug with initilization ordering in `Float8Linear.from_float`, we need to have the right forward config set before stashing it on the weight wrapper. Test Plan: ``` python test/test_fsdp2/test_fsdp2_eager.py /test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
vkuzo
added a commit
that referenced
this pull request
Jul 2, 2024
Summary: Adds test coverage for `Float8Linear` with all dynamic scaling and FSDP2 with float8 all-gather. To make the tests pass, fixes a bug with initilization ordering in `Float8Linear.from_float`, we need to have the right forward config set before stashing it on the weight wrapper. Test Plan: ``` python test/test_fsdp2/test_fsdp2_eager.py /test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 26b7138 Pull Request resolved: #296
drisspg
reviewed
Jul 2, 2024
| ) | ||
| new_mod.weight = mod.weight | ||
| else: | ||
| assert not config.enable_fsdp_fp8_all_gather, "unsupported" |
Contributor
There was a problem hiding this comment.
Nit: maybe a more helpful assert message
drisspg
reviewed
Jul 2, 2024
|
|
||
| def swap_linear_with_dynamic(self, module: nn.Module, **kwargs: Any) -> nn.Module: | ||
| return swap_linear_with_float8_linear(module, Float8DynamicLinear, **kwargs) | ||
| def swap_linear_with_dynamic( |
Contributor
There was a problem hiding this comment.
Maybe losing some context but is there a reason why the existing swap function doesnt work?
Contributor
Author
There was a problem hiding this comment.
if the question is why do we need swap_linear_with_dynamic, we probably don't. Removing that is not related to this PR though so I left it for a future person.
drisspg
reviewed
Jul 2, 2024
| self._test_transformer_memory(enable_fsdp_fp8_all_gather) | ||
|
|
||
| def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool): | ||
| # for enable_fsdp_fp8_all_gather in [False, True]: |
drisspg
approved these changes
Jul 2, 2024
Contributor
drisspg
left a comment
There was a problem hiding this comment.
Looks good, maybe add a dummy test that float8Linear with not all dynamic errors when trying to use fp8 allgather
This was referenced Jul 2, 2024
Contributor
Author
|
@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Contributor
|
This pull request has been merged in 412222b. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to subscribe to this conversation on GitHub.
Already have an account?
Sign in.
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stack from ghstack (oldest at bottom):
Summary:
Adds test coverage for
Float8Linearwith all dynamic scaling and FSDP2with float8 all-gather.
To make the tests pass, fixes a bug with initilization ordering in
Float8Linear.from_float, we need to have the right forward configset before stashing it on the weight wrapper.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: D59305793