[Kernel Slimming] Migrate marlin moe kernel to JIT#19181
[Kernel Slimming] Migrate marlin moe kernel to JIT#19181BBuf merged 9 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @celve, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly refactors the Marlin Mixture-of-Experts (MoE) kernel by transitioning its compilation strategy from Ahead-Of-Time (AOT) to Just-In-Time (JIT). This change aims to improve the adaptability and potential performance of the kernel within the SGLang framework. The new JIT implementation is thoroughly validated through extensive unit tests, confirming bitwise equivalence with the previous AOT version, and its performance is benchmarked to ensure no regressions and identify potential gains. The integration into the existing fused MoE layer ensures a seamless transition for users. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Serve TheBloke/dolphin-2.7-mixtral-8x7b-AWQ with JIT: With AOT: |
There was a problem hiding this comment.
Code Review
The pull request successfully migrates the Marlin MoE kernel to a JIT-compiled version, which helps in reducing the binary size and improving flexibility. The implementation follows the existing Marlin logic and includes comprehensive tests and benchmarks. I have identified a few issues related to potential integer overflows in stride calculations, missing validation for quantization group alignment, and some complex boolean expressions that could be clarified with parentheses to improve maintainability.
| const int scales_expert_stride = prob_n * prob_k / group_size / (w_type == host::kFE2M1f ? 16 : 8); | ||
| const int zp_expert_stride = | ||
| is_zp_float ? prob_n * prob_k / group_size / 8 : prob_n * prob_k / group_size / (pack_factor * 4); |
There was a problem hiding this comment.
The calculation of scales_expert_stride and zp_expert_stride involves an intermediate product prob_n * prob_k using 32-bit signed integers. If both dimensions are large (e.g., 65536), this product will overflow, leading to incorrect stride values and potential memory corruption. It is recommended to cast one of the operands to int64_t before multiplication.
| group_blocks = group_size / 16; | ||
| host::RuntimeCheck( | ||
| prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); | ||
| } else { | ||
| host::RuntimeCheck(group_size == 0); | ||
| group_blocks = 0; | ||
| } | ||
| } else { | ||
| if (group_size == -1) { | ||
| group_blocks = -1; | ||
| } else { | ||
| group_blocks = group_size / 16; | ||
| host::RuntimeCheck( | ||
| prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); |
There was a problem hiding this comment.
The group_blocks is derived by dividing group_size by 16 without verifying that group_size is actually a multiple of 16. Marlin kernels rely on 16x16 tile alignment for quantization groups. If an unaligned group_size is provided, the truncation will cause incorrect indexing into the scales and zero-points buffers. Additionally, if group_blocks becomes 0, it may lead to division-by-zero errors in the kernel. A runtime check should be added to ensure group_size % 16 == 0 when grouping is enabled (i.e., group_size > 0).
| static constexpr auto w_type = host::ScalarType::from_id(w_type_id); | ||
| static constexpr auto s_type = host::ScalarType::from_id(s_type_id); | ||
| if constexpr (w_type == host::kFE2M1f) { | ||
| static_assert(s_type == host::kFE4M3fn && group_blocks == 1 || s_type == host::kFE8M0fnu && group_blocks == 2); |
There was a problem hiding this comment.
The static_assert condition is complex and lacks parentheses to clearly define the operator precedence between && and ||. While C++ precedence rules handle this correctly, adding explicit parentheses would improve readability and prevent potential logic errors during future maintenance.
static_assert((s_type == host::kFE4M3fn && group_blocks == 1) || (s_type == host::kFE8M0fnu && group_blocks == 2));| constexpr bool dequant_skip_flop = w_type == host::kFE4M3fn || | ||
| w_type == host::kFE2M1f && s_type == host::kFE4M3fn || | ||
| has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value || | ||
| has_zp && !is_zp_float && !(w_type == host::kU8); |
There was a problem hiding this comment.
This complex boolean expression for dequant_skip_flop is difficult to parse due to the mix of && and || operators without parentheses. Adding parentheses to group the logical units would significantly improve maintainability and clarity.
constexpr bool dequant_skip_flop = (w_type == host::kFE4M3fn) ||
(w_type == host::kFE2M1f && s_type == host::kFE4M3fn) ||
(has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value) ||
(has_zp && !is_zp_float && !(w_type == host::kU8));There was a problem hiding this comment.
Pull request overview
This pull request migrates the Marlin MoE (Mixture of Experts) kernel from Ahead-of-Time (AOT) compilation in the sgl-kernel package to Just-in-Time (JIT) compilation. This is part of a larger kernel slimming initiative to reduce the size of the sgl-kernel wheel, which currently takes up 1633 MB on H100 systems. The Marlin MoE kernels alone account for approximately 370 MB (22.67%).
Changes:
- Migrates the
moe_wna16_marlin_gemmkernel from AOT (sgl-kernel) to JIT compilation - Changes the API parameter from
b_q_type_id(integer) tob_q_type(ScalarType object) - Adds comprehensive unit tests verifying bitwise equality between JIT and AOT implementations
- Includes benchmark code to validate performance parity
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py |
Updates import to use JIT kernel and changes parameter from b_q_type_id to b_q_type |
python/sglang/jit_kernel/moe_wna16_marlin.py |
Python wrapper that handles JIT compilation, tensor allocation, and parameter conversion |
python/sglang/jit_kernel/csrc/gemm/marlin_moe/moe_wna16_marlin.cuh |
Main CUDA kernel implementation ported from sgl-kernel |
python/sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h |
Marlin MoE kernel template with full implementation |
python/sglang/jit_kernel/csrc/gemm/marlin_moe/kernel.h |
Kernel header definitions |
python/sglang/jit_kernel/tests/test_moe_wna16_marlin.py |
Comprehensive unit tests verifying correctness across multiple configurations |
python/sglang/jit_kernel/benchmark/bench_moe_wna16_marlin.py |
Performance benchmarking code comparing JIT and AOT implementations |
Comments suppressed due to low confidence (2)
python/sglang/jit_kernel/moe_wna16_marlin.py:93
- The
has_biasflag is derived from checking ifb_bias_or_none is not None(line 93), but this doesn't verify that the tensor actually has valid data. An empty tensor created by_or_emptywould still causehas_biasto beFalseeven when the convertedb_bias_tis passed.
The logic should check if the bias tensor has elements after conversion, similar to how has_zp is determined:
has_zp = b_zeros_or_none is not None and b_zeros_or_none.numel() > 0Consider changing line 93 to:
has_bias = b_bias_or_none is not None and b_bias_or_none.numel() > 0This ensures consistency with how other optional tensors are checked and prevents passing empty tensors with has_bias=True to the kernel.
has_bias = b_bias_or_none is not None
python/sglang/jit_kernel/moe_wna16_marlin.py:52
- Parameter name mismatch between Python wrapper and CUDA function. The Python function parameter is named
num_tokens_post_padded(line 52), but the CUDA function signature expectsnum_tokens_post_padded(in moe_wna16_marlin.cuh line 840). However, the internal usage in marlin_mm function usesnum_tokens_past_padded_ptr(line 657, 817).
This is actually an inconsistency that exists in the original AOT code being ported. While it appears to work (likely because the variable is just passed through), the naming should be consistent. The correct name should be num_tokens_post_padded based on the context (tokens after padding), not "past" padded.
num_tokens_post_padded: torch.Tensor,
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 3256397814
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| ) | ||
|
|
||
| # Determine has_zp | ||
| has_zp = b_zeros_or_none is not None and b_zeros_or_none.numel() > 0 |
There was a problem hiding this comment.
Treat empty zero-point tensors as quantized weights
has_zp is derived from b_zeros_or_none.numel() > 0, but call sites like fused_marlin_moe choose b_q_type from w*_zeros is not None; if an expert-parallel shard passes an empty zero-point tensor (e.g. shape with 0 experts), this wrapper sets has_zp=False while still passing b_q_type=uint4, which then trips the kernel-side type check (has_zp=False requires uint4b8/uint8b128) and aborts execution for that rank.
Useful? React with 👍 / 👎.
|
We should serve kimi-k2-thinking and do acc test, similar command is: python3 -m sglang.launch_server --model-path moonshotai/Kimi-K2-Thinking --tp 8 --trust-remote-code --tool-call-parser kimi_k2 --reasoning-parser kimi_k2 --model-loader-extra-config='{"enable_multithread_load": "true","num_threads": 64}' |
|
Please fix lint |
python3 benchmark/gsm8k/bench_sglang.py --num-questions 2000 --parallel 2000 --num-shots 8
Downloading from https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl to /tmp/test.jsonl
/tmp/test.jsonl: 732kB [00:00, 36.9MB/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:55<00:00, 23.71it/s]
Accuracy: 0.935
Invalid: 0.000
Latency: 55.630 s
Output throughput: 2451.688 token/s |
|
/tag-and-rerun-ci |
4c1fd07 to
f9ca01d
Compare
|
Merged with ci green and one unrelated error(ds v3.2) : https://github.com/sgl-project/sglang/actions/runs/22385088337/job/64922987872?pr=19181 |
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Motivation
See #17865
Modifications
New files:
python/sglang/jit_kernel/csrc/gemm/marlin_moe/moe_wna16_marlin.cuh— JIT-compiled CUDA kernel ported fromsgl-kernel/csrc/moe/marlin_moe_wna16/ops.cupython/sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h— Marlin MoE kernel template (ported from sgl-kernel)python/sglang/jit_kernel/csrc/gemm/marlin_moe/kernel.h— Kernel header definitionspython/sglang/jit_kernel/moe_wna16_marlin.py— Python wrapper with JIT loading and output tensor allocationpython/sglang/jit_kernel/tests/test_moe_wna16_marlin.py— Unit tests (JIT vs AOT bitwise equality)python/sglang/jit_kernel/benchmark/bench_moe_wna16_marlin.py— Benchmark (JIT vs AOT latency comparison)Modified files:
python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py— Switchmoe_wna16_marlin_gemmimport fromsgl_kernel(AOT) tosglang.jit_kernel(JIT)Accuracy Tests
Pass all tests defined in
python/sglang/jit_kernel/tests/test_moe_wna16_marlin.py— verifies bitwise equality (rtol=0, atol=0) between JIT and AOT kernels across:Benchmarking and Profiling
Checklist