Skip to content

float8 with delayed scaling: fix autocast handling#1306

Merged
vkuzo merged 1 commit into
mainfrom
20241118_delayed_scaling_autocast_fix
Nov 19, 2024
Merged

float8 with delayed scaling: fix autocast handling#1306
vkuzo merged 1 commit into
mainfrom
20241118_delayed_scaling_autocast_fix

Conversation

@vkuzo

@vkuzo vkuzo commented Nov 18, 2024

Copy link
Copy Markdown
Contributor

Summary:

Fixes a bug with delayed scaling + autocast.

Before, the last input dtype when in autocast was queried from the input to torch._scaled_mm:

x_hp -> {query_dtype_here} -> to_autocast -> torch._scaled_mm

This is incorrect because the dtype was saved from before the place where autocast could change it. This happened to work if x_hp was already of the correct dtype, but did not work in cases such as the new test case added in this PR, or real models such as the repro from #1297. The reason we haven't caught this for so long is we've been using FSDP's mixed precision and not single-GPU autocast.

The fix I'm taking here is to query the original post-autocast dtype based on the output of torch._scaled_mm. Since this dtype is based on the dtype of the input to torch._scaled_mm, this will properly capture autocasting:

x_hp -> to_autocast -> x_autocast_dtype -> to_fp8 -> x_fp8 -> torch._scaled_mm -> {query_dtype_here}

Test Plan:

// first, test the updated test case - it passes

// second - test a modified version of the repro in
// https://github.com/pytorch/ao/issues/1297:
// code: https://gist.github.com/vkuzo/6c53a1deca19856238d38746b1e52ee7
// logs: https://gist.github.com/vkuzo/60846b1f6b2822f91d2dfa67cab10a10
// we now see a speedup with float8

Reviewers:

Subscribers:

Tasks:

Tags:

@pytorch-bot

pytorch-bot Bot commented Nov 18, 2024

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1306

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 4402195 with merge base 6234116 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 18, 2024
Summary:

Fixes a bug with delayed scaling + autocast.

Before, the last input dtype when in autocast was queried from the input
to `torch._scaled_mm`:

```
x_hp -> {query_dtype_here} -> to_autocast -> torch._scaled_mm
```

This is incorrect because the dtype was saved from before the place
where autocast could change it.  This happened to work if `x_hp` was
already of the correct dtype, but did not work in cases such as the new
test case added in this PR, or real models such as the repro from
#1297.  The reason we haven't caught
this for so long is we've been using FSDP's mixed precision and not
single-GPU autocast.

The fix I'm taking here is to query the original post-autocast dtype based
on the output of `torch._scaled_mm`.  Since this dtype is based on the
dtype of the input to `torch._scaled_mm`, this will properly capture
autocasting:

```
x_hp -> to_autocast -> x_autocast_dtype -> to_fp8 -> x_fp8 -> torch._scaled_mm -> {query_dtype_here}
```

Test Plan:

```
// first, test the updated test case - it passes

// second - test a modified version of the repro in
// #1297:
// code: https://gist.github.com/vkuzo/6c53a1deca19856238d38746b1e52ee7
// logs: https://gist.github.com/vkuzo/60846b1f6b2822f91d2dfa67cab10a10
// we now see a speedup with float8
```

Reviewers:

Subscribers:

Tasks:

Tags:
@vkuzo vkuzo added the topic: bug fix Use this tag for PRs that fix bugs label Nov 18, 2024
@vkuzo vkuzo force-pushed the 20241118_delayed_scaling_autocast_fix branch from 76586ad to 4402195 Compare November 18, 2024 23:41
@vkuzo vkuzo requested review from drisspg and weifengpy November 18, 2024 23:42
@vkuzo vkuzo merged commit b714026 into main Nov 19, 2024
vkuzo added a commit that referenced this pull request Nov 22, 2024
Summary:

In #1306 I accidentally broke
torchtitan + float8 + AC + compile.

I don't have a non-torchtitan repro now, putting up the fix first
to ensure torchtitan still works, and we should follow-up later
with adding test coverage to torchao to prevent similar breakages in the
future.

What broke:
* in the forward of `Float8Linear`, we were setting an attribute on
  the module
* ^ is not supported with compile + something how torchtitan
  specifically calls AC

The fix: remove this attribute setting altogether. Unfortunately this
breaks an edge case feature for ensuring scales are reprensentable in
`float16`.  Since `float16` training is not commonly used with `float8`
and this feature was added during very early testing, removing this for
now is fine.

If we need to add this feature back in the future, I'd advocate for
doing it via explicit configuration such as `config.set_scale_upper_bound`
and avoiding the stateful hacks, which are usually not compiler
friendly.

Test Plan:

```
// this repo
./test/float8/test_everything.sh

// torchtitan - broken before this PR, works after this PR
with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:
vkuzo added a commit that referenced this pull request Nov 22, 2024
Summary:

In #1306 I accidentally broke
torchtitan + float8 + AC + compile.

I don't have a non-torchtitan repro now, putting up the fix first
to ensure torchtitan still works, and we should follow-up later
with adding test coverage to torchao to prevent similar breakages in the
future.

What broke:
* in the forward of `Float8Linear`, we were setting an attribute on
  the module
* ^ is not supported with compile + something how torchtitan
  specifically calls AC

The fix: remove this attribute setting altogether. Unfortunately this
breaks an edge case feature for ensuring scales are reprensentable in
`float16`.  Since `float16` training is not commonly used with `float8`
and this feature was added during very early testing, removing this for
now is fine.

If we need to add this feature back in the future, I'd advocate for
doing it via explicit configuration such as `config.set_scale_upper_bound`
and avoiding the stateful hacks, which are usually not compiler
friendly.

Test Plan:

```
// this repo
./test/float8/test_everything.sh

// torchtitan - broken before this PR, works after this PR
with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: bug fix Use this tag for PRs that fix bugs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants