Skip to content

Update auto-tuning support for _scaled_grouped_mm#150944

Closed
alexsamardzic wants to merge 39 commits intogh/alexsamardzic/1/basefrom
gh/alexsamardzic/1/head
Closed

Update auto-tuning support for _scaled_grouped_mm#150944
alexsamardzic wants to merge 39 commits intogh/alexsamardzic/1/basefrom
gh/alexsamardzic/1/head

Conversation

@alexsamardzic
Copy link
Collaborator

@alexsamardzic alexsamardzic commented Apr 9, 2025

Stack from ghstack (oldest at bottom):

  1. Enable strided inputs
  2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
  3. Fix non-TMA load variant
  4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor
  5. Fix cases when group size along K dimension is not multiple of block size along K
  6. Updated meta registration
  7. Update synthetic offsets creation

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 9, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/150944

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 016b438 with merge base f34ab16 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@alexsamardzic
Copy link
Collaborator Author

alexsamardzic commented Apr 9, 2025

Validation script
from enum import Enum
from itertools import product

import torch


f_ref = torch._scaled_grouped_mm
f = torch.compile(
    f_ref,
    options={
        "max_autotune": True,
        "max_autotune_gemm_backends": "TRITON",
    },
)


class MMType(Enum):
    MM_2D_2D = 1
    MM_2D_3D = 2
    MM_3D_2D = 3
    MM_3D_3D = 4


def generate_data(
    mm_type, group_size, M, N, K, device, dtype_AB, dtype_scale, dtype_offset, strided
):
    if mm_type == MMType.MM_2D_2D:
        A = torch.randn(M, K * (group_size + strided), device=device).to(dtype_AB)[
            :, : K * group_size
        ]
        B = torch.randn(N, K * (group_size + strided), device=device).to(dtype_AB)[
            :, : K * group_size
        ]
        A_scale = torch.rand(group_size * M, device=device, dtype=dtype_scale)
        B_scale = torch.rand(group_size * N, device=device, dtype=dtype_scale)
        offs = torch.arange(K, group_size * K + 1, K, device=device, dtype=dtype_offset)

    if mm_type == MMType.MM_2D_3D:
        A = torch.randn(M * group_size, K * (1 + strided), device=device).to(dtype_AB)[
            :, :K
        ]
        B = torch.randn(
            group_size * (1 + strided), N, K * (1 + strided), device=device
        ).to(dtype_AB)[:: (1 + strided), :, :K]
        A_scale = torch.rand(group_size * M, device=device, dtype=dtype_scale)
        B_scale = torch.rand(
            group_size, N * (1 + strided), device=device, dtype=dtype_scale
        )[:, :N]
        offs = torch.arange(M, group_size * M + 1, M, device=device, dtype=dtype_offset)

    if mm_type == MMType.MM_3D_2D:
        A = torch.randn(
            group_size * (1 + strided), M, K * (1 + strided), device=device
        ).to(dtype_AB)[:: (1 + strided), :, :K]
        B = torch.randn(N * group_size, K * (1 + strided), device=device).to(dtype_AB)[
            :, :K
        ]
        A_scale = torch.rand(
            group_size, M * (1 + strided), device=device, dtype=dtype_scale
        )[:, :M]
        B_scale = torch.rand(group_size * N, device=device, dtype=dtype_scale)
        offs = torch.arange(N, group_size * N + 1, N, device=device, dtype=dtype_offset)

    if mm_type == MMType.MM_3D_3D:
        A = torch.randn(
            group_size * (1 + strided), M, K * (1 + strided), device=device
        ).to(dtype_AB)[:: (1 + strided), :, :K]
        B = torch.randn(
            group_size * (1 + strided), N, K * (1 + strided), device=device
        ).to(dtype_AB)[:: (1 + strided), :, :K]
        A_scale = torch.rand(group_size, M * (1 + strided), device=device).to(
            dtype_scale
        )[:, :M]
        B_scale = torch.rand(group_size, N * (1 + strided), device=device).to(
            dtype_scale
        )[:, :N]
        offs = None

    if offs is not None:
        if offs[0] >= 32:
            offs[0] -= 16
            offs[2] += 16
        elif offs[0] >= 64:
            offs[0] -= 16
            offs[1] += 16
            offs[2] -= 32

    return A, B, A_scale, B_scale, offs


def validate():
    def validate_helper(
        mm_type,
        group_size,
        M,
        N,
        K,
        device,
        dtype_AB,
        dtype_scale,
        dtype_offset,
        dtype_C,
        use_fast_accum,
        strided,
        atol,
        rtol,
    ):
        torch._dynamo.reset()

        A, B, A_scale, B_scale, offs = generate_data(
            mm_type,
            group_size,
            M,
            N,
            K,
            device,
            dtype_AB,
            dtype_scale,
            dtype_offset,
            strided,
        )

        C_ref = f_ref(
            A,
            B.transpose(-2, -1),
            A_scale,
            B_scale,
            offs,
            out_dtype=dtype_C,
            use_fast_accum=use_fast_accum,
        )
        C = f(
            A,
            B.transpose(-2, -1),
            A_scale,
            B_scale,
            offs,
            out_dtype=dtype_C,
            use_fast_accum=use_fast_accum,
        )
        assert torch.allclose(C, C_ref, atol=atol, rtol=rtol)

    device = "cuda"
    group_size = 4
    M_range = [2**i for i in range(4, 6)]
    N_range = [2**i for i in range(5, 8)]
    K_range = [2**i for i in range(6, 9)]
    dtype_AB = torch.float8_e4m3fn
    dtype_scale = torch.float32
    dtype_offset = torch.int32
    dtype_C = torch.bfloat16
    use_fast_accum_range = [False, True]
    strided_range = [False, True]
    atol = 1e-2
    rtol = 1e-2
    for mm_type, M, N, K, use_fast_accum, strided in product(
        MMType, M_range, N_range, K_range, use_fast_accum_range, strided_range
    ):
        validate_helper(
            mm_type,
            group_size,
            M,
            N,
            K,
            device,
            dtype_AB,
            dtype_scale,
            dtype_offset,
            dtype_C,
            use_fast_accum,
            strided,
            atol,
            rtol,
        )


validate()

(Note: to validate non-TMA load variant, change "USE_TMA_LOAD": True in mm_scaled_grouped.py to False.)


Todo: handle use_fast_accum case like CUTLASS does it...

[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Apr 9, 2025
1. Enable strided inputs
2. Implement "3d/3d" and "3d/2d" combinations of inputs
3. Fix non-TMA load variant

ghstack-source-id: e775b83
Pull Request resolved: #150944
[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Apr 10, 2025
1. Enable strided inputs
2. Implement "3d/3d" and "3d/2d" combinations of inputs
3. Fix non-TMA load variant

ghstack-source-id: 4be59f3
Pull Request resolved: #150944
[ghstack-poisoned]
[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Apr 19, 2025
1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant

ghstack-source-id: 5847dae
Pull Request resolved: #150944
[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Apr 19, 2025
1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant

ghstack-source-id: 2b0248e
Pull Request resolved: #150944
[ghstack-poisoned]
[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Apr 21, 2025
1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant
4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor

ghstack-source-id: 7a64328
Pull Request resolved: #150944
[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Apr 21, 2025
1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant
4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor
5. Fix cases when group size along K dimension is not multiple of block size along K

ghstack-source-id: e6016b7
Pull Request resolved: #150944
[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Apr 22, 2025
1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant
4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor
5. Fix cases when group size along K dimension is not multiple of block size along K

ghstack-source-id: 63a6271
Pull Request resolved: #150944
@alexsamardzic
Copy link
Collaborator Author

@davidberard98: I've made an update. I was not able to immediately utilize make_tensor_descriptor wrapper from triton_helpers.py as the reason for switching from tl.extra.cuda.experimental_device_tensormap_create2d in this PR was exactly to utilize striding for better performance and simpler code (note also that this PR deals with 3D tensors too). However, I've fixed it for Triton 3.2 along the same line as the new wrapper does it - the code is checking for tl attributes _experimental_make_tensor_descriptor and make_tensor_descriptor, and if these are not available, it will do non-TMA load. I hope this would be OK for now, and I'll keep an eye on helpers, and will update it once arbitrary striding an 3D tensors support provided there (I'm willing to add support for these myself too, but at the moment merging this PR, and PRs following up in the stack is of higher priority).

[ghstack-poisoned]
@alexsamardzic
Copy link
Collaborator Author

Pushed some minor changes, tested with Triton 3.2.0, 3.3.1 and latest main, works fine in all cases. Merging this if the CI passes.

@davidberard98: If it get merged, no need to change anything in scaled_grouped_mm.py when you push again your changes that got reverted.

Comment on lines +53 to +54
from triton.language import ( # noqa: F401
_experimental_make_tensor_descriptor,
Copy link
Contributor

@davidberard98 davidberard98 Jun 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change causes has_triton_tma_device() to always return false w/triton 3.2, because triton 3.2 doesn't have _experimental_make_tensor_descriptor.

To test this, you can build triton from source on the release/3.2.x branch. AFAIK we don't actually have any tests right now that fail due to this change, but I think it's still something we care about (we should probably add a test at some point...)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted - this change is not needed any more in this PR.

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Jun 11, 2025
1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant
4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor
5. Fix cases when group size along K dimension is not multiple of block size along K
6. Update meta registration
7. Updated synthetic offsets creation

ghstack-source-id: 0d0b772
Pull Request resolved: #150944
[ghstack-poisoned]
Copy link
Contributor

@davidberard98 davidberard98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fixes!

Comment on lines -62 to -68
_AMD_CONFIGS = [
Config(
{
"BLOCK_M": block_size_m,
"BLOCK_N": block_size_n,
"BLOCK_K": block_size_k,
"waves_per_eu": waves_per_cu,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious, why did you remove the AMD configs?

Copy link
Collaborator Author

@alexsamardzic alexsamardzic Jun 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The grouped MM ATen operators are written in CUTLASS, they won't work on AMD so there is nothing to auto-tune, and it's not possible to keep these tested. The configs are relic from Triton grouped MM FBGEMM kernel, that this kernel started from.

Edit: So nothing specific to remove them in the context of current updates, it's just that I noticed now that it's better without them, for now.

@davidberard98
Copy link
Contributor

@alexsamardzic just wanted to double check with you: are you planning to add striding support for the mm / flex_attention kernels?

In this PR, your approach is to skip TMA if we only have the 3.2 TMA APIs (experimental_device_tensormap_create2d), because experimental_device_tensormap_create2d doesn't support striding.

I think we're okay with skipping TMA for Triton 3.2 for scaled_grouped_mm, but not for the other templates (persistent matmul & flex_attention). For persistent matmul & flex_attention, we'd probably need to either (a) wait until we can remove our 3.2 support (I'd estimate on the order of 1-2 months) or (b) add more complex handling to enable striding only when a new enough Triton version is available.

@alexsamardzic
Copy link
Collaborator Author

@alexsamardzic just wanted to double check with you: are you planning to add striding support for the mm / flex_attention kernels?

...

I have no intention to touch other kernels. For grouped MM, it's just that 3.2 TMA APIs are inadequate, and fortunately non-TMA version is already there to support 3.2.

@alexsamardzic
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

davidberard98 added a commit that referenced this pull request Jun 11, 2025
…py templates"


Triton 3.4 will remove the experimental TMA APIs: triton-lang/triton#6488

For mm.py templates, this PR adds support for using the new APIs when they are available (and otherwise falls back to the experimental APIs).

For flex_attention, we'll remove TMA support for Triton 3.2 and 3.3 (versions of triton that don't have the new API).

For mm_scaled_grouped.py, #150944 will remove TMA support for Triton 3.2.

Note: we attempted this earlier with #154858, but this broke TMA usage in Triton 3.2.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D76444471](https://our.internmc.facebook.com/intern/diff/D76444471)

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Jun 12, 2025
#155723)

Triton 3.4 will remove the experimental TMA APIs: triton-lang/triton#6488

For mm.py templates, this PR adds support for using the new APIs when they are available (and otherwise falls back to the experimental APIs).

For flex_attention, we'll remove TMA support for Triton 3.2 and 3.3 (versions of triton that don't have the new API).

For mm_scaled_grouped.py, #150944 will remove TMA support for Triton 3.2.

Note: we attempted this earlier with #154858, but this broke TMA usage in Triton 3.2.

Differential Revision: [D76444471](https://our.internmc.facebook.com/intern/diff/D76444471)
Pull Request resolved: #155723
Approved by: https://github.com/NikhilAPatel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants