Skip to content

meta registration for torch._scaled_mm with mxfp8#148461

Closed
vkuzo wants to merge 5 commits into
gh/vkuzo/7/basefrom
gh/vkuzo/7/head
Closed

meta registration for torch._scaled_mm with mxfp8#148461
vkuzo wants to merge 5 commits into
gh/vkuzo/7/basefrom
gh/vkuzo/7/head

Conversation

@vkuzo

@vkuzo vkuzo commented Mar 4, 2025

Copy link
Copy Markdown
Contributor

Stack from ghstack (oldest at bottom):

Summary:

Adds the meta registration logic for torch.compile to work with
torch._scaled_mm with mxfp8. Thanks to @eellison for the pointer to make inductor work with this.

Test Plan:

pytest test/test_matmul_cuda.py -k test_blockwise_mxfp8_compile -s

Reviewers:

Subscribers:

Tasks:

Tags:

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

[ghstack-poisoned]
@pytorch-bot

pytorch-bot Bot commented Mar 4, 2025

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148461

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (4 Unrelated Failures)

As of commit 5e38d0b with merge base 23183fe (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

vkuzo added a commit that referenced this pull request Mar 4, 2025
Summary:

Adds the meta registration logic for torch.compile to work with
`torch._scaled_mm` with mxfp8, with `aot_eager` backend.

Note that we need #147873 for
inductor to work.

Test Plan:

```
pytest test/test_matmul_cuda.py -k test_blockwise_mxfp8_compile -s
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 248830a
Pull Request resolved: #148461
@vkuzo vkuzo added the topic: not user facing topic category label Mar 4, 2025
vkuzo added a commit to pytorch/ao that referenced this pull request Mar 5, 2025
Summary:

This PR enables `MXLinear` with `mxfp8_cublas` recipe to use
torch.compile.

The current approach is a short term workaround until
pytorch/pytorch#148461 is done. Since we can't
use e8m0 in torchinductor or triton yet, we create a custom op wrapper
around `torch._scaled_mm` which takes `uint8` scales and does the cast to
e8m0 inside the wrapper, where torchinductor can't see it.

Test Plan:

```
// this now works (although performance is not ideal due to #1788)
python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250305_test --mx_recipe_name mxfp8_cublas

// we can also uncomment the hardware check and run the unit test
pytest test/prototype/mx_formats -s -k test_linear_compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 033d817
ghstack-comment-id: 2701679811
Pull Request resolved: #1841
Comment thread test/test_matmul_cuda.py Outdated
C_ref = A_ref @ B_ref.t()

# TODO(#147873): switch to inductor backend after e8m0 is supported there
compiled_scaled_mm = torch.compile(torch._scaled_mm, backend="aot_eager")

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when I rebase past https://github.com/pytorch/pytorch/pull/148722/files and then change the backend in this code to inductor, I see https://www.internalfb.com/phabricator/paste/view/P1753483593 . cc @eellison

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 13, 2025
Summary:

Adds the meta registration logic for torch.compile to work with
`torch._scaled_mm` with mxfp8, with `aot_eager` backend.

Note that we need #147873 for
inductor to work.

Test Plan:

```
pytest test/test_matmul_cuda.py -k test_blockwise_mxfp8_compile -s
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 67791ed
Pull Request resolved: #148461
@vkuzo vkuzo requested review from drisspg and eellison March 13, 2025 14:19
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 13, 2025
Summary:

Adds the meta registration logic for torch.compile to work with
`torch._scaled_mm` with mxfp8, with `aot_eager` backend.

Note that we need #147873 for
inductor to work.

Test Plan:

```
pytest test/test_matmul_cuda.py -k test_blockwise_mxfp8_compile -s
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: ea7eb2f
Pull Request resolved: #148461
[ghstack-poisoned]
vkuzo added a commit to pytorch/ao that referenced this pull request Mar 26, 2025
Summary:

After pytorch/pytorch#148461 lands, we
can use `torch.float8_e8m0fnu` throughout our codebase and compile
will still work, removing the workarounds.

Test Plan:

```
pytest test/prototype/mx_formats/ -s -x
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 278117b
ghstack-comment-id: 2755728114
Pull Request resolved: #1966
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 26, 2025
Summary:

Adds the meta registration logic for torch.compile to work with
`torch._scaled_mm` with mxfp8, with `aot_eager` backend.

Note that we need #147873 for
inductor to work.

Test Plan:

```
pytest test/test_matmul_cuda.py -k test_blockwise_mxfp8_compile -s
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 2c22e77
Pull Request resolved: #148461
vkuzo added a commit to pytorch/ao that referenced this pull request Mar 26, 2025
Summary:

After pytorch/pytorch#148461 lands, we
can use `torch.float8_e8m0fnu` throughout our codebase and compile
will still work, removing the workarounds.

Test Plan:

```
pytest test/prototype/mx_formats/ -s -x
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 278117b
ghstack-comment-id: 2755728114
Pull Request resolved: #1966
@vkuzo

vkuzo commented Mar 26, 2025

Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 26, 2025
@pytorchmergebot

Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
Summary:

Adds the meta registration logic for torch.compile to work with
`torch._scaled_mm` with mxfp8.  Thanks to @eellison  for the pointer to make inductor work with this.

Test Plan:

```
pytest test/test_matmul_cuda.py -k test_blockwise_mxfp8_compile -s
```

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: pytorch#148461
Approved by: https://github.com/drisspg, https://github.com/eellison
@github-actions github-actions Bot deleted the gh/vkuzo/7/head branch May 2, 2025 02:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants