Skip to content

Add CUTLASS kernel as choice for _int_mm() Inductor autotuning#119685

Closed
alexsamardzic wants to merge 10 commits into
gh/alexsamardzic/24/basefrom
gh/alexsamardzic/24/head
Closed

Add CUTLASS kernel as choice for _int_mm() Inductor autotuning#119685
alexsamardzic wants to merge 10 commits into
gh/alexsamardzic/24/basefrom
gh/alexsamardzic/24/head

Conversation

@pytorch-bot

pytorch-bot Bot commented Feb 12, 2024

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/119685

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (8 Unrelated Failures)

As of commit 6a3745b with merge base 5b90074 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@alexsamardzic

alexsamardzic commented Feb 12, 2024

Copy link
Copy Markdown
Collaborator Author

This PR enables generating CUTLASS kernels as candidates for auto-tuning of _int_mm() op (for int8 inputs).

Example code
import torch

from torch._inductor import config

_CUTLASS_DIR = ".../pytorch/third_party/cutlass"
max_autotune_gemm_backends = "CUTLASS,Triton"
dynamic = False

torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False

op = torch._int_mm
def my_op(a, b):
    return op(a, b.T )


a = torch.randint(-10, 10, (512, 2048), dtype=torch.int8).cuda()
b = torch.randint(-10, 10, (1024, 2048), dtype=torch.int8).cuda()

with config.patch(
    {
        "max_autotune": True,
        "autotune_in_subproc": False,
        "max_autotune_gemm_backends": max_autotune_gemm_backends,
        "cuda.cutlass_dir": _CUTLASS_DIR,
        "cuda.cutlass_max_profiling_configs": 8,
    }
):
    Y_compiled = torch.compile(my_op, dynamic=dynamic)(a, b)
    Y = my_op(a, b)
    torch.testing.assert_close(Y_compiled, Y)

The alignments checking is yet to be fixed (marked with FIXME in the code).

In my experimenting over a limited set of shapes, a CUTLASS kernel typically gets chosen by autotuning for larger input sizes, while for small input sizes a Triton kernel typically wins.

Edit: Note that for int8 inputs, CUTLASS can only handle the case when first operand is in row-major, and second operand in column-major layout. So it should work fine for linear operator, but it's not as general as for floating point inputs.

@ipiszy @cpuhrsch

@alexsamardzic

Copy link
Copy Markdown
Collaborator Author

As far as alignment checking mentioned in previous comment concerned: I may be wrong, but it seems to me these are all set by cutlass_library, so what would be the point of checking them against some values that we enumerate (in get_alignments() method)? In particular, for cases where accumulator is set by CUTLASS to be wider than one (in mixed dtypes GEMM case) or both (in 4-bit or 8-bit integer GEMM cases) inputs, CUTLASS seems to be choosing operand C alignment different than its usual 128 / cutlass::sizeof_bits<ElementC>::value (see op.C.alignment assignments in python/cutlass_library/generator.py in the CUTLASS source tree), and this is the reason why I had to add 8 in the list, in the second FIXME line in my PR, to make it work.

So my question/suggestion here is that we remove checking alignments, would that be OK?

@alexsamardzic alexsamardzic requested a review from ipiszy February 14, 2024 09:52
…ning"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
@kadeng

kadeng commented Feb 23, 2024

Copy link
Copy Markdown
Contributor

Looks good in principle, but we would definitely need additional unit tests. You can add a few new test methods to test_max_autotune.py by copying and modifying pre-existing cutlass backend tests in there. Do you expect performance benefits compared to Aten and Triton?

@alexsamardzic

Copy link
Copy Markdown
Collaborator Author

Thanks! I'll add tests, and also some benchmark results for both this and other PR. As mentioned above, in my initial quick benchmarking, generated CUTLASS kernels are indeed the fastest for some input shapes.

…ning"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
choices = []

if m * n != 0 and use_cutlass_template(layout):
CUTLASSGemmTemplate.add_cutlass_gemm_choices(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where are the various configs defined?

Below we use int8_mm_configs to see the various Triton kernels. Is there an equivalent for CUTLASS? If so, we likely need to update or unify that.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For CUTLASS, these are specified within the CUTLASS itself, in the generator code.

While the MMA PTX instructions eventually used to implement given operator should be the same for Triton and CUTLASS, there is number of possible differences: how are tiles transferred between global and shared memory, and then between shared memory and registers, then how is shared memory used so that single thread-block handles multiple tiles, etc. For the same reason, CUTLASS generator suggests different tiles for different data-types combinations etc. So I think there is no point in the unification. (On the other side, while working on this and other PR, I found some small omissions and inconsitencies in CUTLASS generator code, so I created this PR for CUTLASS to fix these - there may be more minor changes coming for CUTLASS generator like this).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that makes sense, but we should have a sense for the number of configurations that are being tried.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Triton, I believe all of the configs listed are tried during auto-tuning. CUTLASS generator code typically enumerates 10-15 configurations for given operation, but a disadvantage of CUTLASS is that generated code is heavily templated C++, that takes long to compile. For this reason, there is cuda.cutlass_max_profiling_configs parameter introduced into torch._inductor.config, making it possible for user to adjust the number of CUTLASS configurations tried during auto-tuning; the default setting for this parameter is that all CUTLASS configurations are tried too. So this is a trade-off for CUTLASS, admittedly it's still somewhat unsatisfactory in the sense that if cuda.cutlass_max_profiling_configs is set by user, in order to save compilation time, to less than number of configurations that CUTLASS generator lists for given operation, then there will be configurations listed by CUTLASS that won't be tried, and some among them could actually perform the best. (For the record, I set this parameter to 10 for the benchmarking presented below.)

In general, I still think that down the road it would be nice to build a heuristic, probably the best approach would be some kind of machine learning, that could be trained up-front on various GPUs, and that would make it possible for a more educated guess on the good configuration(s) to try/use, for both eager and compiled mode. It should not be hard to build, but would take time and resources to train, so it could be a nice idea for a side project.

@kadeng kadeng Mar 11, 2024

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, that machine learning based approach ( a fast Decision Tree, maybe combined with a Linear Regression
should suffice ) is also something that I suggested and was discussed a bit Pytorch internally. The cuda.cutlass_max_profiling_configs option is more something that should be used to limit computation time for unit tests, but it's not a good idea to use this when aiming for actual performance benefits through the Cutlass backend.

…ning"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
@alexsamardzic

Copy link
Copy Markdown
Collaborator Author

I've added a test into test_max_autotune.py. There are three kind of CUTLASS-related tests there:

  1. A smoke test.
  2. MM operation tests.
  3. Test involving MM with arbitrary epilogue.

The latest ones are for SM90, as there is no support for SM80 epilogues in the CUTLASS auto-tuning code in PyTorch yet - this is something that I'm intending to work on as a follow-up, thus I'm only adding MM operation test for _int_mm() for now.

I also did some benchmarking. Here is my benchmarking script, the set of shapes I used for benchmarking are shapes that I have as "Llama shapes".

Benchmarking script (`bench.py`)
import sys

import io
from contextlib import redirect_stdout, redirect_stderr

import torch
from torch._inductor import config

int8 = mixed = None
if len(sys.argv) >= 2:
    int8 = sys.argv[1].lower() == "int8"
    mixed = sys.argv[1].lower() == "mixed"
assert (int8 or mixed)

_CUTLASS_DIR = "./third_party/cutlass"
max_autotune_gemm_backends = "Triton,CUTLASS"
dynamic = False

if mixed:
    torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False

shapes = [
    # llama shapes
    (4096, 11008, 64),
    (12288, 4096, 64),
    (22016, 4096, 64),
    (4096, 4096, 64),
    (65536, 4096, 64),
    (4096, 11008, 16),
    (12288, 4096, 16),
    (22016, 4096, 16),
    (4096, 4096, 16),
]

if int8:
    def fn(a, b):
        return torch._int_mm(a, b)
if mixed:
    def fn(a, b):
        return torch.mm(a, b.to(a.dtype))


torch._dynamo.config.cache_size_limit = len(shapes)
print("shape|kernel|time")
for m, k, n in shapes:
    if int8:
        a = torch.randint(0, 5, (n, k), dtype=torch.int8).cuda()
    if mixed:
        a = torch.rand((n, k), dtype=torch.float16).cuda()
    b = torch.randint(0, 5, (m, k), dtype=torch.int8).cuda().T
    dtype = a.dtype if a.element_size() >= b.element_size() else b.dtype

    with config.patch(
        {
            "max_autotune": True,
            "autotune_in_subproc": False,
            "max_autotune_gemm_backends": max_autotune_gemm_backends,
            "compile_threads": 10,
            "cuda.cutlass_dir": _CUTLASS_DIR,
            "cuda.cutlass_max_profiling_configs": 10,
            "use_mixed_mm": mixed,
        }
    ):
        f = io.StringIO()
        with redirect_stdout(None), redirect_stderr(f):
            Y_compiled = torch.compile(fn, dynamic=dynamic)(a, b)

        #Y = fn(a.to(dtype), b.to(dtype))
        #torch.testing.assert_close(Y_compiled, Y)

        lines = f.getvalue().splitlines()
        triton_line = cutlass_line = ""
        for line in lines:
            if triton_line and cutlass_line:
                break
            line = line.strip()
            if not triton_line and line.startswith("triton_mm"):
                triton_line = line
            if not cutlass_line and line.startswith("cuda_cutlass_gemm"):
                cutlass_line = line
        triton_time = triton_line.split()[1] if triton_line else triton_line
        cutlass_time = cutlass_line.split()[1] if cutlass_line else cutlass_line
        print(f"{n},{k},{m}|Triton|{triton_time}")
        print(f"{n},{k},{m}|CUTLASS|{cutlass_time}")

Here are the results (smaller values are better, and values are for the best performing Triton/CUTLASS kernel generated; note that I've put 10 CUTLASS kernels in the mix for auto-tuning):

image

As mentioned above, Triton generated code is on average slightly faster than CUTLASS generated code for smaller shapes. If I increase the batch sizes from 16 to 256 and from 64 to 1024, and keep remaining dimensions the same, then CUTLASS is visibly better than Triton, i.e. a CUTLASS kernel get selected over the Triton kernel for the most of the time:

image

One has to keep in mind that CUTLASS supports only row-major/column-major combination of layouts for _int_mm() (but that's good for linear operator: X @ W.T, if both X and W are contiguous tensors in row-major order, as usual with PyTorch), and that CUTLASS kernels, as it's heavily templated C++ code, takes quite longer than Triton kernels to compile during auto-tuning. Still, as it could be seen above, CUTLASS can provide performance improvement.


For my record, here is an R script that I used for creating plots above:

Plotting script (`plot.R`)
library(ggplot2)

df <- read.csv(file("stdin"), sep="|")

p <- ggplot(data=df, aes(x=shape, y=time, fill=kernel)) +
    geom_bar(stat="identity", position=position_dodge()) +
    ggtitle("@torch.compile F16(nxk) @ S8(kxm): Triton vs. CUTLASS") +
    xlab("n,k,m") +
    ylab("time (ms)")
ggsave("plot.png", width=10, height=6)

and I would run benchmarking and plotting in a pipeline:

./bench.py int8 | Rscript plot.R

@Jokeren

Jokeren commented Mar 5, 2024

Copy link
Copy Markdown
Contributor

How did you measure the time spent on triton kernels?

@Jokeren

Jokeren commented Mar 5, 2024

Copy link
Copy Markdown
Contributor

I wonder if the time obtained from nsys showing the same histogram? Considering that triton has higher runtime cost.

@alexsamardzic

Copy link
Copy Markdown
Collaborator Author

Values depicted on charts are timings reported by Inductor auto-tuning procedure - these timings are actually used to decide which kernel should be selected to execute given operation in @torch.compile-d code. The actual code doing benchmarking is in torch/_inductor/autotune_process.py, pieces to start from are TritonBenchmarkRequest and CUDABenchmarkRequest classes, respectively. Note that, as discussed above, Triton and CUTLASS kernels may not be compared over the same set of run-time configurations - but from Inductor point of view, it's important to come up with fastest kernel to perform an operation, and that's it. I'll run kernels in a profiler, and will get back with results here.

@alexsamardzic

Copy link
Copy Markdown
Collaborator Author

Profiled with nsys. I've chosen the leftmost shapes on the bottom chart from above for profiling, so n, k, m = 1024, 11008, 4096. I ran my benchmarking script once with these shapes only, and auto-tuning output printed, and then from the output I found the best performing Triton and CUTLASS kernels. Their configurations were:

{"config": (256, 128, 128, 3, 8), "cond": torch.version.hip is None},

for Triton kernel, from int8_mm_kernel_configs in torch/_inductor/kernel/mm_common.py, and:

TileDescription([128, 128,  64],  5, [2, 2, 1], math_inst, min_cc, max_cc),

for CUTLASS kernel, from third_party/cutlass/python/cutlass_library/generator.py, respectively.

I then temporarily removed other configurations from mentioned files, in order to have less clutter in the nsys results, and then the auto-tuning output was:

AUTOTUNE int_mm(1024x11008, 11008x4096)
  cuda_cutlass_gemm_0 0.2751 ms 100.0%
  triton_mm_0 0.3735 ms 73.7%

and these were indeed best the Triton and CUTLASS timings reported by auto-tuning initially.

Now I was able to run nsys, to get kernels timings only for these two best candidates:

nsys profile python bench.py int8

Here is a screenshot of nsys-ui after nsys results loaded:

image

Auto-tuning apparently consists of running given kernel number of times, first in succession over the same data, and later seemingly on different data, namely this vectorized_elementwise_kernel that appears on the screenshot above for both Triton and CUTLASS auto-tuning is about applying FillFunctor.

There is not much to see when zooming in either in Triton or CUTLASS auto-tuning segments - just a bunch of successive calls of the corresponding kernel, later interspersed with above mentioned fills. Here is how it looks like in nsys-ui for Triton kernel:

image

However, one could see from the timeline that Triton kernel execution time is indeed close to 0.37ms, reported as an average for Triton kernels for the auto-tuning process. Here is a copy of the text from the tooltip that nsys-ui displays when hovering over a kernel, confirming this:

triton_mm_0d1d2d
Begins: 51.5264s
Ends: 51.5268s (+368.706 μs)
grid:  <<<128, 1, 1>>>
block: <<<256, 1, 1>>>
Launch Type: Regular
Static Shared Memory: 0 bytes
Dynamic Shared Memory: 98,304 bytes
Registers Per Thread: 254
Local Memory Per Thread: 0 bytes
Local Memory Total: 205,258,752 bytes
Shared Memory executed: 167,936 bytes
Shared Memory Bank Size: 4 B
Theoretical occupancy: 12.5 %
Launched from thread: 5981
Latency: ←3.948 ms
Correlation ID: 10837
Stream: Default stream 7

Alike for CUTLASS kernel, here is a screenshot of nsys-ui after zooming in:

image

and here is the contents of corresponding tooltip, again confirming that the time reported by auto-tuning is actually the kernel execution time:

Kernel2
Begins: 49.136s
Ends: 49.1363s (+269.954 μs)
grid:  <<<64, 4, 1>>>
block: <<<128, 1, 1>>>
Launch Type: Regular
Static Shared Memory: 0 bytes
Dynamic Shared Memory: 81,920 bytes
Registers Per Thread: 230
Local Memory Per Thread: 0 bytes
Local Memory Total: 205,258,752 bytes
Shared Memory executed: 167,936 bytes
Shared Memory Bank Size: 4 B
Theoretical occupancy: 12.5 %
Launched from thread: 5981
Latency: ←3.549 ms
Correlation ID: 1182
Stream: Default stream 7

Thus, my conclusion is that auto-tuning reports kernel execution times only, and that values on the charts above are valid comparison between Triton and CUTLASS kernels, for given set of configurations used for auto-tuning. All of the other caveats mentioned above stay, in particular this doesn't mean that either Triton or CUTLASS is better option for @torch.compile-d code in general - it only means that CUTLASS kernels may be a meaningful option to add into the mix.

@Jokeren

Jokeren commented Mar 6, 2024

Copy link
Copy Markdown
Contributor

The kernel activities from cutlass and triton make sense to me. Thanks for investigation.

My thought was that in theory triton should be as efficient as cutlass, but in this case it's 30% slower. Not sure if it's an important case for your guys? Maybe @bertmaher will be interested to investigate?

@cpuhrsch

cpuhrsch commented Mar 6, 2024

Copy link
Copy Markdown
Contributor

@Jokeren - I suspect there's just not enough configs that the autotuner tires. Looks like CUTLASS tries more / different configs. So, maybe Triton would get the exact same perf (or better) if it was given the same config.


if m * n != 0 and use_cutlass_template(layout):
CUTLASSGemmTemplate.add_cutlass_gemm_choices(
choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is both fuseable and non_fuseable True? I'm mostly just confused by the names here.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's that an ordinary _int_mm() operation is found by parser, so no fusion needed here. On the other side, CUTLASS code generator differentiates between creating fuseable or non-fuseable MM kernels; fusion is not needed in this particular case, so auto-tuner could (and should) try both fuseable and non-fuseable kernels, in order to find the best performing one. (I was confused by names on the first sight too, but they actually make sense.)

@alexsamardzic

Copy link
Copy Markdown
Collaborator Author

A Triton configuration matching the CUTLASS one should be:

{"config": (128, 128, 64, 5, 4), "cond": torch.version.hip is None},

With this configuration, Triton kernel performance is indeed practically the same as for CUTLASS kernel, here is the output of auto-tuning:

AUTOTUNE int_mm(1024x11008, 11008x4096)
  cuda_cutlass_gemm_0 0.2744 ms 100.0%
  triton_mm_0 0.2796 ms 98.2%

@cpuhrsch

cpuhrsch commented Mar 8, 2024

Copy link
Copy Markdown
Contributor

@pytorchbot merge

@kadeng

kadeng commented Mar 13, 2024

Copy link
Copy Markdown
Contributor

@pytorchbot merge

@pytorchmergebot

Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot

Copy link
Copy Markdown
Collaborator

Merge failed

Reason: 4 jobs have failed, first few of them are: linux-binary-manywheel, trunk, linux-binary-libtorch-pre-cxx11, linux-binary-libtorch-cxx11-abi

Details for Dev Infra team Raised by workflow job

@kadeng

kadeng commented Mar 13, 2024

Copy link
Copy Markdown
Contributor

@cpuhrsch The merge I started failed immediately. I assume it could be because the diff was reverted in a previous version, or it could be a temporary infra issue. Let's retry a bit later..

@alexsamardzic

Copy link
Copy Markdown
Collaborator Author

The merge I started failed immediately. I assume it could be because the diff was reverted in a previous version, or it could be a temporary infra issue. Let's retry a bit later..

I don't get what are these failures about - if I get to the CI page, everything looks fine.

@lezcano

lezcano commented Mar 13, 2024

Copy link
Copy Markdown
Collaborator

@pytorchbot merge

@pytorchmergebot

Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot

Copy link
Copy Markdown
Collaborator

Merge failed

Reason: 4 jobs have failed, first few of them are: linux-binary-manywheel, trunk, linux-binary-libtorch-pre-cxx11, linux-binary-libtorch-cxx11-abi

Details for Dev Infra team Raised by workflow job

@alexsamardzic

Copy link
Copy Markdown
Collaborator Author

Let me rebase on latest main, in the hope that CI will pass on the next try.

…ning"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
@cpuhrsch

Copy link
Copy Markdown
Contributor

@pytorchbot merge

@pytorchmergebot

Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot

Copy link
Copy Markdown
Collaborator

Merge failed

Reason: 2 jobs have failed, first few of them are: trunk, linux-binary-manywheel

Details for Dev Infra team Raised by workflow job

…ning"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
@alexsamardzic

Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge

@pytorchmergebot

Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Mar 14, 2024
@github-actions github-actions Bot deleted the gh/alexsamardzic/24/head branch April 14, 2024 02:20
@cpuhrsch cpuhrsch mentioned this pull request May 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants