Skip to content

fix float8 rowwise inference perf with torch.compile#2672

Merged
vkuzo merged 1 commit into
mainfrom
20250804_fix_float8_rowwise_compile
Aug 4, 2025
Merged

fix float8 rowwise inference perf with torch.compile#2672
vkuzo merged 1 commit into
mainfrom
20250804_fix_float8_rowwise_compile

Conversation

@vkuzo

@vkuzo vkuzo commented Aug 4, 2025

Copy link
Copy Markdown
Contributor

In #2379, logic was added which prevented torchinductor from fusing the activation quantization for float8 inference. Here are some logs which show the extra kernels being added by that PR to float8 inference on NVIDIA GPUs: https://www.internalfb.com/phabricator/paste/view/P1891592748 .

This PR reverts most of #2379, and adds a test to ensure we see the correct # of GPU kernels for float8 tensorwise and rowwise quantization. We'll have to re-do #2379 without breaking this test.

Perf impact of this PR on MKN == 1024, 2048, 4096 on an NVIDIA H100 for float8 rowwise inference:

Note that I added a benchmark to benchmarks/inference/bench_float8_inference.py to reproduce the numbers above, but I ran this benchmark out-of-tree to get the actual numbers, for easier comparison of before-this-PR vs after-this-PR.

Summary:

Test Plan:

TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 pytest test/dtypes/test_affine_quantized_float.py -s -k expected_kernels_on_gpu

Reviewers:

Subscribers:

Tasks:

Tags:

@pytorch-bot

pytorch-bot Bot commented Aug 4, 2025

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2672

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 2 Pending

As of commit 8c23d32 with merge base 7dbc816 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 4, 2025
@vkuzo vkuzo added the topic: bug fix Use this tag for PRs that fix bugs label Aug 4, 2025
In #2379, logic was added which
prevented torchinductor from fusing the activation quantization for
float8 inference.

This PR reverts most of #2379, and
adds a test to ensure we see the correct # of GPU kernels for float8
tensorwise and rowwise quantization.  We'll have to re-do
#2379 without breaking this test.

Summary:

Test Plan:

```bash
TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 pytest test/dtypes/test_affine_quantized_float.py -s -k expected_kernels_on_gpu
```

Reviewers:

Subscribers:

Tasks:

Tags:
@vkuzo vkuzo force-pushed the 20250804_fix_float8_rowwise_compile branch from 3e9752c to 8c23d32 Compare August 4, 2025 15:14
@vkuzo vkuzo requested review from drisspg and jerryzh168 August 4, 2025 15:15

@jerryzh168 jerryzh168 left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix, I feel this might be related to the fbgemm benchmark regression as well

@vkuzo vkuzo merged commit 6a74e34 into main Aug 4, 2025
17 of 20 checks passed
liangel-02 pushed a commit that referenced this pull request Aug 25, 2025
In #2379, logic was added which
prevented torchinductor from fusing the activation quantization for
float8 inference.

This PR reverts most of #2379, and
adds a test to ensure we see the correct # of GPU kernels for float8
tensorwise and rowwise quantization.  We'll have to re-do
#2379 without breaking this test.

Summary:

Test Plan:

```bash
TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 pytest test/dtypes/test_affine_quantized_float.py -s -k expected_kernels_on_gpu
```

Reviewers:

Subscribers:

Tasks:

Tags:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: bug fix Use this tag for PRs that fix bugs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants