Add CUTLASS kernel as choice for _int_mm() Inductor autotuning#119685
Add CUTLASS kernel as choice for _int_mm() Inductor autotuning#119685alexsamardzic wants to merge 10 commits into
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/119685
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (8 Unrelated Failures)As of commit 6a3745b with merge base 5b90074 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
This PR enables generating CUTLASS kernels as candidates for auto-tuning of Example codeimport torch
from torch._inductor import config
_CUTLASS_DIR = ".../pytorch/third_party/cutlass"
max_autotune_gemm_backends = "CUTLASS,Triton"
dynamic = False
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
op = torch._int_mm
def my_op(a, b):
return op(a, b.T )
a = torch.randint(-10, 10, (512, 2048), dtype=torch.int8).cuda()
b = torch.randint(-10, 10, (1024, 2048), dtype=torch.int8).cuda()
with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": False,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 8,
}
):
Y_compiled = torch.compile(my_op, dynamic=dynamic)(a, b)
Y = my_op(a, b)
torch.testing.assert_close(Y_compiled, Y)The alignments checking is yet to be fixed (marked with In my experimenting over a limited set of shapes, a CUTLASS kernel typically gets chosen by autotuning for larger input sizes, while for small input sizes a Triton kernel typically wins. Edit: Note that for int8 inputs, CUTLASS can only handle the case when first operand is in row-major, and second operand in column-major layout. So it should work fine for linear operator, but it's not as general as for floating point inputs. |
|
As far as alignment checking mentioned in previous comment concerned: I may be wrong, but it seems to me these are all set by So my question/suggestion here is that we remove checking alignments, would that be OK? |
…ning" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
|
Looks good in principle, but we would definitely need additional unit tests. You can add a few new test methods to test_max_autotune.py by copying and modifying pre-existing cutlass backend tests in there. Do you expect performance benefits compared to Aten and Triton? |
|
Thanks! I'll add tests, and also some benchmark results for both this and other PR. As mentioned above, in my initial quick benchmarking, generated CUTLASS kernels are indeed the fastest for some input shapes. |
…ning" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
| choices = [] | ||
|
|
||
| if m * n != 0 and use_cutlass_template(layout): | ||
| CUTLASSGemmTemplate.add_cutlass_gemm_choices( |
There was a problem hiding this comment.
Where are the various configs defined?
Below we use int8_mm_configs to see the various Triton kernels. Is there an equivalent for CUTLASS? If so, we likely need to update or unify that.
There was a problem hiding this comment.
For CUTLASS, these are specified within the CUTLASS itself, in the generator code.
While the MMA PTX instructions eventually used to implement given operator should be the same for Triton and CUTLASS, there is number of possible differences: how are tiles transferred between global and shared memory, and then between shared memory and registers, then how is shared memory used so that single thread-block handles multiple tiles, etc. For the same reason, CUTLASS generator suggests different tiles for different data-types combinations etc. So I think there is no point in the unification. (On the other side, while working on this and other PR, I found some small omissions and inconsitencies in CUTLASS generator code, so I created this PR for CUTLASS to fix these - there may be more minor changes coming for CUTLASS generator like this).
There was a problem hiding this comment.
Ok, that makes sense, but we should have a sense for the number of configurations that are being tried.
There was a problem hiding this comment.
For Triton, I believe all of the configs listed are tried during auto-tuning. CUTLASS generator code typically enumerates 10-15 configurations for given operation, but a disadvantage of CUTLASS is that generated code is heavily templated C++, that takes long to compile. For this reason, there is cuda.cutlass_max_profiling_configs parameter introduced into torch._inductor.config, making it possible for user to adjust the number of CUTLASS configurations tried during auto-tuning; the default setting for this parameter is that all CUTLASS configurations are tried too. So this is a trade-off for CUTLASS, admittedly it's still somewhat unsatisfactory in the sense that if cuda.cutlass_max_profiling_configs is set by user, in order to save compilation time, to less than number of configurations that CUTLASS generator lists for given operation, then there will be configurations listed by CUTLASS that won't be tried, and some among them could actually perform the best. (For the record, I set this parameter to 10 for the benchmarking presented below.)
In general, I still think that down the road it would be nice to build a heuristic, probably the best approach would be some kind of machine learning, that could be trained up-front on various GPUs, and that would make it possible for a more educated guess on the good configuration(s) to try/use, for both eager and compiled mode. It should not be hard to build, but would take time and resources to train, so it could be a nice idea for a side project.
There was a problem hiding this comment.
Agreed, that machine learning based approach ( a fast Decision Tree, maybe combined with a Linear Regression
should suffice ) is also something that I suggested and was discussed a bit Pytorch internally. The cuda.cutlass_max_profiling_configs option is more something that should be used to limit computation time for unit tests, but it's not a good idea to use this when aiming for actual performance benefits through the Cutlass backend.
…ning" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
|
I've added a test into
The latest ones are for SM90, as there is no support for SM80 epilogues in the CUTLASS auto-tuning code in PyTorch yet - this is something that I'm intending to work on as a follow-up, thus I'm only adding MM operation test for I also did some benchmarking. Here is my benchmarking script, the set of shapes I used for benchmarking are shapes that I have as "Llama shapes". Benchmarking script (`bench.py`)import sys
import io
from contextlib import redirect_stdout, redirect_stderr
import torch
from torch._inductor import config
int8 = mixed = None
if len(sys.argv) >= 2:
int8 = sys.argv[1].lower() == "int8"
mixed = sys.argv[1].lower() == "mixed"
assert (int8 or mixed)
_CUTLASS_DIR = "./third_party/cutlass"
max_autotune_gemm_backends = "Triton,CUTLASS"
dynamic = False
if mixed:
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
shapes = [
# llama shapes
(4096, 11008, 64),
(12288, 4096, 64),
(22016, 4096, 64),
(4096, 4096, 64),
(65536, 4096, 64),
(4096, 11008, 16),
(12288, 4096, 16),
(22016, 4096, 16),
(4096, 4096, 16),
]
if int8:
def fn(a, b):
return torch._int_mm(a, b)
if mixed:
def fn(a, b):
return torch.mm(a, b.to(a.dtype))
torch._dynamo.config.cache_size_limit = len(shapes)
print("shape|kernel|time")
for m, k, n in shapes:
if int8:
a = torch.randint(0, 5, (n, k), dtype=torch.int8).cuda()
if mixed:
a = torch.rand((n, k), dtype=torch.float16).cuda()
b = torch.randint(0, 5, (m, k), dtype=torch.int8).cuda().T
dtype = a.dtype if a.element_size() >= b.element_size() else b.dtype
with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": False,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"compile_threads": 10,
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 10,
"use_mixed_mm": mixed,
}
):
f = io.StringIO()
with redirect_stdout(None), redirect_stderr(f):
Y_compiled = torch.compile(fn, dynamic=dynamic)(a, b)
#Y = fn(a.to(dtype), b.to(dtype))
#torch.testing.assert_close(Y_compiled, Y)
lines = f.getvalue().splitlines()
triton_line = cutlass_line = ""
for line in lines:
if triton_line and cutlass_line:
break
line = line.strip()
if not triton_line and line.startswith("triton_mm"):
triton_line = line
if not cutlass_line and line.startswith("cuda_cutlass_gemm"):
cutlass_line = line
triton_time = triton_line.split()[1] if triton_line else triton_line
cutlass_time = cutlass_line.split()[1] if cutlass_line else cutlass_line
print(f"{n},{k},{m}|Triton|{triton_time}")
print(f"{n},{k},{m}|CUTLASS|{cutlass_time}")Here are the results (smaller values are better, and values are for the best performing Triton/CUTLASS kernel generated; note that I've put 10 CUTLASS kernels in the mix for auto-tuning): As mentioned above, Triton generated code is on average slightly faster than CUTLASS generated code for smaller shapes. If I increase the batch sizes from 16 to 256 and from 64 to 1024, and keep remaining dimensions the same, then CUTLASS is visibly better than Triton, i.e. a CUTLASS kernel get selected over the Triton kernel for the most of the time: One has to keep in mind that CUTLASS supports only row-major/column-major combination of layouts for For my record, here is an R script that I used for creating plots above: Plotting script (`plot.R`)library(ggplot2)
df <- read.csv(file("stdin"), sep="|")
p <- ggplot(data=df, aes(x=shape, y=time, fill=kernel)) +
geom_bar(stat="identity", position=position_dodge()) +
ggtitle("@torch.compile F16(nxk) @ S8(kxm): Triton vs. CUTLASS") +
xlab("n,k,m") +
ylab("time (ms)")
ggsave("plot.png", width=10, height=6)and I would run benchmarking and plotting in a pipeline: ./bench.py int8 | Rscript plot.R |
|
How did you measure the time spent on triton kernels? |
|
I wonder if the time obtained from |
|
Values depicted on charts are timings reported by Inductor auto-tuning procedure - these timings are actually used to decide which kernel should be selected to execute given operation in |
|
Profiled with for Triton kernel, from for CUTLASS kernel, from I then temporarily removed other configurations from mentioned files, in order to have less clutter in the and these were indeed best the Triton and CUTLASS timings reported by auto-tuning initially. Now I was able to run Here is a screenshot of Auto-tuning apparently consists of running given kernel number of times, first in succession over the same data, and later seemingly on different data, namely this There is not much to see when zooming in either in Triton or CUTLASS auto-tuning segments - just a bunch of successive calls of the corresponding kernel, later interspersed with above mentioned fills. Here is how it looks like in However, one could see from the timeline that Triton kernel execution time is indeed close to 0.37ms, reported as an average for Triton kernels for the auto-tuning process. Here is a copy of the text from the tooltip that Alike for CUTLASS kernel, here is a screenshot of and here is the contents of corresponding tooltip, again confirming that the time reported by auto-tuning is actually the kernel execution time: Thus, my conclusion is that auto-tuning reports kernel execution times only, and that values on the charts above are valid comparison between Triton and CUTLASS kernels, for given set of configurations used for auto-tuning. All of the other caveats mentioned above stay, in particular this doesn't mean that either Triton or CUTLASS is better option for |
|
The kernel activities from cutlass and triton make sense to me. Thanks for investigation. My thought was that in theory triton should be as efficient as cutlass, but in this case it's 30% slower. Not sure if it's an important case for your guys? Maybe @bertmaher will be interested to investigate? |
|
@Jokeren - I suspect there's just not enough configs that the autotuner tires. Looks like CUTLASS tries more / different configs. So, maybe Triton would get the exact same perf (or better) if it was given the same config. |
|
|
||
| if m * n != 0 and use_cutlass_template(layout): | ||
| CUTLASSGemmTemplate.add_cutlass_gemm_choices( | ||
| choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True |
There was a problem hiding this comment.
Why is both fuseable and non_fuseable True? I'm mostly just confused by the names here.
There was a problem hiding this comment.
It's that an ordinary _int_mm() operation is found by parser, so no fusion needed here. On the other side, CUTLASS code generator differentiates between creating fuseable or non-fuseable MM kernels; fusion is not needed in this particular case, so auto-tuner could (and should) try both fuseable and non-fuseable kernels, in order to find the best performing one. (I was confused by names on the first sight too, but they actually make sense.)
|
A Triton configuration matching the CUTLASS one should be: With this configuration, Triton kernel performance is indeed practically the same as for CUTLASS kernel, here is the output of auto-tuning: |
|
@pytorchbot merge |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 4 jobs have failed, first few of them are: linux-binary-manywheel, trunk, linux-binary-libtorch-pre-cxx11, linux-binary-libtorch-cxx11-abi Details for Dev Infra teamRaised by workflow job |
|
@cpuhrsch The merge I started failed immediately. I assume it could be because the diff was reverted in a previous version, or it could be a temporary infra issue. Let's retry a bit later.. |
I don't get what are these failures about - if I get to the CI page, everything looks fine. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 4 jobs have failed, first few of them are: linux-binary-manywheel, trunk, linux-binary-libtorch-pre-cxx11, linux-binary-libtorch-cxx11-abi Details for Dev Infra teamRaised by workflow job |
|
Let me rebase on latest main, in the hope that CI will pass on the next try. |
…ning" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 jobs have failed, first few of them are: trunk, linux-binary-manywheel Details for Dev Infra teamRaised by workflow job |
…ning" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…ng (#119986) Pull Request resolved: #119986 Approved by: https://github.com/kadeng ghstack dependencies: #119685





Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire @chauhang