Skip to content

Add CUTLASS kernel as choice for (u)int8/(b)float16 mixed MM autotuning#119986

Closed
alexsamardzic wants to merge 13 commits into
gh/alexsamardzic/25/basefrom
gh/alexsamardzic/25/head
Closed

Add CUTLASS kernel as choice for (u)int8/(b)float16 mixed MM autotuning#119986
alexsamardzic wants to merge 13 commits into
gh/alexsamardzic/25/basefrom
gh/alexsamardzic/25/head

Conversation

@pytorch-bot

pytorch-bot Bot commented Feb 15, 2024

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/119986

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (8 Unrelated Failures)

As of commit 6ca9970 with merge base 5b90074 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@alexsamardzic

alexsamardzic commented Feb 15, 2024

Copy link
Copy Markdown
Collaborator Author

This PR enables generating CUTLASS kernels as candidates for auto-tuning of mixed mm() op for cases where one of inputs is either int8 or uint8, and other input is either float16 or bfloat16.

Example code
import torch

from torch._inductor import config

_CUTLASS_DIR = ".../pytorch/third_party/cutlass"
max_autotune_gemm_backends = "CUTLASS"
dynamic = False

torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False

op = torch.mm
def my_op(a, b):
    bt = b.T
    return op(a, bt.to(a.dtype))


a = torch.rand((512, 2048), dtype=torch.float16).cuda()
b = torch.randint(0, 10, (1024, 2048), dtype=torch.int8).cuda()
dtype = a.dtype if a.element_size() >= b.element_size() else b.dtype

with config.patch(
    {
        "max_autotune": True,
        "autotune_in_subproc": False,
        "max_autotune_gemm_backends": max_autotune_gemm_backends,
        "cuda.cutlass_dir": _CUTLASS_DIR,
        "cuda.cutlass_max_profiling_configs": 8,
        "use_mixed_mm": True,
    }
):
    Y_compiled = torch.compile(my_op, dynamic=dynamic)(a, b)
    Y = my_op(a.to(dtype), b.to(dtype))
    print(Y_compiled[0:5, 0:5])
    print(Y[0:5, 0:5])
    torch.testing.assert_close(Y_compiled, Y)

Note that, just as mentioned for previous PR in this stack, CUTLASS can only handle the case when first operand is in row-major, and second operand in column-major layout.

@ipiszy @cpuhrsch

@alexsamardzic alexsamardzic requested a review from ipiszy February 15, 2024 14:22
…nductor autotuning"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Feb 15, 2024
…nductor autotuning"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
@alexsamardzic alexsamardzic changed the title Add CUTLASS kernel as choice for (u)int8/(b)float16 mm() Inductor autotuning Add CUTLASS kernel as choice for (u)int8/(b)float16 mixed MM autotuning Feb 15, 2024
…MM autotuning"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]

@kadeng kadeng left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also looks good in principle, but I would like to see some tests added to test_max_autotune.py - Are cutlass Kernels actually picked during autotuning if you don't force them to, e.g. are there cases when they are fastest?

…MM autotuning"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Feb 26, 2024
…MM autotuning"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
@alexsamardzic

alexsamardzic commented Feb 28, 2024

Copy link
Copy Markdown
Collaborator Author

Added pure MM operation test into test_max_autotune.py, alike to what mentioned in this comment for _int_mm() operator tuning. Also, here are benchmarking results of CUTLASS vs. Triton generated kernels for "Llama shapes" (benchmarking script given in the same comment, it only has to be run with mixed command line argument instead of int8):

image

Note that here results are actually from two rounds of benchmarking: Triton only supports row-major/row-major combination of layouts here, while CUTLASS only supports row-major/column-major combination of layouts. So, while CUTLASS is faster most of the times, it's not exactly apples-to-apples; on the other side, CUTLASS here provides an auto-tuning options that is not available without it.

…MM autotuning"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Feb 29, 2024
@cpuhrsch cpuhrsch requested a review from kadeng March 9, 2024 04:02
@cpuhrsch

cpuhrsch commented Mar 9, 2024

Copy link
Copy Markdown
Contributor

@kadeng - Could you take another look please? Thank you.

Comment thread torch/_inductor/kernel/mm.py
…MM autotuning"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Mar 11, 2024
…MM autotuning"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Mar 12, 2024

@kadeng kadeng left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks for your contribution!

Before we can merge: There's a conflicting PR in the process of being merged #121489 that moves the Cutlass backend tests into a separate file called test_cutlass_backend.py. I think we should wait until that one lands and then also move the tests from this PR into test_cutlass_backend.py.

@alexsamardzic

Copy link
Copy Markdown
Collaborator Author

Sure, no problem.

…MM autotuning"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
…MM autotuning"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
@alexsamardzic

alexsamardzic commented Mar 13, 2024

Copy link
Copy Markdown
Collaborator Author

Rebased on latest main, that now incudes #121489 - newly added tests moved into test_cutlass_backend.py.

…MM autotuning"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Mar 13, 2024
…MM autotuning"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Mar 14, 2024
@alexsamardzic

Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 14, 2024
@pytorchmergebot

Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions Bot deleted the gh/alexsamardzic/25/head branch April 14, 2024 02:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants