Skip to content

Conversation

@yaoyaoding
Copy link
Member

This PR adds the second matmul in blackwell that uses tma.

Example:

@tilus.autotune("block_m, block_n", [[128, 64], [128, 128], [128, 256]])
@tilus.autotune("block_k", [16, 32, 64])
class BlackwellMatmul(tilus.Script):
    def __init__(self, block_m: int, block_n: int, block_k: int):
        super().__init__()
        self.block_m = block_m
        self.block_n = block_n
        self.block_k = block_k

    def __call__(
        self,
        m_size: int32,
        n_size: int,
        k_size: int,
        a_ptr: ~float16,
        b_ptr: ~float16,
        c_ptr: ~float16,
    ):
        self.attrs.blocks = [cdiv(m_size, self.block_m), cdiv(n_size, self.block_n)]
        self.attrs.warps = 4

        offset_m: int32 = self.block_m * self.blockIdx.x
        offset_n: int32 = self.block_n * self.blockIdx.y

        g_a = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
        g_b = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size])
        s_a = self.shared_tensor(dtype=float16, shape=[self.block_m, self.block_k])
        s_b = self.shared_tensor(dtype=float16, shape=[self.block_n, self.block_k])

        # allocate a tensor in tensor memory (tmem)
        t_acc = self.tcgen05.alloc(
            dtype=float32, shape=[self.block_m, self.block_n], init=0.0
        )

        # allocate one barrier in shared memory
        tma_barrier, mma_barrier = self.mbarrier.alloc(counts=[1, 1])

        # use a phase to record the current phase of the barrier
        phase: uint32 = 0

        self.sync()

        for offset_k in range(0, k_size, self.block_k):
            with self.single_thread():  # we use a single thread to issue the TMA copy
                self.tma.global_to_shared(
                    src=g_a,
                    dst=s_a,
                    offsets=[offset_m, offset_k],
                    mbarrier=tma_barrier,
                )
                self.tma.global_to_shared(
                    src=g_b,
                    dst=s_b,
                    offsets=[offset_n, offset_k],
                    mbarrier=tma_barrier,
                )
                self.mbarrier.arrive(tma_barrier)
                self.mbarrier.wait(tma_barrier, phase=phase)

                # perform tcgen05 mma on two shared tensors
                self.tcgen05.mma(s_a, s_b.transpose(), t_acc)

                # commit the mma operation the finish of the committed operations will trigger a arrive event on the barrier
                self.tcgen05.commit(mbarrier=mma_barrier)

                # wait for all pending arrivals to finish (in this case, the expected count = 1, which is the operation of mma)
                self.mbarrier.wait(mma_barrier, phase=phase)
            self.sync()

            phase ^= 1

        # load the result from tensor memory to register
        r_acc = self.tcgen05.load(
            t_acc, offsets=[0, 0], shape=[self.block_m, self.block_n]
        )

        g_c = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
        self.store_global(g_c, r_acc.to(float16), offsets=[offset_m, offset_n])

        # all allocated tensor memory must be deallocated
        self.sync()
        self.tcgen05.dealloc(t_acc)

The performance is similar to v0 since implement the example in a synchronous way.

      m     n      k   name  latency (ms)       tflops
0  4096  4096   4096  torch      0.155712   882.648432
1  4096  4096   4096  tilus      0.529440   259.593074
2  4096  4096  14336  torch      0.451056  1066.466987
3  4096  4096  14336  tilus      1.818080   264.584801

Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
@yaoyaoding yaoyaoding mentioned this pull request Oct 25, 2025
17 tasks
@yaoyaoding yaoyaoding merged commit a284935 into main Oct 25, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants