Update auto-tuning support for _scaled_grouped_mm#150944
Update auto-tuning support for _scaled_grouped_mm#150944alexsamardzic wants to merge 39 commits intogh/alexsamardzic/1/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/150944
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit 016b438 with merge base f34ab16 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Validation scriptfrom enum import Enum
from itertools import product
import torch
f_ref = torch._scaled_grouped_mm
f = torch.compile(
f_ref,
options={
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON",
},
)
class MMType(Enum):
MM_2D_2D = 1
MM_2D_3D = 2
MM_3D_2D = 3
MM_3D_3D = 4
def generate_data(
mm_type, group_size, M, N, K, device, dtype_AB, dtype_scale, dtype_offset, strided
):
if mm_type == MMType.MM_2D_2D:
A = torch.randn(M, K * (group_size + strided), device=device).to(dtype_AB)[
:, : K * group_size
]
B = torch.randn(N, K * (group_size + strided), device=device).to(dtype_AB)[
:, : K * group_size
]
A_scale = torch.rand(group_size * M, device=device, dtype=dtype_scale)
B_scale = torch.rand(group_size * N, device=device, dtype=dtype_scale)
offs = torch.arange(K, group_size * K + 1, K, device=device, dtype=dtype_offset)
if mm_type == MMType.MM_2D_3D:
A = torch.randn(M * group_size, K * (1 + strided), device=device).to(dtype_AB)[
:, :K
]
B = torch.randn(
group_size * (1 + strided), N, K * (1 + strided), device=device
).to(dtype_AB)[:: (1 + strided), :, :K]
A_scale = torch.rand(group_size * M, device=device, dtype=dtype_scale)
B_scale = torch.rand(
group_size, N * (1 + strided), device=device, dtype=dtype_scale
)[:, :N]
offs = torch.arange(M, group_size * M + 1, M, device=device, dtype=dtype_offset)
if mm_type == MMType.MM_3D_2D:
A = torch.randn(
group_size * (1 + strided), M, K * (1 + strided), device=device
).to(dtype_AB)[:: (1 + strided), :, :K]
B = torch.randn(N * group_size, K * (1 + strided), device=device).to(dtype_AB)[
:, :K
]
A_scale = torch.rand(
group_size, M * (1 + strided), device=device, dtype=dtype_scale
)[:, :M]
B_scale = torch.rand(group_size * N, device=device, dtype=dtype_scale)
offs = torch.arange(N, group_size * N + 1, N, device=device, dtype=dtype_offset)
if mm_type == MMType.MM_3D_3D:
A = torch.randn(
group_size * (1 + strided), M, K * (1 + strided), device=device
).to(dtype_AB)[:: (1 + strided), :, :K]
B = torch.randn(
group_size * (1 + strided), N, K * (1 + strided), device=device
).to(dtype_AB)[:: (1 + strided), :, :K]
A_scale = torch.rand(group_size, M * (1 + strided), device=device).to(
dtype_scale
)[:, :M]
B_scale = torch.rand(group_size, N * (1 + strided), device=device).to(
dtype_scale
)[:, :N]
offs = None
if offs is not None:
if offs[0] >= 32:
offs[0] -= 16
offs[2] += 16
elif offs[0] >= 64:
offs[0] -= 16
offs[1] += 16
offs[2] -= 32
return A, B, A_scale, B_scale, offs
def validate():
def validate_helper(
mm_type,
group_size,
M,
N,
K,
device,
dtype_AB,
dtype_scale,
dtype_offset,
dtype_C,
use_fast_accum,
strided,
atol,
rtol,
):
torch._dynamo.reset()
A, B, A_scale, B_scale, offs = generate_data(
mm_type,
group_size,
M,
N,
K,
device,
dtype_AB,
dtype_scale,
dtype_offset,
strided,
)
C_ref = f_ref(
A,
B.transpose(-2, -1),
A_scale,
B_scale,
offs,
out_dtype=dtype_C,
use_fast_accum=use_fast_accum,
)
C = f(
A,
B.transpose(-2, -1),
A_scale,
B_scale,
offs,
out_dtype=dtype_C,
use_fast_accum=use_fast_accum,
)
assert torch.allclose(C, C_ref, atol=atol, rtol=rtol)
device = "cuda"
group_size = 4
M_range = [2**i for i in range(4, 6)]
N_range = [2**i for i in range(5, 8)]
K_range = [2**i for i in range(6, 9)]
dtype_AB = torch.float8_e4m3fn
dtype_scale = torch.float32
dtype_offset = torch.int32
dtype_C = torch.bfloat16
use_fast_accum_range = [False, True]
strided_range = [False, True]
atol = 1e-2
rtol = 1e-2
for mm_type, M, N, K, use_fast_accum, strided in product(
MMType, M_range, N_range, K_range, use_fast_accum_range, strided_range
):
validate_helper(
mm_type,
group_size,
M,
N,
K,
device,
dtype_AB,
dtype_scale,
dtype_offset,
dtype_C,
use_fast_accum,
strided,
atol,
rtol,
)
validate()(Note: to validate non-TMA load variant, change Todo: handle use_fast_accum case like CUTLASS does it... |
1. Enable strided inputs 2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs 3. Fix non-TMA load variant 4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor 5. Fix cases when group size along K dimension is not multiple of block size along K ghstack-source-id: e6016b7 Pull Request resolved: #150944
1. Enable strided inputs 2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs 3. Fix non-TMA load variant 4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor 5. Fix cases when group size along K dimension is not multiple of block size along K ghstack-source-id: 63a6271 Pull Request resolved: #150944
|
@davidberard98: I've made an update. I was not able to immediately utilize |
|
Pushed some minor changes, tested with Triton 3.2.0, 3.3.1 and latest main, works fine in all cases. Merging this if the CI passes. @davidberard98: If it get merged, no need to change anything in |
torch/utils/_triton.py
Outdated
| from triton.language import ( # noqa: F401 | ||
| _experimental_make_tensor_descriptor, |
There was a problem hiding this comment.
This change causes has_triton_tma_device() to always return false w/triton 3.2, because triton 3.2 doesn't have _experimental_make_tensor_descriptor.
To test this, you can build triton from source on the release/3.2.x branch. AFAIK we don't actually have any tests right now that fail due to this change, but I think it's still something we care about (we should probably add a test at some point...)
There was a problem hiding this comment.
Reverted - this change is not needed any more in this PR.
1. Enable strided inputs 2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs 3. Fix non-TMA load variant 4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor 5. Fix cases when group size along K dimension is not multiple of block size along K 6. Update meta registration 7. Updated synthetic offsets creation ghstack-source-id: 0d0b772 Pull Request resolved: #150944
davidberard98
left a comment
There was a problem hiding this comment.
Thanks for the fixes!
| _AMD_CONFIGS = [ | ||
| Config( | ||
| { | ||
| "BLOCK_M": block_size_m, | ||
| "BLOCK_N": block_size_n, | ||
| "BLOCK_K": block_size_k, | ||
| "waves_per_eu": waves_per_cu, |
There was a problem hiding this comment.
curious, why did you remove the AMD configs?
There was a problem hiding this comment.
The grouped MM ATen operators are written in CUTLASS, they won't work on AMD so there is nothing to auto-tune, and it's not possible to keep these tested. The configs are relic from Triton grouped MM FBGEMM kernel, that this kernel started from.
Edit: So nothing specific to remove them in the context of current updates, it's just that I noticed now that it's better without them, for now.
|
@alexsamardzic just wanted to double check with you: are you planning to add striding support for the mm / flex_attention kernels? In this PR, your approach is to skip TMA if we only have the 3.2 TMA APIs (experimental_device_tensormap_create2d), because experimental_device_tensormap_create2d doesn't support striding. I think we're okay with skipping TMA for Triton 3.2 for scaled_grouped_mm, but not for the other templates (persistent matmul & flex_attention). For persistent matmul & flex_attention, we'd probably need to either (a) wait until we can remove our 3.2 support (I'd estimate on the order of 1-2 months) or (b) add more complex handling to enable striding only when a new enough Triton version is available. |
I have no intention to touch other kernels. For grouped MM, it's just that 3.2 TMA APIs are inadequate, and fortunately non-TMA version is already there to support 3.2. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…py templates" Triton 3.4 will remove the experimental TMA APIs: triton-lang/triton#6488 For mm.py templates, this PR adds support for using the new APIs when they are available (and otherwise falls back to the experimental APIs). For flex_attention, we'll remove TMA support for Triton 3.2 and 3.3 (versions of triton that don't have the new API). For mm_scaled_grouped.py, #150944 will remove TMA support for Triton 3.2. Note: we attempted this earlier with #154858, but this broke TMA usage in Triton 3.2. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Differential Revision: [D76444471](https://our.internmc.facebook.com/intern/diff/D76444471) [ghstack-poisoned]
#155723) Triton 3.4 will remove the experimental TMA APIs: triton-lang/triton#6488 For mm.py templates, this PR adds support for using the new APIs when they are available (and otherwise falls back to the experimental APIs). For flex_attention, we'll remove TMA support for Triton 3.2 and 3.3 (versions of triton that don't have the new API). For mm_scaled_grouped.py, #150944 will remove TMA support for Triton 3.2. Note: we attempted this earlier with #154858, but this broke TMA usage in Triton 3.2. Differential Revision: [D76444471](https://our.internmc.facebook.com/intern/diff/D76444471) Pull Request resolved: #155723 Approved by: https://github.com/NikhilAPatel
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov