Skip to content

Conversation

@yaoyaoding
Copy link
Member

@yaoyaoding yaoyaoding commented Nov 26, 2025

Add example v5 that:

  1. use cluster-launch-control (clc) instructions to support persistent thread block
  2. double buffering the tensor memory buffer for accumulation, to pipeline the mma and write back.

Kernel:

"""
Scheduler:
- shared by all workers.
- LoadWorker fetch the next block
- All workers query the next block

Pipelines:
- LoadPipeline: load A and B from global memory to shared memory
- MmaPipeline: compute MMA from shared memory to tensor memory

Workers:
- LoadWorker (warp 1): producer of LoadPipeline, consumer of BlockPipeline
- MmaWorker (warp 2): consumer of LoadPipeline, producer of MmaPipeline, consumer of BlockPipeline
- EpilogueWorker (warp 4-7): consumer of MmaPipeline, consumer of BlockPipeline
- warp 0 and 3 are idle
"""


class Params(tilus.Class):
    def __init__(
        self,
        m_size: int32,
        n_size: int,
        k_size: int,
        block_m: int,
        block_n: int,
        block_k: int,
        g_a: GlobalTensor,
        g_b: GlobalTensor,
        g_c: GlobalTensor,
    ):
        self.m_size: int32 = m_size
        self.n_size: int = n_size
        self.k_size: int = k_size
        self.block_m: int = block_m
        self.block_n: int = block_n
        self.block_k: int = block_k
        self.g_a: GlobalTensor = g_a
        self.g_b: GlobalTensor = g_b
        self.g_c: GlobalTensor = g_c


class Scheduler(tilus.Class):
    def __init__(self):
        self.barrier = self.mbarrier.alloc(count=1)
        self.phase: uint32 = self.mbarrier.consumer_initial_phase
        self.s_response = self.shared_tensor(dtype=int32, shape=[4])

    def fetch_next(self):
        self.clc.try_cancel(self.s_response, self.barrier, multicast=False)

    def query_next(self) -> tuple[Expr, Dim3]:
        """Utility function can be used by consumers. Need to be executed within consumer's thread group."""
        # wait for the response
        self.mbarrier.wait(self.barrier, phase=self.phase)
        self.phase = self.phase ^ uint32(1)

        # decode the response
        is_valid, blockIdx = self.clc.query_response(self.s_response)
        return is_valid, blockIdx


class LoadPipeline(tilus.Pipeline):
    def __init__(self, num_stages: int, params: Params):
        super().__init__(
            num_stages=num_stages,
            producer_arrive_count=2,  # two tma loads
            consumer_arrive_count=1,  # one commit in MmaWorker
        )
        self.s_a = self.shared_tensor(
            dtype=float16, shape=[num_stages, params.block_m, params.block_k]
        )
        self.s_b = self.shared_tensor(
            dtype=float16, shape=[num_stages, params.block_n, params.block_k]
        )


class MmaPipeline(tilus.Pipeline):
    def __init__(self, num_stages: int, params: Params):
        super().__init__(
            num_stages=num_stages,
            producer_arrive_count=1,  # one commit in MmaWorker
            consumer_arrive_count=128,  # epilogue has 128 threads
        )
        self.t_acc: TMemoryTensor = self.tcgen05.alloc(
            float32, shape=[num_stages, params.block_m, params.block_n], init=0.0
        )

    def finalize(self):
        self.sync()
        self.tcgen05.dealloc(self.t_acc)


class LoadWorker(tilus.Class):
    def __init__(self, params: Params, load_pipe: LoadPipeline, scheduler: Scheduler):
        self.params: Params = params
        self.load_pipe: LoadPipeline = load_pipe
        self.scheduler: Scheduler = scheduler

    def async_run(self):
        params, load_pipe, scheduler = self.params, self.load_pipe, self.scheduler
        s_a, s_b = load_pipe.s_a, load_pipe.s_b
        num_stages: int = load_pipe.num_stages
        offset_m: int32 = self.blockIdx.x * params.block_m
        offset_n: int32 = self.blockIdx.y * params.block_n
        with self.thread_group(thread_begin=32, num_threads=32):
            while True:
                for offset_k in self.range(
                    0, params.k_size, params.block_k, unroll=num_stages
                ):
                    load_pipe.producer_acquire()
                    with self.single_thread():
                        self.tma.global_to_shared(
                            src=params.g_a,
                            dst=s_a[load_pipe.producer_stage],
                            offsets=[offset_m, offset_k],
                            mbarrier=load_pipe.producer_release_barrier(),
                        )
                        self.tma.global_to_shared(
                            src=params.g_b,
                            dst=s_b[load_pipe.producer_stage],
                            offsets=[offset_n, offset_k],
                            mbarrier=load_pipe.producer_release_barrier(),
                        )
                    load_pipe.producer_advance()

                scheduler.fetch_next()
                is_valid, blockIdx = scheduler.query_next()
                if is_valid:
                    offset_m = blockIdx.x * params.block_m
                    offset_n = blockIdx.y * params.block_n
                else:
                    break


class MmaWorker(tilus.Class):
    def __init__(
        self,
        params: Params,
        load_pipe: LoadPipeline,
        mma_pipe: MmaPipeline,
        scheduler: Scheduler,
    ):
        self.params: Params = params
        self.load_pipe: LoadPipeline = load_pipe
        self.mma_pipe: MmaPipeline = mma_pipe
        self.scheduler: Scheduler = scheduler

    def async_run(self):
        params, load_pipe, mma_pipe, scheduler = (
            self.params,
            self.load_pipe,
            self.mma_pipe,
            self.scheduler,
        )
        with self.thread_group(thread_begin=64, num_threads=32):
            while True:
                mma_pipe.producer_acquire()
                for _ in self.range(
                    0, params.k_size, params.block_k, unroll=load_pipe.num_stages
                ):
                    load_pipe.consumer_acquire()
                    with self.single_thread():
                        self.tcgen05.mma(
                            load_pipe.s_a[load_pipe.consumer_stage],
                            load_pipe.s_b[load_pipe.consumer_stage].transpose(),
                            mma_pipe.t_acc[mma_pipe.producer_stage],
                        )
                        self.tcgen05.commit(mbarrier=load_pipe.consumer_release_barrier())
                    load_pipe.consumer_advance()
                with self.single_thread():
                    self.tcgen05.commit(mbarrier=mma_pipe.producer_release_barrier())
                mma_pipe.producer_advance()

                # check if there is a new block to process
                is_valid, blockIdx = scheduler.query_next()
                if not is_valid:
                    break


class EpilogueWorker(tilus.Class):
    def __init__(self, params: Params, mma_pipe: MmaPipeline, scheduler: Scheduler):
        super().__init__()
        self.params: Params = params
        self.mma_pipe: MmaPipeline = mma_pipe
        self.scheduler: Scheduler = scheduler

    def async_run_stg(self):
        params, mma_pipe, block_pipe = self.params, self.mma_pipe, self.scheduler
        with self.thread_group(thread_begin=128, num_threads=128):
            offset_m: int32 = self.blockIdx.x * params.block_m
            offset_n: int32 = self.blockIdx.y * params.block_n
            while True:
                mma_pipe.consumer_acquire()
                t_acc = mma_pipe.t_acc[mma_pipe.consumer_stage]

                # tmem to smem
                r_acc = self.tcgen05.load(t_acc)
                self.tcgen05.wait_load()
                self.store_global(
                    params.g_c, r_acc.to(float16), offsets=[offset_m, offset_n]
                )

                # reset tmem to 0.0 for next accumulation
                self.tcgen05.store(
                    t_acc,
                    src=self.register_tensor(
                        dtype=float32, shape=[params.block_m, params.block_n], init=0.0
                    ),
                )
                self.tcgen05.wait_store()
                self.sync()

                self.mbarrier.arrive(mma_pipe.consumer_release_barrier())
                mma_pipe.consumer_advance()

                is_valid, blockIdx = block_pipe.query_next()
                if is_valid:
                    offset_m = blockIdx.x * params.block_m
                    offset_n = blockIdx.y * params.block_n
                else:
                    break


@tilus.autotune("block_m, block_n", [[128, 64], [128, 128], [128, 256]])
@tilus.autotune("block_k", [16, 32, 64])
@tilus.autotune("load_stages", [2, 3, 4])
class BlackwellMatmulV5(tilus.Script):
    def __init__(self, block_m: int, block_n: int, block_k: int, load_stages: int):
        super().__init__()
        self.block_m = block_m
        self.block_n = block_n
        self.block_k = block_k
        self.load_stages = load_stages

    def __call__(
        self,
        m_size: int32,
        n_size: int,
        k_size: int,
        a_ptr: ~float16,
        b_ptr: ~float16,
        c_ptr: ~float16,
    ):
        self.attrs.blocks = [cdiv(m_size, self.block_m), cdiv(n_size, self.block_n)]
        self.attrs.warps = 8

        params = Params(
            m_size=m_size,
            n_size=n_size,
            k_size=k_size,
            block_m=self.block_m,
            block_n=self.block_n,
            block_k=self.block_k,
            g_a=self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]),
            g_b=self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size]),
            g_c=self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]),
        )

        scheduler = Scheduler()
        load_pipe = LoadPipeline(num_stages=self.load_stages, params=params)
        mma_pipe = MmaPipeline(num_stages=1, params=params)

        load_worker = LoadWorker(params, load_pipe=load_pipe, scheduler=scheduler)
        mma_worker = MmaWorker(params, load_pipe, mma_pipe, scheduler)
        epilogue_worker = EpilogueWorker(params, mma_pipe, scheduler)

        load_worker.async_run()
        mma_worker.async_run()
        epilogue_worker.async_run_stg()

        self.sync()
        mma_pipe.finalize()

Performance on B200:

       m      n      k   name  latency (ms)       tflops
0   4096   4096   4096  torch      0.094208  1458.888315
1   4096   4096   4096  tilus      0.113088  1215.327469
2   4096   4096  14336  torch      0.309248  1555.503468
3   4096   4096  14336  tilus      0.364560  1319.498372
4   8192   8192   8192  torch      0.677952  1621.813405
5   8192   8192   8192  tilus      0.829872  1324.917157
6  10240  10240  10240  torch      1.422320  1509.845629
7  10240  10240  10240  tilus      1.679408  1278.714660

@copy-pr-bot
Copy link

copy-pr-bot bot commented Nov 26, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
.
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
@yaoyaoding yaoyaoding marked this pull request as ready for review December 1, 2025 19:51
@yaoyaoding yaoyaoding merged commit 02c9db4 into main Dec 1, 2025
10 checks passed
@yaoyaoding yaoyaoding deleted the blackwell-gemm-v5 branch December 1, 2025 20:19
@yaoyaoding yaoyaoding mentioned this pull request Dec 4, 2025
17 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants