Skip to content

[sgl-kernel][1/N]Support Expert Specialization Grouped GEMM#11432

Merged
zhyncs merged 4 commits intosgl-project:mainfrom
HydraQYH:dev_support_expert_specilization_op
Oct 13, 2025
Merged

[sgl-kernel][1/N]Support Expert Specialization Grouped GEMM#11432
zhyncs merged 4 commits intosgl-project:mainfrom
HydraQYH:dev_support_expert_specilization_op

Conversation

@HydraQYH
Copy link
Copy Markdown
Collaborator

@HydraQYH HydraQYH commented Oct 10, 2025

Motivation

There are two motivations for implementing Expert Specialization Grouped GEMM.

1. CUTLASS MoE Benchmark

When I used benchmark to test the performance of FP8 Blockwise Grouped GEMM, I found that the performance was poor when the batch size was 512 to 1024:
test_cutlass_moe
The reason is that at this time FP8 Blockwise Grouped GEMM uses a larger TiledMMAShape:

using MmaTileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _2, _1>;

At this time, the number of tokens that each expert needs to process is very small (M is very small), and using a large TiledMMAShape will result in additional unnecessary calculations.

2. Unbalanced load

Inspired by this PR, the load of experts may vary depending on the scenario and is dynamically changing. Furthermore, this distribution is usually uneven, typically, most experts process only a small number of tokens, while a few experts handle a very large number of tokens.

Implement

The implementation idea of ​​Expert Specialization is to select the appropriate kernel based on the number of tokens processed by the expert. We launch multiple grouped GEMM kernels, each using a different Configuration. For a specific expert, we select an appropriate kernel based on the number of tokens it processes.
There is a detail here. Since grouped GEMM kernel requires num_groups as a Host parameter, this usually requires MemcpyD2H to implement, which will bring significant performance overhead. Therefore, we use another method - Mask Problem Sizes to avoid MemcpyD2H.

Modifications

Implement:
sgl-kernel/CMakeLists.txt
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/expert_specialization/es_fp8_blockwise.cu
sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_functor.cuh
sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_launcher.cuh
sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_traits.cuh
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/expert_specilization.py
Benchmark:
sgl-kernel/benchmark/bench_es_fp8_blockwise_grouped_gemm.py
Unitest:
sgl-kernel/tests/test_es_fp8_blockwise_moe.py

Accuracy Tests

pytest -s test_es_fp8_blockwise_moe.py
image

Benchmarking and Profiling

I tested it using the benchmark script on the H200 and found that in some scenarios, the performance gain of Expert Specialization was significant:
image
After that, I modified this:

fp8_blockwise_scaled_grouped_mm(

replacing fp8_blockwise_scaled_grouped_mmwithes_fp8_blockwise_scaled_grouped_mm, and tested it further.
Before:
image
After:
image
Compared to the native fp8_blockwise_scaled_grouped_mm of sgl-kernel, a general performance improvement can be observed.

TODO List

There is still a lot of work to be done next.

  • Fine-tune the range of M for H20 and H20.
  • Adapting CUDAGraph
  • Adapt the Python side code for H20 and H200. This is because H20 and H200 have significant differences in configuration and applicable scenarios.
  • Adapt to the SM100+ architecture.

Co-Author

I had many discussions with @yuan-luo @FlamingoPg @BBuf and thank you very much.

Original Project

https://github.com/HydraQYH/expert_specialization_moe

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @HydraQYH, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a novel 'Expert Specialization' mechanism to optimize the performance of FP8 Blockwise Grouped GEMM operations. By intelligently adapting the kernel selection based on the varying computational demands of individual experts, it effectively mitigates performance bottlenecks observed in scenarios with small batch sizes and uneven expert loads. The changes involve new CUDA kernels, their integration into the existing framework, and comprehensive testing, demonstrating notable speedups.

Highlights

  • Performance Optimization for Grouped GEMM: Introduced 'Expert Specialization Grouped GEMM' to address performance issues in FP8 Blockwise Grouped GEMM, especially when dealing with small batch sizes (512-1024) and unbalanced expert loads.
  • Dynamic Kernel Selection: The core implementation dynamically selects the most appropriate kernel based on the number of tokens processed by each expert, launching multiple grouped GEMM kernels with different TiledMMAShapes.
  • Memory Overhead Reduction: Utilizes a 'Mask Problem Sizes' technique to avoid costly MemcpyD2H operations, which would otherwise be required for Host parameters in Grouped GEMM Kernels.
  • Significant Performance Gains: Benchmarking on H200 shows significant performance improvements in certain scenarios and a general uplift compared to the native fp8_blockwise_scaled_grouped_mm.
  • New Kernel and Integration: Added new CUDA kernel files (es_fp8_blockwise.cu, es_fp8_blockwise_functor.cuh, es_fp8_blockwise_launcher.cuh, es_fp8_blockwise_traits.cuh) and integrated them into the build system (CMakeLists.txt), PyTorch extension (common_extension.cc), and Python bindings (sgl_kernel/__init__.py, expert_specilization.py).
  • Comprehensive Testing and Benchmarking: Includes new benchmark scripts (bench_es_fp8_blockwise_grouped_gemm.py) and unit tests (test_es_fp8_blockwise_moe.py) to validate accuracy and measure performance.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces "Expert Specialization" for grouped GEMM to enhance performance, particularly for unbalanced workloads in Mixture-of-Experts (MoE) models. The core idea is to dynamically select optimized CUDA kernels with different TiledMMAShape configurations based on the problem sizes of each expert, which is a robust strategy. The implementation is comprehensive, adding new C++ kernels, Python bindings, benchmarks, and tests.

My review has identified a critical correctness bug in the kernel dispatch logic for H20 devices and a separate issue that could lead to compilation failures on older GPU architectures. Additionally, I've provided suggestions to improve code maintainability by reducing duplication and to fix minor inconsistencies in documentation and comments. Despite these issues, this is a valuable contribution with clear performance benefits.

Comment on lines +123 to +124
TORCH_CHECK_NOT_IMPLEMENTED(
can_implement, "No implemented fp8_blockwise_scaled_grouped_mm for current compute capability: ", sm_version);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The TORCH_CHECK_NOT_IMPLEMENTED macro here uses variables can_implement and sm_version which are not defined within this #else block. This will cause a compilation error on systems where CUTLASS_ARCH_MMA_SM90_SUPPORTED is not defined. Please use a static message instead.

  TORCH_CHECK(false, "es_fp8_blockwise_scaled_grouped_mm requires SM90+ architecture support.");

Comment on lines +269 to +280
launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmHx00Traits>(
out_ptrs,
a_ptrs,
b_ptrs,
a_scales_ptrs,
b_scales_ptrs,
stride_a,
stride_b,
stride_d,
layout_sfa,
layout_sfb,
mm_problem_sizes);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There appears to be a copy-paste error here. For H20 devices (is_h20_device is true), you are launching HighMGemmHx00Traits for mm_problem_sizes (middle-M problems). This should likely be MiddleMGemmH20Traits, which is defined but currently unused. Using the wrong traits can lead to incorrect results or suboptimal performance as the kernel configurations (like MMA tile shape) are mismatched for the intended problem sizes.

    launch_sm90_fp8_blockwise_scaled_group_mm<MiddleMGemmH20Traits>(
        out_ptrs,
        a_ptrs,
        b_ptrs,
        a_scales_ptrs,
        b_scales_ptrs,
        stride_a,
        stride_b,
        stride_d,
        layout_sfa,
        layout_sfb,
        mm_problem_sizes);

@@ -494,6 +494,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"bool silu_activation,"
"int pad_slot_id) -> ()");
m.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);

/*
* From csrc/expert_sepcialization
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a typo in this comment. sepcialization should be specialization.

Suggested change
* From csrc/expert_sepcialization
* From csrc/expert_specialization

* @param scales_b Scaling factors for tensor B, float32 per expert group.
* @param stride_a Stride information for tensor A (int32).
* @param stride_b Stride information for tensor B (int32).
* @param stride_c Stride information for output tensor C (int32).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The documentation for this parameter is @param stride_c, but the actual parameter name in the function signature (line 41) is stride_d. Please update the docstring to use stride_d for consistency.

 * @param stride_d       Stride information for output tensor D (int32).

Comment on lines +67 to +117
if (!is_h20_device) {
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMHx00> lm_psf(
static_cast<int*>(lm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMHx00> mm_psf(
static_cast<int*>(mm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMHx00> hm_psf(
static_cast<int*>(hm_problem_sizes.data_ptr()));
groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>(
static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf);
} else {
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMH20> lm_psf(
static_cast<int*>(lm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMH20> mm_psf(
static_cast<int*>(mm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMH20> hm_psf(
static_cast<int*>(hm_problem_sizes.data_ptr()));
groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>(
static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf);
}
} else if (out_tensors.dtype() == torch::kFloat16) {
struct Fp8BlockwiseGroupedGemmOffsetFunctor<cutlass::float_e4m3_t, float, cutlass::half_t> of(
static_cast<int*>(expert_offsets.data_ptr()),
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()),
static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()),
static_cast<cutlass::half_t*>(out_tensors.data_ptr()),
static_cast<float*>(a_scales.data_ptr()),
static_cast<float*>(b_scales.data_ptr()),
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()),
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()),
static_cast<float**>(a_scales_ptrs.data_ptr()),
static_cast<float**>(b_scales_ptrs.data_ptr()),
static_cast<cutlass::half_t**>(out_ptrs.data_ptr()));
if (!is_h20_device) {
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMHx00> lm_psf(
static_cast<int*>(lm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMHx00> mm_psf(
static_cast<int*>(mm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMHx00> hm_psf(
static_cast<int*>(hm_problem_sizes.data_ptr()));
groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>(
static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf);
} else {
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMH20> lm_psf(
static_cast<int*>(lm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMH20> mm_psf(
static_cast<int*>(mm_problem_sizes.data_ptr()));
struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMH20> hm_psf(
static_cast<int*>(hm_problem_sizes.data_ptr()));
groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>(
static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is significant code duplication within the if (out_tensors.dtype() == torch::kBFloat16) and else if (out_tensors.dtype() == torch::kFloat16) blocks for handling H20 vs. non-H20 devices. This reduces maintainability and increases the chance of errors.

Consider refactoring this logic into a template function that accepts the PerfConfig types as template parameters. This would eliminate the repeated code blocks and make the logic easier to follow and update.

HydraQYH and others added 2 commits October 11, 2025 10:53
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: PGFLMG <1106310035@qq.com>
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
@HydraQYH HydraQYH force-pushed the dev_support_expert_specilization_op branch from 2ecf486 to ab41d97 Compare October 11, 2025 02:53
@yuan-luo yuan-luo changed the title [1/N]Support Expert Specialization Grouped GEMM [sgl-kernel][1/N]Support Expert Specialization Grouped GEMM Oct 11, 2025
int num_experts = (int)expert_offsets.size(0);
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
// Dispatch
if (out_tensors.dtype() == torch::kBFloat16) {
Copy link
Copy Markdown
Collaborator

@BBuf BBuf Oct 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can simplify replicate code here

// BFloat16 branch
if (out_tensors.dtype() == torch::kBFloat16) {
    // ... create functors ...
    if (!is_h20_device) {
        // Hx00 config
    } else {
        // H20 config
    }
} else if (out_tensors.dtype() == torch::kFloat16) {
    // ... screate functors ...
    if (!is_h20_device) {
        // Hx00 config
    } else {
        // H20 config
    }
}

->

template <typename ElementD>
void dispatch_pre_compute(
    const torch::Tensor& out_tensors,
    const torch::Tensor& a_tensors,
    const torch::Tensor& b_tensors,
    // ... other params
    bool is_h20_device,
    cudaStream_t stream) {
    
    struct Fp8BlockwiseGroupedGemmOffsetFunctor<cutlass::float_e4m3_t, float, ElementD> of(...);
    
    if (!is_h20_device) {
        dispatch_problem_size_filters<PerfConfigLowMHx00, PerfConfigMiddleMHx00, PerfConfigHighMHx00>(...);
    } else {
        dispatch_problem_size_filters<PerfConfigLowMH20, PerfConfigMiddleMH20, PerfConfigHighMH20>(...);
    }
}

using HighMGemmHx00Traits =
ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits<OutType, cutlass::layout::RowMajor, PerfConfigHighMHx00>;

const std::string H20_device_type_str("NVIDIA H20");
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better create a tool func in sglang/sgl-kernel/include/utils.h .

@@ -0,0 +1,27 @@
import torch
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

File name should be expert_specialization.py?

rmsnorm,
silu_and_mul,
)
from sgl_kernel.expert_specilization import es_fp8_blockwise_scaled_grouped_mm
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from sgl_kernel.expert_specilization import es_fp8_blockwise_scaled_grouped_mm
from sgl_kernel.expert_specialization import es_fp8_blockwise_scaled_grouped_mm

Copy link
Copy Markdown
Collaborator

@zhyncs zhyncs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, we just need to address @BBuf's review comment

@zhyncs zhyncs merged commit 9a30914 into sgl-project:main Oct 13, 2025
83 of 89 checks passed
lpc0220 pushed a commit to lpc0220/sglang that referenced this pull request Oct 29, 2025
…ect#11432)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: PGFLMG <1106310035@qq.com>
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants