Skip to content

enable float32 and float16 in torch._grouped_mm fallback#162059

Closed
vkuzo wants to merge 3 commits intogh/vkuzo/6/basefrom
gh/vkuzo/6/head
Closed

enable float32 and float16 in torch._grouped_mm fallback#162059
vkuzo wants to merge 3 commits intogh/vkuzo/6/basefrom
gh/vkuzo/6/head

Conversation

@vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Sep 3, 2025

Stack from ghstack (oldest at bottom):

Summary:

Enables torch.float32 and torch.float16 options in
torch._grouped_mm. Note that the fast path is only enabled if mat_a,
mat_b, and out_dtype are torch.bfloat16.

Saving for future PRs:

  1. enabling testing on more platforms
  2. supporting out_dtype != mat_a.dtype
  3. opinfo
  4. better compile support

Test Plan:

// on A100 and H100
pytest test/test_matmul_cuda.py -s -k test_grouped_gemm -x
// on H100
pytest test/test_matmul_cuda.py -s -k test_scaled_grouped_gemm -x

Reviewers:

Subscribers:

Tasks:

Tags:

Summary:

Enables `torch.float32` and `torch.float16` options in
`torch._grouped_mm`. Note that the fast path is only enabled if `mat_a`,
`mat_b`, and `out_dtype` are `torch.bfloat16`.

Saving for future PRs:
1. enabling testing on more platforms
2. supporting out_dtype != mat_a.dtype
3. opinfo
4. better compile support

Test Plan:

```bash
pytest test/test_matmul_cuda.py -s -k test_grouped_gemm -x
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 3, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/162059

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8932199 with merge base aed33a8 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

vkuzo added a commit that referenced this pull request Sep 3, 2025
Summary:

Enables `torch.float32` and `torch.float16` options in
`torch._grouped_mm`. Note that the fast path is only enabled if `mat_a`,
`mat_b`, and `out_dtype` are `torch.bfloat16`.

Saving for future PRs:
1. enabling testing on more platforms
2. supporting out_dtype != mat_a.dtype
3. opinfo
4. better compile support

Test Plan:

```bash
pytest test/test_matmul_cuda.py -s -k test_grouped_gemm -x
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: fdc346e
Pull Request resolved: #162059
@vkuzo vkuzo requested review from drisspg and ngimel September 3, 2025 13:35
Summary:

Enables `torch.float32` and `torch.float16` options in
`torch._grouped_mm`. Note that the fast path is only enabled if `mat_a`,
`mat_b`, and `out_dtype` are `torch.bfloat16`.

Saving for future PRs:
1. enabling testing on more platforms
2. supporting out_dtype != mat_a.dtype
3. opinfo
4. better compile support

Test Plan:

```bash
pytest test/test_matmul_cuda.py -s -k test_grouped_gemm -x
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Sep 3, 2025
Summary:

Enables `torch.float32` and `torch.float16` options in
`torch._grouped_mm`. Note that the fast path is only enabled if `mat_a`,
`mat_b`, and `out_dtype` are `torch.bfloat16`.

Saving for future PRs:
1. enabling testing on more platforms
2. supporting out_dtype != mat_a.dtype
3. opinfo
4. better compile support

Test Plan:

```bash
pytest test/test_matmul_cuda.py -s -k test_grouped_gemm -x
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 6893b58
Pull Request resolved: #162059
@vkuzo vkuzo added the topic: not user facing topic category label Sep 3, 2025
@ngimel ngimel added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 3, 2025
Summary:

Enables `torch.float32` and `torch.float16` options in
`torch._grouped_mm`. Note that the fast path is only enabled if `mat_a`,
`mat_b`, and `out_dtype` are `torch.bfloat16`.

Saving for future PRs:
1. enabling testing on more platforms
2. supporting out_dtype != mat_a.dtype
3. opinfo
4. better compile support

Test Plan:

```bash
pytest test/test_matmul_cuda.py -s -k test_grouped_gemm -x
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Sep 4, 2025
Summary:

Enables `torch.float32` and `torch.float16` options in
`torch._grouped_mm`. Note that the fast path is only enabled if `mat_a`,
`mat_b`, and `out_dtype` are `torch.bfloat16`.

Saving for future PRs:
1. enabling testing on more platforms
2. supporting out_dtype != mat_a.dtype
3. opinfo
4. better compile support

Test Plan:

```bash
pytest test/test_matmul_cuda.py -s -k test_grouped_gemm -x
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 23e9fd6
Pull Request resolved: #162059
@vkuzo vkuzo requested a review from ngimel September 4, 2025 12:20
@eqy eqy added the ciflow/h100 label Sep 4, 2025
@eqy
Copy link
Collaborator

eqy commented Sep 4, 2025

Should the compute-capability of tests be gated if #161407 is only for sm80+?

@vkuzo
Copy link
Contributor Author

vkuzo commented Sep 4, 2025

Should the compute-capability of tests be gated if #161407 is only for sm80+?

currently the high precision grouped_gemm tests are gated with @unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater") (added earlier in this stack). The fallback should work on earlier GPUs as well, but I currently only have an A100 to test on. Would be interested in thoughts on if there are additional GPU cards in CI we can enable these tests for - the fallback should be supported anywhere where torch.mm is supported.

@vkuzo
Copy link
Contributor Author

vkuzo commented Sep 4, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…62059)

Summary:

Enables `torch.float32` and `torch.float16` options in
`torch._grouped_mm`. Note that the fast path is only enabled if `mat_a`,
`mat_b`, and `out_dtype` are `torch.bfloat16`.

Saving for future PRs:
1. enabling testing on more platforms
2. supporting out_dtype != mat_a.dtype
3. opinfo
4. better compile support

Test Plan:

```bash
// on A100 and H100
pytest test/test_matmul_cuda.py -s -k test_grouped_gemm -x
// on H100
pytest test/test_matmul_cuda.py -s -k test_scaled_grouped_gemm -x
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: pytorch#162059
Approved by: https://github.com/ngimel, https://github.com/eqy
ghstack dependencies: pytorch#161407, pytorch#161717
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
…62059)

Summary:

Enables `torch.float32` and `torch.float16` options in
`torch._grouped_mm`. Note that the fast path is only enabled if `mat_a`,
`mat_b`, and `out_dtype` are `torch.bfloat16`.

Saving for future PRs:
1. enabling testing on more platforms
2. supporting out_dtype != mat_a.dtype
3. opinfo
4. better compile support

Test Plan:

```bash
// on A100 and H100
pytest test/test_matmul_cuda.py -s -k test_grouped_gemm -x
// on H100
pytest test/test_matmul_cuda.py -s -k test_scaled_grouped_gemm -x
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: pytorch#162059
Approved by: https://github.com/ngimel, https://github.com/eqy
ghstack dependencies: pytorch#161407, pytorch#161717
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
…62059)

Summary:

Enables `torch.float32` and `torch.float16` options in
`torch._grouped_mm`. Note that the fast path is only enabled if `mat_a`,
`mat_b`, and `out_dtype` are `torch.bfloat16`.

Saving for future PRs:
1. enabling testing on more platforms
2. supporting out_dtype != mat_a.dtype
3. opinfo
4. better compile support

Test Plan:

```bash
// on A100 and H100
pytest test/test_matmul_cuda.py -s -k test_grouped_gemm -x
// on H100
pytest test/test_matmul_cuda.py -s -k test_scaled_grouped_gemm -x
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: pytorch#162059
Approved by: https://github.com/ngimel, https://github.com/eqy
ghstack dependencies: pytorch#161407, pytorch#161717
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
…62059)

Summary:

Enables `torch.float32` and `torch.float16` options in
`torch._grouped_mm`. Note that the fast path is only enabled if `mat_a`,
`mat_b`, and `out_dtype` are `torch.bfloat16`.

Saving for future PRs:
1. enabling testing on more platforms
2. supporting out_dtype != mat_a.dtype
3. opinfo
4. better compile support

Test Plan:

```bash
// on A100 and H100
pytest test/test_matmul_cuda.py -s -k test_grouped_gemm -x
// on H100
pytest test/test_matmul_cuda.py -s -k test_scaled_grouped_gemm -x
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: pytorch#162059
Approved by: https://github.com/ngimel, https://github.com/eqy
ghstack dependencies: pytorch#161407, pytorch#161717
@github-actions github-actions bot deleted the gh/vkuzo/6/head branch October 5, 2025 02:17
pytorchmergebot pushed a commit that referenced this pull request Oct 24, 2025
…k was added (#165378)

#162059 means we get unexpected successes now on e.g., SM 12.0

Pull Request resolved: #165378
Approved by: https://github.com/Skylion007
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/h100 ciflow/trunk Trigger trunk jobs on your pull request Merged topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants