enable float32 and float16 in torch._grouped_mm fallback#162059
enable float32 and float16 in torch._grouped_mm fallback#162059vkuzo wants to merge 3 commits intogh/vkuzo/6/basefrom
torch._grouped_mm fallback#162059Conversation
Summary: Enables `torch.float32` and `torch.float16` options in `torch._grouped_mm`. Note that the fast path is only enabled if `mat_a`, `mat_b`, and `out_dtype` are `torch.bfloat16`. Saving for future PRs: 1. enabling testing on more platforms 2. supporting out_dtype != mat_a.dtype 3. opinfo 4. better compile support Test Plan: ```bash pytest test/test_matmul_cuda.py -s -k test_grouped_gemm -x ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/162059
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 8932199 with merge base aed33a8 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: Enables `torch.float32` and `torch.float16` options in `torch._grouped_mm`. Note that the fast path is only enabled if `mat_a`, `mat_b`, and `out_dtype` are `torch.bfloat16`. Saving for future PRs: 1. enabling testing on more platforms 2. supporting out_dtype != mat_a.dtype 3. opinfo 4. better compile support Test Plan: ```bash pytest test/test_matmul_cuda.py -s -k test_grouped_gemm -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: fdc346e Pull Request resolved: #162059
Summary: Enables `torch.float32` and `torch.float16` options in `torch._grouped_mm`. Note that the fast path is only enabled if `mat_a`, `mat_b`, and `out_dtype` are `torch.bfloat16`. Saving for future PRs: 1. enabling testing on more platforms 2. supporting out_dtype != mat_a.dtype 3. opinfo 4. better compile support Test Plan: ```bash pytest test/test_matmul_cuda.py -s -k test_grouped_gemm -x ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Enables `torch.float32` and `torch.float16` options in `torch._grouped_mm`. Note that the fast path is only enabled if `mat_a`, `mat_b`, and `out_dtype` are `torch.bfloat16`. Saving for future PRs: 1. enabling testing on more platforms 2. supporting out_dtype != mat_a.dtype 3. opinfo 4. better compile support Test Plan: ```bash pytest test/test_matmul_cuda.py -s -k test_grouped_gemm -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 6893b58 Pull Request resolved: #162059
Summary: Enables `torch.float32` and `torch.float16` options in `torch._grouped_mm`. Note that the fast path is only enabled if `mat_a`, `mat_b`, and `out_dtype` are `torch.bfloat16`. Saving for future PRs: 1. enabling testing on more platforms 2. supporting out_dtype != mat_a.dtype 3. opinfo 4. better compile support Test Plan: ```bash pytest test/test_matmul_cuda.py -s -k test_grouped_gemm -x ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Enables `torch.float32` and `torch.float16` options in `torch._grouped_mm`. Note that the fast path is only enabled if `mat_a`, `mat_b`, and `out_dtype` are `torch.bfloat16`. Saving for future PRs: 1. enabling testing on more platforms 2. supporting out_dtype != mat_a.dtype 3. opinfo 4. better compile support Test Plan: ```bash pytest test/test_matmul_cuda.py -s -k test_grouped_gemm -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 23e9fd6 Pull Request resolved: #162059
|
Should the compute-capability of tests be gated if #161407 is only for sm80+? |
currently the high precision grouped_gemm tests are gated with |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…62059) Summary: Enables `torch.float32` and `torch.float16` options in `torch._grouped_mm`. Note that the fast path is only enabled if `mat_a`, `mat_b`, and `out_dtype` are `torch.bfloat16`. Saving for future PRs: 1. enabling testing on more platforms 2. supporting out_dtype != mat_a.dtype 3. opinfo 4. better compile support Test Plan: ```bash // on A100 and H100 pytest test/test_matmul_cuda.py -s -k test_grouped_gemm -x // on H100 pytest test/test_matmul_cuda.py -s -k test_scaled_grouped_gemm -x ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: pytorch#162059 Approved by: https://github.com/ngimel, https://github.com/eqy ghstack dependencies: pytorch#161407, pytorch#161717
…62059) Summary: Enables `torch.float32` and `torch.float16` options in `torch._grouped_mm`. Note that the fast path is only enabled if `mat_a`, `mat_b`, and `out_dtype` are `torch.bfloat16`. Saving for future PRs: 1. enabling testing on more platforms 2. supporting out_dtype != mat_a.dtype 3. opinfo 4. better compile support Test Plan: ```bash // on A100 and H100 pytest test/test_matmul_cuda.py -s -k test_grouped_gemm -x // on H100 pytest test/test_matmul_cuda.py -s -k test_scaled_grouped_gemm -x ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: pytorch#162059 Approved by: https://github.com/ngimel, https://github.com/eqy ghstack dependencies: pytorch#161407, pytorch#161717
…62059) Summary: Enables `torch.float32` and `torch.float16` options in `torch._grouped_mm`. Note that the fast path is only enabled if `mat_a`, `mat_b`, and `out_dtype` are `torch.bfloat16`. Saving for future PRs: 1. enabling testing on more platforms 2. supporting out_dtype != mat_a.dtype 3. opinfo 4. better compile support Test Plan: ```bash // on A100 and H100 pytest test/test_matmul_cuda.py -s -k test_grouped_gemm -x // on H100 pytest test/test_matmul_cuda.py -s -k test_scaled_grouped_gemm -x ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: pytorch#162059 Approved by: https://github.com/ngimel, https://github.com/eqy ghstack dependencies: pytorch#161407, pytorch#161717
…62059) Summary: Enables `torch.float32` and `torch.float16` options in `torch._grouped_mm`. Note that the fast path is only enabled if `mat_a`, `mat_b`, and `out_dtype` are `torch.bfloat16`. Saving for future PRs: 1. enabling testing on more platforms 2. supporting out_dtype != mat_a.dtype 3. opinfo 4. better compile support Test Plan: ```bash // on A100 and H100 pytest test/test_matmul_cuda.py -s -k test_grouped_gemm -x // on H100 pytest test/test_matmul_cuda.py -s -k test_scaled_grouped_gemm -x ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: pytorch#162059 Approved by: https://github.com/ngimel, https://github.com/eqy ghstack dependencies: pytorch#161407, pytorch#161717
…k was added (#165378) #162059 means we get unexpected successes now on e.g., SM 12.0 Pull Request resolved: #165378 Approved by: https://github.com/Skylion007
Stack from ghstack (oldest at bottom):
torch._grouped_mmfallback #162059_grouped_mmfallback to composite explicit autograd #161717Summary:
Enables
torch.float32andtorch.float16options intorch._grouped_mm. Note that the fast path is only enabled ifmat_a,mat_b, andout_dtypearetorch.bfloat16.Saving for future PRs:
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: