float8 with delayed scaling: fix autocast handling#1306
Merged
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1306
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 4402195 with merge base 6234116 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary:
Fixes a bug with delayed scaling + autocast.
Before, the last input dtype when in autocast was queried from the input
to `torch._scaled_mm`:
```
x_hp -> {query_dtype_here} -> to_autocast -> torch._scaled_mm
```
This is incorrect because the dtype was saved from before the place
where autocast could change it. This happened to work if `x_hp` was
already of the correct dtype, but did not work in cases such as the new
test case added in this PR, or real models such as the repro from
#1297. The reason we haven't caught
this for so long is we've been using FSDP's mixed precision and not
single-GPU autocast.
The fix I'm taking here is to query the original post-autocast dtype based
on the output of `torch._scaled_mm`. Since this dtype is based on the
dtype of the input to `torch._scaled_mm`, this will properly capture
autocasting:
```
x_hp -> to_autocast -> x_autocast_dtype -> to_fp8 -> x_fp8 -> torch._scaled_mm -> {query_dtype_here}
```
Test Plan:
```
// first, test the updated test case - it passes
// second - test a modified version of the repro in
// #1297:
// code: https://gist.github.com/vkuzo/6c53a1deca19856238d38746b1e52ee7
// logs: https://gist.github.com/vkuzo/60846b1f6b2822f91d2dfa67cab10a10
// we now see a speedup with float8
```
Reviewers:
Subscribers:
Tasks:
Tags:
76586ad to
4402195
Compare
drisspg
approved these changes
Nov 19, 2024
vkuzo
added a commit
that referenced
this pull request
Nov 22, 2024
Summary: In #1306 I accidentally broke torchtitan + float8 + AC + compile. I don't have a non-torchtitan repro now, putting up the fix first to ensure torchtitan still works, and we should follow-up later with adding test coverage to torchao to prevent similar breakages in the future. What broke: * in the forward of `Float8Linear`, we were setting an attribute on the module * ^ is not supported with compile + something how torchtitan specifically calls AC The fix: remove this attribute setting altogether. Unfortunately this breaks an edge case feature for ensuring scales are reprensentable in `float16`. Since `float16` training is not commonly used with `float8` and this feature was added during very early testing, removing this for now is fine. If we need to add this feature back in the future, I'd advocate for doing it via explicit configuration such as `config.set_scale_upper_bound` and avoiding the stateful hacks, which are usually not compiler friendly. Test Plan: ``` // this repo ./test/float8/test_everything.sh // torchtitan - broken before this PR, works after this PR with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile ``` Reviewers: Subscribers: Tasks: Tags:
vkuzo
added a commit
that referenced
this pull request
Nov 22, 2024
Summary: In #1306 I accidentally broke torchtitan + float8 + AC + compile. I don't have a non-torchtitan repro now, putting up the fix first to ensure torchtitan still works, and we should follow-up later with adding test coverage to torchao to prevent similar breakages in the future. What broke: * in the forward of `Float8Linear`, we were setting an attribute on the module * ^ is not supported with compile + something how torchtitan specifically calls AC The fix: remove this attribute setting altogether. Unfortunately this breaks an edge case feature for ensuring scales are reprensentable in `float16`. Since `float16` training is not commonly used with `float8` and this feature was added during very early testing, removing this for now is fine. If we need to add this feature back in the future, I'd advocate for doing it via explicit configuration such as `config.set_scale_upper_bound` and avoiding the stateful hacks, which are usually not compiler friendly. Test Plan: ``` // this repo ./test/float8/test_everything.sh // torchtitan - broken before this PR, works after this PR with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile ``` Reviewers: Subscribers: Tasks: Tags:
yanbing-j
pushed a commit
to yanbing-j/ao
that referenced
this pull request
Dec 9, 2024
…nd (pytorch#1306) Bunch of minor code quality things.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
Fixes a bug with delayed scaling + autocast.
Before, the last input dtype when in autocast was queried from the input to
torch._scaled_mm:This is incorrect because the dtype was saved from before the place where autocast could change it. This happened to work if
x_hpwas already of the correct dtype, but did not work in cases such as the new test case added in this PR, or real models such as the repro from #1297. The reason we haven't caught this for so long is we've been using FSDP's mixed precision and not single-GPU autocast.The fix I'm taking here is to query the original post-autocast dtype based on the output of
torch._scaled_mm. Since this dtype is based on the dtype of the input totorch._scaled_mm, this will properly capture autocasting:Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: