Skip to content

[2/2] Use moe_sum_reduce cuda kernel#10654

Merged
yuan-luo merged 5 commits intosgl-project:mainfrom
antgroup:use_moe_sum_reduce
Oct 28, 2025
Merged

[2/2] Use moe_sum_reduce cuda kernel#10654
yuan-luo merged 5 commits intosgl-project:mainfrom
antgroup:use_moe_sum_reduce

Conversation

@yuan-luo
Copy link
Copy Markdown
Collaborator

@yuan-luo yuan-luo commented Sep 19, 2025

Motivation

This PR is to use moe_sum_reduce cuda kernel implemented in #10321.

gsm8k result:

➜  sglang_dev2 git:(use_moe_sum_reduce) ✗ python3 -m sglang.launch_server --model Qwen/Qwen3-30B-A3B --tp-size 8 --port 30000 --mem-fraction-static 0.85 --disable-radix-cache

➜  sglang_dev2 git:(use_moe_sum_reduce) ✗ python3 benchmark/gsm8k/bench_sglang.py
PR:
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:19<00:00, 10.18it/s]
Accuracy: 0.945
Invalid: 0.000
Latency: 19.802 s
Output throughput: 1134.071 token/s

Main:
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:22<00:00,  8.96it/s]
Accuracy: 0.940
Invalid: 0.000
Latency: 23.300 s
Output throughput: 1013.138 token/s

Tried TP4 for Qwen3 Moe, the acc and throughput looks more reasonable.

PR:
➜  sglang git:(main) ✗ python3 benchmark/gsm8k/bench_sglang.py
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:24<00:00,  8.05it/s]
Accuracy: 0.935
Invalid: 0.000
Latency: 25.064 s
Output throughput: 918.339 token/s

Main:
➜  sglang git:(main) ✗ python3 benchmark/gsm8k/bench_sglang.py
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:27<00:00,  7.31it/s]
Accuracy: 0.940
Invalid: 0.000
Latency: 27.659 s
Output throughput: 834.190 token/s
image image

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @yuan-luo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on enhancing the efficiency of Mixture-of-Experts (MoE) operations by integrating a dedicated moe_sum_reduce CUDA kernel. The change streamlines the MoE summation process, replacing previous conditional implementations with a unified, high-performance kernel. This optimization leads to notable improvements in processing speed and throughput, as evidenced by the provided benchmark results.

Highlights

  • Performance Optimization: This pull request integrates a new moe_sum_reduce CUDA kernel, which significantly improves the performance of Mixture-of-Experts (MoE) computations. Benchmarks show a reduction in latency from 23.300s to 19.802s and an increase in output throughput from 1013.138 token/s to 1134.071 token/s on the GSM8K dataset.
  • Code Simplification: The existing conditional logic for moe_sum_reduce (switching between torch.compile for small tokens and Triton for larger ones) has been replaced with a single, direct call to the new moe_sum_reduce CUDA kernel, simplifying the codebase.
  • CUDA Kernel Integration: The moe_sum_reduce function is now imported from the sgl_kernel module, specifically leveraging a CUDA-optimized implementation for efficient execution on NVIDIA GPUs.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request replaces the conditional logic for MoE sum reduction, which previously switched between Triton and torch.compile implementations, with a single, more efficient custom CUDA kernel. This change simplifies the codebase in both fused_moe.py and triton.py and, according to the provided benchmarks, improves performance. My review includes suggestions to remove some redundant code for better clarity and style.

Comment thread python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py Outdated
Comment thread python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py Outdated
Comment thread python/sglang/srt/layers/moe/moe_runner/triton.py Outdated
@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Sep 19, 2025

cc @ch-wan

Copy link
Copy Markdown
Collaborator

@BBuf BBuf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

The CI error is because sgl-kernel has not yet been upgraded, the new kernel can not be recognized. Waiting for the new kernel release.

  File "/public_sglang_ci/runner-l1e-gpu-4567/_work/sglang/sglang/python/sglang/srt/layers/moe/moe_runner/triton.py", line 40, in <module>
    from sgl_kernel import gelu_and_mul, moe_sum_reduce, silu_and_mul
ImportError: cannot import name 'moe_sum_reduce' from 'sgl_kernel' (/usr/local/lib/python3.10/dist-packages/sgl_kernel/__init__.py)
ERROR

@yuan-luo yuan-luo force-pushed the use_moe_sum_reduce branch 2 times, most recently from f2f5751 to cb30bf3 Compare September 20, 2025 01:32
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented Sep 29, 2025

The CI failed because the tensor data type has np.float64, but the cuda kernel doesn't support float64 for the moment.
Use a separate PR to support float64 data type.

[2025-09-29 02:36:08] INFO:     127.0.0.1:35010 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-29 02:36:08] INFO:     127.0.0.1:35022 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-29 02:36:08] INFO:     127.0.0.1:35044 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-29 02:36:08] INFO:     127.0.0.1:35068 - "POST /v1/chat/completions HTTP/1.1" 200 OK
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 64/64 [00:30<00:00,  2.12it/s]
Writing report to /tmp/mmlu_deepseek-ai_DeepSeek-Coder-V2-Lite-Instruct.html
{'other': np.float64(0.0), 'other:std': np.float64(0.0), 'score:std': np.float64(0.0), 'stem': np.float64(0.0), 'stem:std': np.float64(0.0), 'humanities': np.float64(0.0), 'humanities:std': np.float64(0.0), 'social_sciences': np.float64(0.0), 'social_sciences:std': np.float64(0.0), 'score': np.float64(0.0)}
Writing results to /tmp/mmlu_deepseek-ai_DeepSeek-Coder-V2-Lite-Instruct.json
Total latency: 30.266 s
Score: 0.000
E
======================================================================
ERROR: test_mgsm_en (__main__.TestHierarchicalMLA)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/sglang/srt/utils.py", line 2259, in retry
    return fn()
  File "/usr/local/lib/python3.10/dist-packages/sglang/test/test_utils.py", line 1437, in <lambda>
    lambda: super(CustomTestCase, self)._callTestMethod(method),
AssertionError: np.float64(0.0) not greater than 0.8

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/sglang/test/test_utils.py", line 1436, in _callTestMethod
    retry(
  File "/usr/local/lib/python3.10/dist-packages/sglang/srt/utils.py", line 2262, in retry
    raise Exception(f"retry() exceed maximum number of retries.")
Exception: retry() exceed maximum number of retries.

@yuan-luo yuan-luo changed the title [2/2] Use moe_sum_reduce cuda kernel [WIP][2/2] Use moe_sum_reduce cuda kernel Oct 8, 2025
@yuan-luo yuan-luo force-pushed the use_moe_sum_reduce branch 2 times, most recently from fe2e3db to 7f55d35 Compare October 8, 2025 05:09
@yuan-luo yuan-luo force-pushed the use_moe_sum_reduce branch from 7f55d35 to 3f26045 Compare October 8, 2025 06:01
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented Oct 8, 2025

The CI failed because the tensor data type has np.float64, but the cuda kernel doesn't support float64 for the moment. Use a separate PR to support float64 data type.

[2025-09-29 02:36:08] INFO:     127.0.0.1:35010 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-29 02:36:08] INFO:     127.0.0.1:35022 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-29 02:36:08] INFO:     127.0.0.1:35044 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-29 02:36:08] INFO:     127.0.0.1:35068 - "POST /v1/chat/completions HTTP/1.1" 200 OK
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 64/64 [00:30<00:00,  2.12it/s]
Writing report to /tmp/mmlu_deepseek-ai_DeepSeek-Coder-V2-Lite-Instruct.html
{'other': np.float64(0.0), 'other:std': np.float64(0.0), 'score:std': np.float64(0.0), 'stem': np.float64(0.0), 'stem:std': np.float64(0.0), 'humanities': np.float64(0.0), 'humanities:std': np.float64(0.0), 'social_sciences': np.float64(0.0), 'social_sciences:std': np.float64(0.0), 'score': np.float64(0.0)}
Writing results to /tmp/mmlu_deepseek-ai_DeepSeek-Coder-V2-Lite-Instruct.json
Total latency: 30.266 s
Score: 0.000
E
======================================================================
ERROR: test_mgsm_en (__main__.TestHierarchicalMLA)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/sglang/srt/utils.py", line 2259, in retry
    return fn()
  File "/usr/local/lib/python3.10/dist-packages/sglang/test/test_utils.py", line 1437, in <lambda>
    lambda: super(CustomTestCase, self)._callTestMethod(method),
AssertionError: np.float64(0.0) not greater than 0.8

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/sglang/test/test_utils.py", line 1436, in _callTestMethod
    retry(
  File "/usr/local/lib/python3.10/dist-packages/sglang/srt/utils.py", line 2262, in retry
    raise Exception(f"retry() exceed maximum number of retries.")
Exception: retry() exceed maximum number of retries.

With the cuda kernel supporting float64, this issue has been fixed.

[2025-10-07 22:32:23] Decode batch. #running-req: 2, #token: 3489, token usage: 0.00, cuda graph: True, gen throughput (token/s): 6.23, #queue-req: 0,
[2025-10-07 22:32:36] Decode batch. #running-req: 2, #token: 3569, token usage: 0.00, cuda graph: True, gen throughput (token/s): 6.23, #queue-req: 0,
[2025-10-07 22:32:49] Decode batch. #running-req: 2, #token: 3649, token usage: 0.00, cuda graph: True, gen throughput (token/s): 6.23, #queue-req: 0,
[2025-10-07 22:33:01] Decode batch. #running-req: 2, #token: 3729, token usage: 0.00, cuda graph: True, gen throughput (token/s): 6.23, #queue-req: 0,
[2025-10-07 22:33:14] Decode batch. #running-req: 2, #token: 3809, token usage: 0.00, cuda graph: True, gen throughput (token/s): 6.23, #queue-req: 0,
[2025-10-07 22:33:27] Decode batch. #running-req: 2, #token: 3889, token usage: 0.00, cuda graph: True, gen throughput (token/s): 6.23, #queue-req: 0,
[2025-10-07 22:33:40] Decode batch. #running-req: 2, #token: 3969, token usage: 0.00, cuda graph: True, gen throughput (token/s): 6.23, #queue-req: 0,
[2025-10-07 22:33:53] Decode batch. #running-req: 2, #token: 4049, token usage: 0.00, cuda graph: True, gen throughput (token/s): 6.23, #queue-req: 0,
[2025-10-07 22:34:06] Decode batch. #running-req: 2, #token: 4129, token usage: 0.00, cuda graph: True, gen throughput (token/s): 6.23, #queue-req: 0,
[2025-10-07 22:34:19] Decode batch. #running-req: 2, #token: 4209, token usage: 0.00, cuda graph: True, gen throughput (token/s): 6.23, #queue-req: 0,
[2025-10-07 22:34:31] Decode batch. #running-req: 2, #token: 4289, token usage: 0.00, cuda graph: True, gen throughput (token/s): 6.23, #queue-req: 0,
[2025-10-07 22:34:37] INFO:     127.0.0.1:57532 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-10-07 22:34:37] INFO:     127.0.0.1:57650 - "POST /v1/chat/completions HTTP/1.1" 200 OK
100%|██████████| 250/250 [11:01<00:00,  2.65s/it]
.
----------------------------------------------------------------------
Ran 1 test in 743.456s

OK
Writing report to /tmp/mgsm_en_neuralmagic_DeepSeek-Coder-V2-Lite-Instruct-FP8.html
{'en': np.float64(0.864), 'en:std': np.float64(0.3427885645700568), 'group_latin': np.float64(0.864), 'group_latin:std': np.float64(0.3427885645700568), 'score:std': np.float64(0.3427885645700568), 'score': np.float64(0.864)}
Writing results to /tmp/mgsm_en_neuralmagic_DeepSeek-Coder-V2-Lite-Instruct-FP8.json
Total latency: 661.457 s
Score: 0.864

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented Oct 9, 2025

The cuda kernel still has accuracy issue which makes some CI not happy. I'll follow up.

@yuan-luo yuan-luo changed the title [WIP][2/2] Use moe_sum_reduce cuda kernel [2/2] Use moe_sum_reduce cuda kernel Oct 27, 2025
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

The cuda kernel still has accuracy issue which makes some CI not happy. I'll follow up.

The cuda kernel's accuracy is more precise than triton kernel. Move forward for this PR.

@huangtingwei9988
Copy link
Copy Markdown
Collaborator

The cuda kernel still has accuracy issue which makes some CI not happy. I'll follow up.

The cuda kernel's accuracy is more precise than triton kernel. Move forward for this PR.

Yes! According to my test, the cuda kernel's accuracy is more precise than triton kernel

import os

import torch
import triton
import triton.language as tl
from sgl_kernel import moe_sum_reduce as moe_sum_reduce_cuda
from triton.testing import do_bench

@triton.jit
def _moe_sum_reduce_kernel(
    input_ptr,
    input_stride_0,
    input_stride_1,
    input_stride_2,
    output_ptr,
    output_stride_0,
    output_stride_1,
    token_num: int,
    topk_num: int,
    hidden_dim: int,
    routed_scaling_factor: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_DIM: tl.constexpr,
    NUM_STAGE: tl.constexpr,
):
    input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
    input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
    output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)

    token_block_id = tl.program_id(0)
    dim_block_id = tl.program_id(1)

    offs_token = token_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_dim = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)

    mask_token = offs_token < token_num
    mask_dim = offs_dim < hidden_dim

    base_ptrs = input_ptr + offs_token[:, None] * input_stride_0 + offs_dim[None, :]

    accumulator = tl.zeros((BLOCK_M, BLOCK_DIM), dtype=tl.float32)
    for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
        tile = tl.load(
            base_ptrs + i * input_stride_1,
            mask=mask_token[:, None] & mask_dim[None, :],
            other=0.0,
        )
        accumulator += tile.to(tl.float32)
    accumulator *= routed_scaling_factor

    # -------- Write back --------
    store_ptrs = output_ptr + offs_token[:, None] * output_stride_0 + offs_dim[None, :]
    tl.store(
        store_ptrs,
        accumulator.to(input_ptr.dtype.element_ty),
        mask=mask_token[:, None] & mask_dim[None, :],
    )


# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py
def moe_sum_reduce_triton(
    input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
):
    assert input.is_contiguous()
    assert output.is_contiguous()

    token_num, topk_num, hidden_dim = input.shape
    assert output.shape[0] == token_num and output.shape[1] == hidden_dim

    BLOCK_M = 1
    BLOCK_DIM = 2048
    NUM_STAGE = 1
    num_warps = 16

    grid = (
        triton.cdiv(token_num, BLOCK_M),
        triton.cdiv(hidden_dim, BLOCK_DIM),
    )

    _moe_sum_reduce_kernel[grid](
        input,
        *input.stride(),
        output,
        *output.stride(),
        token_num=token_num,
        topk_num=topk_num,
        hidden_dim=hidden_dim,
        routed_scaling_factor=routed_scaling_factor,
        BLOCK_M=BLOCK_M,
        BLOCK_DIM=BLOCK_DIM,
        NUM_STAGE=NUM_STAGE,
        num_warps=num_warps,
    )
    return


def compute_sum_scaled_baseline(
    x: torch.Tensor, out: torch.Tensor, routed_scaling_factor: float
) -> torch.Tensor:
    torch.sum(x, dim=1, out=out)
    out.mul_(routed_scaling_factor)
    return out


@torch.compile
def compute_sum_scaled_compiled(
    x: torch.Tensor, out: torch.Tensor, routed_scaling_factor: float
) -> torch.Tensor:
    torch.sum(x * routed_scaling_factor, dim=1, out=out)
    return out



def verify_correctness(num_tokens=1024, dtype=torch.bfloat16):
    x = torch.randn(num_tokens, 9, 4096, device="cuda", dtype=dtype)
    scaling_factor = 0.3

    out_baseline = torch.empty_like(x[:, 0])
    compute_sum_scaled_baseline(x, out_baseline, scaling_factor)

    out_compiled = torch.empty_like(out_baseline)
    compute_sum_scaled_compiled(x, out_compiled, scaling_factor)

    out_cuda = torch.empty_like(out_baseline)
    moe_sum_reduce_cuda(x, out_cuda, scaling_factor)

    triton_skipped = False
    # triton_skipped = dtype == torch.float64
    # if not triton_skipped:
    out_triton = torch.empty_like(out_baseline)
    moe_sum_reduce_triton(x, out_triton, scaling_factor)

    if dtype == torch.float64:
        atol, rtol = 1e-12, 1e-12
    elif dtype == torch.float32:
        atol, rtol = 1e-6, 1e-6
    else:  # bfloat16 / float16
        atol, rtol = 1e-2, 1e-2

    ok_compiled = torch.allclose(out_baseline, out_compiled, atol=atol, rtol=rtol)
    ok_cuda = torch.allclose(out_baseline, out_cuda, atol=atol, rtol=rtol)
    ok_triton = torch.allclose(out_baseline, out_triton, atol=atol, rtol=rtol)

    if ok_compiled and ok_triton and ok_cuda:
        msg = "✅ All implementations match"
        if triton_skipped:
            msg += " (Triton skipped for float64)"
        print(msg)
    else:
        print(ok_compiled)
        print(ok_triton)
        print(ok_cuda)
        print("❌ Implementations differ")
        print(
            f"Baseline vs Compiled: {(out_baseline - out_compiled).abs().max().item()}"
        )
        print(
            f"Baseline vs Triton: {(out_baseline - out_triton).abs().max().item()}"
        )
        print(f"Baseline vs Cuda: {(out_baseline - out_cuda).abs().max().item()}")


if __name__ == "__main__":

    print("Running correctness verification for float64...")
    verify_correctness(dtype=torch.float64)

result

Running correctness verification for float64...
True
False
True
❌ Implementations differ
Baseline vs Compiled: 1.3322676295501878e-15
Baseline vs Triton: 8.264526956125451e-07
Baseline vs Cuda: 1.7763568394002505e-15

@yuan-luo yuan-luo merged commit 813bd6f into sgl-project:main Oct 28, 2025
97 of 107 checks passed
@yuan-luo yuan-luo deleted the use_moe_sum_reduce branch November 2, 2025 12:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants