Skip to content

Conversation

@yaoyaoding
Copy link
Member

@yaoyaoding yaoyaoding commented Oct 22, 2025

This PR adds an example to use tcgen05.mma and necessary to write a matrix multiplication kernel with Tilus.

This example is used demonstrate a minimal example of matmul on B200 using tcgen05:

import os

import pandas
import tilus
import torch
from tilus import float16, float32, int32, uint32
from tilus.utils import benchmark_func, cdiv

# from tilus.extensions.hidet.utils.ncu_utils import ncu_run

if not tilus.target.get_current_target().supports(tilus.target.nvgpu_sm100a):
    # skip this example if the current target does not support nvgpu_sm100a
    exit(0)

tilus.option.cache_dir(os.path.join(os.path.dirname(__file__), "cache"))
tilus.option.debug.dump_ir()

# tilus.target.set_current_target(tilus.target.nvgpu_sm100a)


class BlackwellMatmul(tilus.Script):
    def __init__(self):
        super().__init__()
        self.block_m = 128
        self.block_n = 64
        self.block_k = 16

    def __call__(
        self,
        m_size: int32,
        n_size: int,
        k_size: int,
        a_ptr: ~float16,
        b_ptr: ~float16,
        c_ptr: ~float16,
    ):
        self.attrs.blocks = [cdiv(m_size, self.block_m), cdiv(n_size, self.block_n)]
        self.attrs.warps = 4

        offset_m: int32 = self.block_m * self.blockIdx.x
        offset_n: int32 = self.block_n * self.blockIdx.y

        g_a = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
        g_b = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size])
        s_a = self.shared_tensor(dtype=float16, shape=[self.block_m, self.block_k])
        s_b = self.shared_tensor(dtype=float16, shape=[self.block_n, self.block_k])

        # allocate a tensor in tensor memory (tmem)
        t_acc = self.tcgen05.alloc(
            dtype=float32, shape=[self.block_m, self.block_n], init=0.0
        )

        # allocate one barrier in shared memory
        mbarriers = self.mbarrier.alloc(counts=[1])

        # use a phase to record the current phase of the barrier
        phase: uint32 = 0

        self.sync()

        for offset_k in range(0, k_size, self.block_k):
            self.copy_async(src=g_a, dst=s_a, offsets=[offset_m, offset_k])
            self.copy_async(src=g_b, dst=s_b, offsets=[offset_n, offset_k])
            self.copy_async_wait_all()
            self.sync()

            # perform tcgen05 mma on two shared tensors
            self.tcgen05.mma(s_a, s_b.transpose(), t_acc)

            # commit the mma operation the finish of the committed operations will trigger a arrive event on the barrier
            self.tcgen05.commit(mbarrier=mbarriers[0])

            # wait for all pending arrivals to finish (in this case, the expected count = 1, which is the operation of mma)
            self.mbarrier.wait(mbarriers[0], phase=phase)
            phase ^= 1

        # load the result from tensor memory to register
        r_acc = self.tcgen05.load(
            t_acc, offsets=[0, 0], shape=[self.block_m, self.block_n]
        )

        g_c = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
        self.store_global(g_c, r_acc.to(float16), offsets=[offset_m, offset_n])

        # all allocated tensor memory must be deallocated
        self.sync()
        self.tcgen05.dealloc(t_acc)


def main(bench=True):
    matmul = BlackwellMatmul()

    headers = ["m", "n", "k", "name", "latency (ms)", "tflops"]
    rows = []

    for m_size, n_size, k_size in [
        [4096, 4096, 4096],
        [4096, 4096, 14336],
    ]:
        print(f"Running with m_size={m_size}, n_size={n_size}, k_size={k_size}")
        a = torch.randn(m_size, k_size, dtype=torch.float16, device="cuda")
        b = torch.randn(n_size, k_size, dtype=torch.float16, device="cuda")
        c = torch.empty(m_size, n_size, dtype=torch.float16, device="cuda")

        matmul(m_size, n_size, k_size, a, b, c)
        torch.cuda.synchronize()

        c_ref = a @ b.T

        torch.testing.assert_close(c, c_ref, atol=1e-2, rtol=1e-2)

        # benchmark
        if bench:
            for name, func in [
                ("torch", lambda: a @ b.T),
                ("tilus", lambda: matmul(m_size, n_size, k_size, a, b, c)),
            ]:
                latency = benchmark_func(func, warmup=5, repeat=20)
                tflops = 2 * m_size * n_size * k_size / latency * 1e-9
                rows.append([m_size, n_size, k_size, name, latency, tflops])

    if bench:
        df = pandas.DataFrame(rows, columns=headers)
        print(df)


if __name__ == "__main__":
    main(bench=True)

The performance:

      m     n      k   name  latency (ms)       tflops
0  4096  4096   4096  torch      0.152496   901.262663
1  4096  4096   4096  tilus      2.647648    51.909829
2  4096  4096  14336  torch      0.440480  1092.073066
3  4096  4096  14336  tilus      9.190816    52.338807

More performant versions will be added gradually.

To make this happen, also added some internal things:

  • add a sugar syntax to create mbarriers self.mbarrier.alloc(...)
  • add an init parameter to self.tcgen05.alloc(...) to specify the init value
  • customize the check_launch_configuration pass to use TVM-FFI's error machanism
  • fixed a bug in tcgen05.mma emitting (its issue granularity is single thread)
  • added a swizzle primitive function to compute the swizzled address in shared memory, and handle it in grid_analyzer and veceval sub-modules.

Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
@yaoyaoding yaoyaoding mentioned this pull request Oct 22, 2025
17 tasks
@yaoyaoding yaoyaoding merged commit 24ce7ee into main Oct 23, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants