[sgl-kernel][1/N]Support Expert Specialization Grouped GEMM#11432
[sgl-kernel][1/N]Support Expert Specialization Grouped GEMM#11432zhyncs merged 4 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @HydraQYH, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a novel 'Expert Specialization' mechanism to optimize the performance of FP8 Blockwise Grouped GEMM operations. By intelligently adapting the kernel selection based on the varying computational demands of individual experts, it effectively mitigates performance bottlenecks observed in scenarios with small batch sizes and uneven expert loads. The changes involve new CUDA kernels, their integration into the existing framework, and comprehensive testing, demonstrating notable speedups. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces "Expert Specialization" for grouped GEMM to enhance performance, particularly for unbalanced workloads in Mixture-of-Experts (MoE) models. The core idea is to dynamically select optimized CUDA kernels with different TiledMMAShape configurations based on the problem sizes of each expert, which is a robust strategy. The implementation is comprehensive, adding new C++ kernels, Python bindings, benchmarks, and tests.
My review has identified a critical correctness bug in the kernel dispatch logic for H20 devices and a separate issue that could lead to compilation failures on older GPU architectures. Additionally, I've provided suggestions to improve code maintainability by reducing duplication and to fix minor inconsistencies in documentation and comments. Despite these issues, this is a valuable contribution with clear performance benefits.
| TORCH_CHECK_NOT_IMPLEMENTED( | ||
| can_implement, "No implemented fp8_blockwise_scaled_grouped_mm for current compute capability: ", sm_version); |
There was a problem hiding this comment.
The TORCH_CHECK_NOT_IMPLEMENTED macro here uses variables can_implement and sm_version which are not defined within this #else block. This will cause a compilation error on systems where CUTLASS_ARCH_MMA_SM90_SUPPORTED is not defined. Please use a static message instead.
TORCH_CHECK(false, "es_fp8_blockwise_scaled_grouped_mm requires SM90+ architecture support.");
| launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmHx00Traits>( | ||
| out_ptrs, | ||
| a_ptrs, | ||
| b_ptrs, | ||
| a_scales_ptrs, | ||
| b_scales_ptrs, | ||
| stride_a, | ||
| stride_b, | ||
| stride_d, | ||
| layout_sfa, | ||
| layout_sfb, | ||
| mm_problem_sizes); |
There was a problem hiding this comment.
There appears to be a copy-paste error here. For H20 devices (is_h20_device is true), you are launching HighMGemmHx00Traits for mm_problem_sizes (middle-M problems). This should likely be MiddleMGemmH20Traits, which is defined but currently unused. Using the wrong traits can lead to incorrect results or suboptimal performance as the kernel configurations (like MMA tile shape) are mismatched for the intended problem sizes.
launch_sm90_fp8_blockwise_scaled_group_mm<MiddleMGemmH20Traits>(
out_ptrs,
a_ptrs,
b_ptrs,
a_scales_ptrs,
b_scales_ptrs,
stride_a,
stride_b,
stride_d,
layout_sfa,
layout_sfb,
mm_problem_sizes);
| @@ -494,6 +494,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { | |||
| "bool silu_activation," | |||
| "int pad_slot_id) -> ()"); | |||
| m.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); | |||
|
|
|||
| /* | |||
| * From csrc/expert_sepcialization | |||
| * @param scales_b Scaling factors for tensor B, float32 per expert group. | ||
| * @param stride_a Stride information for tensor A (int32). | ||
| * @param stride_b Stride information for tensor B (int32). | ||
| * @param stride_c Stride information for output tensor C (int32). |
There was a problem hiding this comment.
| if (!is_h20_device) { | ||
| struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMHx00> lm_psf( | ||
| static_cast<int*>(lm_problem_sizes.data_ptr())); | ||
| struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMHx00> mm_psf( | ||
| static_cast<int*>(mm_problem_sizes.data_ptr())); | ||
| struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMHx00> hm_psf( | ||
| static_cast<int*>(hm_problem_sizes.data_ptr())); | ||
| groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>( | ||
| static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf); | ||
| } else { | ||
| struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMH20> lm_psf( | ||
| static_cast<int*>(lm_problem_sizes.data_ptr())); | ||
| struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMH20> mm_psf( | ||
| static_cast<int*>(mm_problem_sizes.data_ptr())); | ||
| struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMH20> hm_psf( | ||
| static_cast<int*>(hm_problem_sizes.data_ptr())); | ||
| groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>( | ||
| static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf); | ||
| } | ||
| } else if (out_tensors.dtype() == torch::kFloat16) { | ||
| struct Fp8BlockwiseGroupedGemmOffsetFunctor<cutlass::float_e4m3_t, float, cutlass::half_t> of( | ||
| static_cast<int*>(expert_offsets.data_ptr()), | ||
| static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), | ||
| static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()), | ||
| static_cast<cutlass::half_t*>(out_tensors.data_ptr()), | ||
| static_cast<float*>(a_scales.data_ptr()), | ||
| static_cast<float*>(b_scales.data_ptr()), | ||
| static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), | ||
| static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), | ||
| static_cast<float**>(a_scales_ptrs.data_ptr()), | ||
| static_cast<float**>(b_scales_ptrs.data_ptr()), | ||
| static_cast<cutlass::half_t**>(out_ptrs.data_ptr())); | ||
| if (!is_h20_device) { | ||
| struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMHx00> lm_psf( | ||
| static_cast<int*>(lm_problem_sizes.data_ptr())); | ||
| struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMHx00> mm_psf( | ||
| static_cast<int*>(mm_problem_sizes.data_ptr())); | ||
| struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMHx00> hm_psf( | ||
| static_cast<int*>(hm_problem_sizes.data_ptr())); | ||
| groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>( | ||
| static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf); | ||
| } else { | ||
| struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigLowMH20> lm_psf( | ||
| static_cast<int*>(lm_problem_sizes.data_ptr())); | ||
| struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigMiddleMH20> mm_psf( | ||
| static_cast<int*>(mm_problem_sizes.data_ptr())); | ||
| struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor<PerfConfigHighMH20> hm_psf( | ||
| static_cast<int*>(hm_problem_sizes.data_ptr())); | ||
| groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>( | ||
| static_cast<int*>(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf); | ||
| } |
There was a problem hiding this comment.
There is significant code duplication within the if (out_tensors.dtype() == torch::kBFloat16) and else if (out_tensors.dtype() == torch::kFloat16) blocks for handling H20 vs. non-H20 devices. This reduces maintainability and increases the chance of errors.
Consider refactoring this logic into a template function that accepts the PerfConfig types as template parameters. This would eliminate the repeated code blocks and make the logic easier to follow and update.
660f4dc to
daf3294
Compare
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com> Co-authored-by: PGFLMG <1106310035@qq.com> Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
2ecf486 to
ab41d97
Compare
| int num_experts = (int)expert_offsets.size(0); | ||
| auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); | ||
| // Dispatch | ||
| if (out_tensors.dtype() == torch::kBFloat16) { |
There was a problem hiding this comment.
Maybe we can simplify replicate code here
// BFloat16 branch
if (out_tensors.dtype() == torch::kBFloat16) {
// ... create functors ...
if (!is_h20_device) {
// Hx00 config
} else {
// H20 config
}
} else if (out_tensors.dtype() == torch::kFloat16) {
// ... screate functors ...
if (!is_h20_device) {
// Hx00 config
} else {
// H20 config
}
}->
template <typename ElementD>
void dispatch_pre_compute(
const torch::Tensor& out_tensors,
const torch::Tensor& a_tensors,
const torch::Tensor& b_tensors,
// ... other params
bool is_h20_device,
cudaStream_t stream) {
struct Fp8BlockwiseGroupedGemmOffsetFunctor<cutlass::float_e4m3_t, float, ElementD> of(...);
if (!is_h20_device) {
dispatch_problem_size_filters<PerfConfigLowMHx00, PerfConfigMiddleMHx00, PerfConfigHighMHx00>(...);
} else {
dispatch_problem_size_filters<PerfConfigLowMH20, PerfConfigMiddleMH20, PerfConfigHighMH20>(...);
}
}| using HighMGemmHx00Traits = | ||
| ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits<OutType, cutlass::layout::RowMajor, PerfConfigHighMHx00>; | ||
|
|
||
| const std::string H20_device_type_str("NVIDIA H20"); |
There was a problem hiding this comment.
Better create a tool func in sglang/sgl-kernel/include/utils.h .
| @@ -0,0 +1,27 @@ | |||
| import torch | |||
There was a problem hiding this comment.
File name should be expert_specialization.py?
| rmsnorm, | ||
| silu_and_mul, | ||
| ) | ||
| from sgl_kernel.expert_specilization import es_fp8_blockwise_scaled_grouped_mm |
There was a problem hiding this comment.
| from sgl_kernel.expert_specilization import es_fp8_blockwise_scaled_grouped_mm | |
| from sgl_kernel.expert_specialization import es_fp8_blockwise_scaled_grouped_mm |
…ect#11432) Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com> Co-authored-by: PGFLMG <1106310035@qq.com> Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Motivation
There are two motivations for implementing Expert Specialization Grouped GEMM.
1. CUTLASS MoE Benchmark
When I used benchmark to test the performance of FP8 Blockwise Grouped GEMM, I found that the performance was poor when the batch size was 512 to 1024:

The reason is that at this time FP8 Blockwise Grouped GEMM uses a larger TiledMMAShape:
sglang/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu
Lines 488 to 489 in a1a20b4
At this time, the number of tokens that each expert needs to process is very small (M is very small), and using a large TiledMMAShape will result in additional unnecessary calculations.
2. Unbalanced load
Inspired by this PR, the load of experts may vary depending on the scenario and is dynamically changing. Furthermore, this distribution is usually uneven, typically, most experts process only a small number of tokens, while a few experts handle a very large number of tokens.
Implement
The implementation idea of Expert Specialization is to select the appropriate kernel based on the number of tokens processed by the expert. We launch multiple grouped GEMM kernels, each using a different Configuration. For a specific expert, we select an appropriate kernel based on the number of tokens it processes.
There is a detail here. Since grouped GEMM kernel requires num_groups as a Host parameter, this usually requires MemcpyD2H to implement, which will bring significant performance overhead. Therefore, we use another method - Mask Problem Sizes to avoid MemcpyD2H.
Modifications
Implement:
sgl-kernel/CMakeLists.txt
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/expert_specialization/es_fp8_blockwise.cu
sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_functor.cuh
sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_launcher.cuh
sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_traits.cuh
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/expert_specilization.py
Benchmark:
sgl-kernel/benchmark/bench_es_fp8_blockwise_grouped_gemm.py
Unitest:
sgl-kernel/tests/test_es_fp8_blockwise_moe.py
Accuracy Tests
pytest -s test_es_fp8_blockwise_moe.py

Benchmarking and Profiling
I tested it using the benchmark script on the H200 and found that in some scenarios, the performance gain of Expert Specialization was significant:

After that, I modified this:
sglang/python/sglang/srt/layers/moe/cutlass_moe.py
Line 184 in 3b9d97f
replacing
fp8_blockwise_scaled_grouped_mmwithes_fp8_blockwise_scaled_grouped_mm, and tested it further.Before:
After:
Compared to the native
fp8_blockwise_scaled_grouped_mmof sgl-kernel, a general performance improvement can be observed.TODO List
There is still a lot of work to be done next.
Co-Author
I had many discussions with @yuan-luo @FlamingoPg @BBuf and thank you very much.
Original Project
https://github.com/HydraQYH/expert_specialization_moe
Checklist