Skip to content

[Kernel Slimming] Migrate marlin moe kernel to JIT#19181

Merged
BBuf merged 9 commits intosgl-project:mainfrom
celve:jit-marlin-fused-moe
Feb 26, 2026
Merged

[Kernel Slimming] Migrate marlin moe kernel to JIT#19181
BBuf merged 9 commits intosgl-project:mainfrom
celve:jit-marlin-fused-moe

Conversation

@celve
Copy link
Copy Markdown
Collaborator

@celve celve commented Feb 23, 2026

Motivation

See #17865

Modifications

New files:

  • python/sglang/jit_kernel/csrc/gemm/marlin_moe/moe_wna16_marlin.cuh — JIT-compiled CUDA kernel ported from sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu
  • python/sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h — Marlin MoE kernel template (ported from sgl-kernel)
  • python/sglang/jit_kernel/csrc/gemm/marlin_moe/kernel.h — Kernel header definitions
  • python/sglang/jit_kernel/moe_wna16_marlin.py — Python wrapper with JIT loading and output tensor allocation
  • python/sglang/jit_kernel/tests/test_moe_wna16_marlin.py — Unit tests (JIT vs AOT bitwise equality)
  • python/sglang/jit_kernel/benchmark/bench_moe_wna16_marlin.py — Benchmark (JIT vs AOT latency comparison)

Modified files:

  • python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py — Switch moe_wna16_marlin_gemm import from sgl_kernel (AOT) to sglang.jit_kernel (JIT)

Accuracy Tests

Pass all tests defined in python/sglang/jit_kernel/tests/test_moe_wna16_marlin.py — verifies bitwise equality (rtol=0, atol=0) between JIT and AOT kernels across:

  • Batch sizes: m={1, 123}
  • Output dims: n={128, 1024}
  • Data types: {float16, bfloat16}
  • Activation ordering: {True, False}
  • Quantization types: {uint4, uint4b8}

Benchmarking and Profiling

moe-wna16-marlin-gemm-performance (E=8, K=N=4096, uint4b8, fp16, unit: us):
   size_m  JIT Kernel  AOT Kernel
        1      63.40       63.72
        2     113.41      114.46
        4     190.85      166.12
        8     218.28      219.80
       16     218.71      220.54
       32     219.76      222.61
       64     222.16      224.63
      128     225.73      228.57
      256     337.17      316.78
      512     547.23      525.07

Checklist

Copilot AI review requested due to automatic review settings February 23, 2026 10:34
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @celve, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors the Marlin Mixture-of-Experts (MoE) kernel by transitioning its compilation strategy from Ahead-Of-Time (AOT) to Just-In-Time (JIT). This change aims to improve the adaptability and potential performance of the kernel within the SGLang framework. The new JIT implementation is thoroughly validated through extensive unit tests, confirming bitwise equivalence with the previous AOT version, and its performance is benchmarked to ensure no regressions and identify potential gains. The integration into the existing fused MoE layer ensures a seamless transition for users.

Highlights

  • Marlin MoE Kernel Migration to JIT: The core Marlin Mixture-of-Experts (MoE) kernel, moe_wna16_marlin_gemm, has been migrated from Ahead-Of-Time (AOT) compilation to Just-In-Time (JIT) compilation, enhancing flexibility and dynamic optimization capabilities.
  • New JIT Kernel Implementation: Introduced new CUDA C++ source files (kernel.h, marlin_template.h, moe_wna16_marlin.cuh) that implement the JIT-compiled Marlin MoE GEMM operation, adapted from the original Marlin project.
  • Python Wrapper and Integration: A new Python wrapper (moe_wna16_marlin.py) was added to facilitate JIT loading, handle output tensor allocation, and manage parameter passing for the new kernel. The existing fused_marlin_moe.py was updated to utilize this JIT-compiled version.
  • Comprehensive Testing and Benchmarking: New unit tests (test_moe_wna16_marlin.py) were implemented to ensure bitwise equality between the JIT and AOT kernels across various configurations (batch sizes, output dimensions, data types, activation ordering, quantization types). A dedicated benchmark script (bench_moe_wna16_marlin.py) was also added to compare JIT and AOT performance.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/jit_kernel/benchmark/bench_moe_wna16_marlin.py
    • Added a new benchmark script to compare the performance of the JIT-compiled Marlin MoE kernel against its AOT counterpart.
  • python/sglang/jit_kernel/csrc/gemm/marlin_moe/kernel.h
    • Added header definitions for the Marlin MoE kernel parameters and template declaration.
  • python/sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h
    • Added the core Marlin MoE kernel template, including CUDA device functions for matrix multiply-accumulate (MMA) operations, shared memory management, and global reduction logic.
  • python/sglang/jit_kernel/csrc/gemm/marlin_moe/moe_wna16_marlin.cuh
    • Added the CUDA kernel implementation for Marlin MoE, which includes logic for determining optimal execution configurations and launching the kernel.
  • python/sglang/jit_kernel/moe_wna16_marlin.py
    • Added a Python wrapper for the JIT-compiled Marlin MoE GEMM operation, handling JIT loading, output tensor allocation, and parameter conversion for the CUDA kernel.
  • python/sglang/jit_kernel/tests/test_moe_wna16_marlin.py
    • Added unit tests to verify the bitwise equality of the JIT-compiled Marlin MoE kernel against the AOT kernel across various input shapes, data types, and quantization configurations.
  • python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py
    • Updated imports to include the new JIT-compiled moe_wna16_marlin_gemm function.
    • Modified calls to moe_wna16_marlin_gemm to use the JIT-compiled version instead of the AOT torch.ops.sgl_kernel implementation.
    • Adjusted parameter passing for b_q_type from b_q_type_id to the ScalarType object for the JIT kernel calls.
Activity
  • The pull request addresses issue [Feature] sgl-kernel wheel slimming plan tracking #17865, focusing on kernel slimming by migrating the Marlin MoE kernel to JIT.
  • New files were introduced for the JIT-compiled CUDA kernel, its template, header definitions, Python wrapper, unit tests, and benchmarks.
  • Existing files were modified to switch from AOT to JIT kernel imports and usage.
  • Accuracy tests passed, verifying bitwise equality between JIT and AOT kernels for various batch sizes, output dimensions, data types, activation ordering, and quantization types.
  • Benchmarking results were provided, showing comparable or slightly better performance for the JIT kernel across different batch sizes.
  • The author completed checklist items for code formatting, unit tests, and code style guidance.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@celve
Copy link
Copy Markdown
Collaborator Author

celve commented Feb 23, 2026

Serve TheBloke/dolphin-2.7-mixtral-8x7b-AWQ with JIT:

subject: abstract_algebra, #q:100, acc: 0.340                                                                                                                              
subject: anatomy, #q:135, acc: 0.644                                                                                                                                       
subject: astronomy, #q:152, acc: 0.763                                                                                                                                     
subject: business_ethics, #q:100, acc: 0.700                                                                                                                               
subject: clinical_knowledge, #q:265, acc: 0.800                                                                                                                            
subject: college_biology, #q:144, acc: 0.785                                                                                                                               
subject: college_chemistry, #q:100, acc: 0.540                                                                                                                             
subject: college_computer_science, #q:100, acc: 0.630                                                                                                                      
subject: college_mathematics, #q:100, acc: 0.400                                                                                                                           
subject: college_medicine, #q:173, acc: 0.699
subject: college_physics, #q:102, acc: 0.392
subject: computer_security, #q:100, acc: 0.770
subject: conceptual_physics, #q:235, acc: 0.626
subject: econometrics, #q:114, acc: 0.614
subject: electrical_engineering, #q:145, acc: 0.641
subject: elementary_mathematics, #q:378, acc: 0.463
subject: formal_logic, #q:126, acc: 0.516
subject: global_facts, #q:100, acc: 0.450
subject: high_school_biology, #q:310, acc: 0.819
subject: high_school_chemistry, #q:203, acc: 0.581
subject: high_school_computer_science, #q:100, acc: 0.740
subject: high_school_european_history, #q:165, acc: 0.770
subject: high_school_geography, #q:198, acc: 0.854
subject: high_school_government_and_politics, #q:193, acc: 0.927
subject: high_school_macroeconomics, #q:390, acc: 0.703
subject: high_school_mathematics, #q:270, acc: 0.396
subject: high_school_microeconomics, #q:238, acc: 0.752
subject: high_school_physics, #q:151, acc: 0.384
subject: high_school_psychology, #q:545, acc: 0.866
subject: high_school_statistics, #q:216, acc: 0.556
subject: high_school_us_history, #q:204, acc: 0.853
subject: high_school_world_history, #q:237, acc: 0.869
subject: human_aging, #q:223, acc: 0.709
subject: human_sexuality, #q:131, acc: 0.779
subject: international_law, #q:121, acc: 0.876
subject: jurisprudence, #q:108, acc: 0.806
subject: logical_fallacies, #q:163, acc: 0.761
subject: machine_learning, #q:112, acc: 0.482
subject: management, #q:103, acc: 0.864
subject: marketing, #q:234, acc: 0.906
subject: medical_genetics, #q:100, acc: 0.770
subject: miscellaneous, #q:783, acc: 0.872
subject: moral_disputes, #q:346, acc: 0.772
subject: moral_scenarios, #q:895, acc: 0.451
subject: nutrition, #q:306, acc: 0.778
subject: philosophy, #q:311, acc: 0.759
subject: prehistory, #q:324, acc: 0.796
subject: professional_accounting, #q:282, acc: 0.500
subject: professional_law, #q:1534, acc: 0.512
subject: professional_medicine, #q:272, acc: 0.739
subject: professional_psychology, #q:612, acc: 0.752
subject: public_relations, #q:110, acc: 0.673
subject: security_studies, #q:245, acc: 0.739
subject: sociology, #q:201, acc: 0.891
subject: us_foreign_policy, #q:100, acc: 0.870
subject: virology, #q:166, acc: 0.518
subject: world_religions, #q:171, acc: 0.865
Total latency: 679.298
Average accuracy: 0.682

With AOT:

subject: abstract_algebra, #q:100, acc: 0.330                                                                                                                              
subject: anatomy, #q:135, acc: 0.644                                                                                                                                       
subject: astronomy, #q:152, acc: 0.770                                                                                                                                     
subject: business_ethics, #q:100, acc: 0.680                                                                                                                               
subject: clinical_knowledge, #q:265, acc: 0.800                                                                                                                            
subject: college_biology, #q:144, acc: 0.799                                                                                                                               
subject: college_chemistry, #q:100, acc: 0.550                                                                                                                             
subject: college_computer_science, #q:100, acc: 0.620                                                                                                                      
subject: college_mathematics, #q:100, acc: 0.410                                                                                                                           
subject: college_medicine, #q:173, acc: 0.699 
subject: college_physics, #q:102, acc: 0.402
subject: computer_security, #q:100, acc: 0.770
subject: conceptual_physics, #q:235, acc: 0.630
subject: econometrics, #q:114, acc: 0.614
subject: electrical_engineering, #q:145, acc: 0.641
subject: elementary_mathematics, #q:378, acc: 0.471
subject: formal_logic, #q:126, acc: 0.516
subject: global_facts, #q:100, acc: 0.430
subject: high_school_biology, #q:310, acc: 0.823
subject: high_school_chemistry, #q:203, acc: 0.586
subject: high_school_computer_science, #q:100, acc: 0.740
subject: high_school_european_history, #q:165, acc: 0.776
subject: high_school_geography, #q:198, acc: 0.848
subject: high_school_government_and_politics, #q:193, acc: 0.922
subject: high_school_macroeconomics, #q:390, acc: 0.705
subject: high_school_mathematics, #q:270, acc: 0.393
subject: high_school_microeconomics, #q:238, acc: 0.752
subject: high_school_physics, #q:151, acc: 0.384
subject: high_school_psychology, #q:545, acc: 0.864
subject: high_school_statistics, #q:216, acc: 0.565
subject: high_school_us_history, #q:204, acc: 0.863
subject: high_school_world_history, #q:237, acc: 0.869
subject: human_aging, #q:223, acc: 0.709
subject: human_sexuality, #q:131, acc: 0.779
subject: international_law, #q:121, acc: 0.876
subject: jurisprudence, #q:108, acc: 0.806
subject: logical_fallacies, #q:163, acc: 0.748
subject: machine_learning, #q:112, acc: 0.491
subject: management, #q:103, acc: 0.874
subject: marketing, #q:234, acc: 0.906
subject: medical_genetics, #q:100, acc: 0.770
subject: miscellaneous, #q:783, acc: 0.871
subject: moral_disputes, #q:346, acc: 0.772
subject: moral_scenarios, #q:895, acc: 0.449
subject: nutrition, #q:306, acc: 0.781
subject: philosophy, #q:311, acc: 0.756
subject: prehistory, #q:324, acc: 0.799
subject: professional_accounting, #q:282, acc: 0.511
subject: professional_law, #q:1534, acc: 0.512
subject: professional_medicine, #q:272, acc: 0.735
subject: professional_psychology, #q:612, acc: 0.745
subject: public_relations, #q:110, acc: 0.682
subject: security_studies, #q:245, acc: 0.739
subject: sociology, #q:201, acc: 0.891
subject: us_foreign_policy, #q:100, acc: 0.870
subject: virology, #q:166, acc: 0.524
subject: world_religions, #q:171, acc: 0.865
Total latency: 671.739
Average accuracy: 0.682

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request successfully migrates the Marlin MoE kernel to a JIT-compiled version, which helps in reducing the binary size and improving flexibility. The implementation follows the existing Marlin logic and includes comprehensive tests and benchmarks. I have identified a few issues related to potential integer overflows in stride calculations, missing validation for quantization group alignment, and some complex boolean expressions that could be clarified with parentheses to improve maintainability.

Comment on lines +372 to +374
const int scales_expert_stride = prob_n * prob_k / group_size / (w_type == host::kFE2M1f ? 16 : 8);
const int zp_expert_stride =
is_zp_float ? prob_n * prob_k / group_size / 8 : prob_n * prob_k / group_size / (pack_factor * 4);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calculation of scales_expert_stride and zp_expert_stride involves an intermediate product prob_n * prob_k using 32-bit signed integers. If both dimensions are large (e.g., 65536), this product will overflow, leading to incorrect stride values and potential memory corruption. It is recommended to cast one of the operands to int64_t before multiplication.

Comment on lines +626 to +639
group_blocks = group_size / 16;
host::RuntimeCheck(
prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks);
} else {
host::RuntimeCheck(group_size == 0);
group_blocks = 0;
}
} else {
if (group_size == -1) {
group_blocks = -1;
} else {
group_blocks = group_size / 16;
host::RuntimeCheck(
prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The group_blocks is derived by dividing group_size by 16 without verifying that group_size is actually a multiple of 16. Marlin kernels rely on 16x16 tile alignment for quantization groups. If an unaligned group_size is provided, the truncation will cause incorrect indexing into the scales and zero-points buffers. Additionally, if group_blocks becomes 0, it may lead to division-by-zero errors in the kernel. A runtime check should be added to ensure group_size % 16 == 0 when grouping is enabled (i.e., group_size > 0).

static constexpr auto w_type = host::ScalarType::from_id(w_type_id);
static constexpr auto s_type = host::ScalarType::from_id(s_type_id);
if constexpr (w_type == host::kFE2M1f) {
static_assert(s_type == host::kFE4M3fn && group_blocks == 1 || s_type == host::kFE8M0fnu && group_blocks == 2);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The static_assert condition is complex and lacks parentheses to clearly define the operator precedence between && and ||. While C++ precedence rules handle this correctly, adding explicit parentheses would improve readability and prevent potential logic errors during future maintenance.

    static_assert((s_type == host::kFE4M3fn && group_blocks == 1) || (s_type == host::kFE8M0fnu && group_blocks == 2));

Comment on lines +359 to +362
constexpr bool dequant_skip_flop = w_type == host::kFE4M3fn ||
w_type == host::kFE2M1f && s_type == host::kFE4M3fn ||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(w_type == host::kU8);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This complex boolean expression for dequant_skip_flop is difficult to parse due to the mix of && and || operators without parentheses. Adding parentheses to group the logical units would significantly improve maintainability and clarity.

  constexpr bool dequant_skip_flop = (w_type == host::kFE4M3fn) ||
                                     (w_type == host::kFE2M1f && s_type == host::kFE4M3fn) ||
                                     (has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value) ||
                                     (has_zp && !is_zp_float && !(w_type == host::kU8));

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request migrates the Marlin MoE (Mixture of Experts) kernel from Ahead-of-Time (AOT) compilation in the sgl-kernel package to Just-in-Time (JIT) compilation. This is part of a larger kernel slimming initiative to reduce the size of the sgl-kernel wheel, which currently takes up 1633 MB on H100 systems. The Marlin MoE kernels alone account for approximately 370 MB (22.67%).

Changes:

  • Migrates the moe_wna16_marlin_gemm kernel from AOT (sgl-kernel) to JIT compilation
  • Changes the API parameter from b_q_type_id (integer) to b_q_type (ScalarType object)
  • Adds comprehensive unit tests verifying bitwise equality between JIT and AOT implementations
  • Includes benchmark code to validate performance parity

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated no comments.

Show a summary per file
File Description
python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py Updates import to use JIT kernel and changes parameter from b_q_type_id to b_q_type
python/sglang/jit_kernel/moe_wna16_marlin.py Python wrapper that handles JIT compilation, tensor allocation, and parameter conversion
python/sglang/jit_kernel/csrc/gemm/marlin_moe/moe_wna16_marlin.cuh Main CUDA kernel implementation ported from sgl-kernel
python/sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h Marlin MoE kernel template with full implementation
python/sglang/jit_kernel/csrc/gemm/marlin_moe/kernel.h Kernel header definitions
python/sglang/jit_kernel/tests/test_moe_wna16_marlin.py Comprehensive unit tests verifying correctness across multiple configurations
python/sglang/jit_kernel/benchmark/bench_moe_wna16_marlin.py Performance benchmarking code comparing JIT and AOT implementations
Comments suppressed due to low confidence (2)

python/sglang/jit_kernel/moe_wna16_marlin.py:93

  • The has_bias flag is derived from checking if b_bias_or_none is not None (line 93), but this doesn't verify that the tensor actually has valid data. An empty tensor created by _or_empty would still cause has_bias to be False even when the converted b_bias_t is passed.

The logic should check if the bias tensor has elements after conversion, similar to how has_zp is determined:

has_zp = b_zeros_or_none is not None and b_zeros_or_none.numel() > 0

Consider changing line 93 to:

has_bias = b_bias_or_none is not None and b_bias_or_none.numel() > 0

This ensures consistency with how other optional tensors are checked and prevents passing empty tensors with has_bias=True to the kernel.

    has_bias = b_bias_or_none is not None

python/sglang/jit_kernel/moe_wna16_marlin.py:52

  • Parameter name mismatch between Python wrapper and CUDA function. The Python function parameter is named num_tokens_post_padded (line 52), but the CUDA function signature expects num_tokens_post_padded (in moe_wna16_marlin.cuh line 840). However, the internal usage in marlin_mm function uses num_tokens_past_padded_ptr (line 657, 817).

This is actually an inconsistency that exists in the original AOT code being ported. While it appears to work (likely because the variable is just passed through), the naming should be consistent. The correct name should be num_tokens_post_padded based on the context (tokens after padding), not "past" padded.

    num_tokens_post_padded: torch.Tensor,

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 3256397814

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

)

# Determine has_zp
has_zp = b_zeros_or_none is not None and b_zeros_or_none.numel() > 0
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Treat empty zero-point tensors as quantized weights

has_zp is derived from b_zeros_or_none.numel() > 0, but call sites like fused_marlin_moe choose b_q_type from w*_zeros is not None; if an expert-parallel shard passes an empty zero-point tensor (e.g. shape with 0 experts), this wrapper sets has_zp=False while still passing b_q_type=uint4, which then trips the kernel-side type check (has_zp=False requires uint4b8/uint8b128) and aborts execution for that rank.

Useful? React with 👍 / 👎.

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Feb 23, 2026

We should serve kimi-k2-thinking and do acc test, similar command is:

python3 -m sglang.launch_server --model-path moonshotai/Kimi-K2-Thinking --tp 8 --trust-remote-code  --tool-call-parser kimi_k2 --reasoning-parser kimi_k2 --model-loader-extra-config='{"enable_multithread_load": "true","num_threads": 64}' 

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Feb 23, 2026

Please fix lint

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Feb 25, 2026

We should serve kimi-k2-thinking and do acc test, similar command is:

python3 -m sglang.launch_server --model-path moonshotai/Kimi-K2-Thinking --tp 8 --trust-remote-code  --tool-call-parser kimi_k2 --reasoning-parser kimi_k2 --model-loader-extra-config='{"enable_multithread_load": "true","num_threads": 64}' 
 python3 benchmark/gsm8k/bench_sglang.py --num-questions 2000 --parallel 2000 --num-shots 8                 
Downloading from https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl to /tmp/test.jsonl
/tmp/test.jsonl: 732kB [00:00, 36.9MB/s]                                                                                                       
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:55<00:00, 23.71it/s]
Accuracy: 0.935
Invalid: 0.000
Latency: 55.630 s
Output throughput: 2451.688 token/s

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Feb 25, 2026

/tag-and-rerun-ci

@celve celve force-pushed the jit-marlin-fused-moe branch from 4c1fd07 to f9ca01d Compare February 25, 2026 06:21
@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Feb 26, 2026

Merged with ci green and one unrelated error(ds v3.2) : https://github.com/sgl-project/sglang/actions/runs/22385088337/job/64922987872?pr=19181

@BBuf BBuf merged commit beabaa8 into sgl-project:main Feb 26, 2026
202 of 221 checks passed
@celve celve deleted the jit-marlin-fused-moe branch February 26, 2026 01:48
klhhhhh pushed a commit to klhhhhh/sglang that referenced this pull request Feb 26, 2026
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
magicYang1573 pushed a commit to magicYang1573/sglang that referenced this pull request Mar 9, 2026
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants