Add CUTLASS kernel as choice for (u)int8/(b)float16 mixed MM autotuning#119986
Add CUTLASS kernel as choice for (u)int8/(b)float16 mixed MM autotuning#119986alexsamardzic wants to merge 13 commits into
Conversation
…otuning [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/119986
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (8 Unrelated Failures)As of commit 6ca9970 with merge base 5b90074 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
This PR enables generating CUTLASS kernels as candidates for auto-tuning of mixed Example codeimport torch
from torch._inductor import config
_CUTLASS_DIR = ".../pytorch/third_party/cutlass"
max_autotune_gemm_backends = "CUTLASS"
dynamic = False
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
op = torch.mm
def my_op(a, b):
bt = b.T
return op(a, bt.to(a.dtype))
a = torch.rand((512, 2048), dtype=torch.float16).cuda()
b = torch.randint(0, 10, (1024, 2048), dtype=torch.int8).cuda()
dtype = a.dtype if a.element_size() >= b.element_size() else b.dtype
with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": False,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 8,
"use_mixed_mm": True,
}
):
Y_compiled = torch.compile(my_op, dynamic=dynamic)(a, b)
Y = my_op(a.to(dtype), b.to(dtype))
print(Y_compiled[0:5, 0:5])
print(Y[0:5, 0:5])
torch.testing.assert_close(Y_compiled, Y)Note that, just as mentioned for previous PR in this stack, CUTLASS can only handle the case when first operand is in row-major, and second operand in column-major layout. |
…nductor autotuning" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…nductor autotuning" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…MM autotuning" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
kadeng
left a comment
There was a problem hiding this comment.
Also looks good in principle, but I would like to see some tests added to test_max_autotune.py - Are cutlass Kernels actually picked during autotuning if you don't force them to, e.g. are there cases when they are fastest?
…MM autotuning" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…MM autotuning" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
|
Added pure MM operation test into Note that here results are actually from two rounds of benchmarking: Triton only supports row-major/row-major combination of layouts here, while CUTLASS only supports row-major/column-major combination of layouts. So, while CUTLASS is faster most of the times, it's not exactly apples-to-apples; on the other side, CUTLASS here provides an auto-tuning options that is not available without it. |
…MM autotuning" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
|
@kadeng - Could you take another look please? Thank you. |
…MM autotuning" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…MM autotuning" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
kadeng
left a comment
There was a problem hiding this comment.
Looks good to me, thanks for your contribution!
Before we can merge: There's a conflicting PR in the process of being merged #121489 that moves the Cutlass backend tests into a separate file called test_cutlass_backend.py. I think we should wait until that one lands and then also move the tests from this PR into test_cutlass_backend.py.
|
Sure, no problem. |
…MM autotuning" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…MM autotuning" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
|
Rebased on latest main, that now incudes #121489 - newly added tests moved into |
…MM autotuning" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…MM autotuning" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |

Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire @chauhang