Skip to content

[Kimi K2] num_experts extends to 384#8007

Open
panpan0000 wants to merge 2 commits intosgl-project:mainfrom
panpan0000:k2-topk
Open

[Kimi K2] num_experts extends to 384#8007
panpan0000 wants to merge 2 commits intosgl-project:mainfrom
panpan0000:k2-topk

Conversation

@panpan0000
Copy link
Copy Markdown
Contributor

@panpan0000 panpan0000 commented Jul 14, 2025

Kimi K2 has "n_routed_experts": 384
 
change the code , to avoid go to "Unexpected num_experts: "

Test Result for CUDA code on A800:

python3  sgl-kernel/tests/test_moe_fused_gate.py

sgl-kernel/tests/test_moe_fused_gate.py 

============ warnings summary ===================================

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============== 594 passed, 1 warning in 11.16s =================================
python sgl-kernel/benchmark/bench_moe_fused_gate.py

moe-fused-gate-performance:
   seq_length    Original  SGL Kernel
0      5000.0   39.503999   39.519999
1     10000.0   60.704000   59.967998
2     15000.0   79.712003   79.712003
3     20000.0   98.112002   97.216003
4     25000.0  115.103997  115.808003
5     30000.0  133.504003  132.671997
6     35000.0  150.335997  150.368005
7     40000.0  168.240003  168.096006
···

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @panpan0000, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a necessary update to the topk_sigmoid_cpu operation, enabling it to correctly handle configurations with 384 experts. This is a targeted fix to prevent a specific runtime error and expand the operational range of the kernel, likely supporting larger or different model architectures.

Highlights

  • Feature Extension: I've identified that this pull request extends the supported number of experts for the topk_sigmoid_cpu kernel. Specifically, it adds support for 384 experts.
  • Error Prevention: The change directly addresses and prevents the "Unexpected num_experts: " runtime error that would occur if a model configuration with 384 experts was used previously.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request extends support for num_experts to 384, specifically within the topk_sigmoid_cpu function. The change is straightforward, adding a new case to a switch statement. My review points out that similar functions within the same file (topk_softmax_cpu and grouped_topk_cpu) have not been updated, which could lead to unexpected errors if they are used with 384 experts. I've recommended updating these as well for consistent behavior.

Comment thread sgl-kernel/csrc/cpu/topk.cpp Outdated
@yhyang201
Copy link
Copy Markdown
Collaborator

It looks like dsv3_router_gemm only supports 256 experts.

Signed-off-by: Peter Pan <Peter.Pan@daocloud.io>
@panpan0000
Copy link
Copy Markdown
Contributor Author

panpan0000 commented Jul 14, 2025

@yhyang201
It looks like dsv3_router_gemm only supports 256 experts.

Try to create another PR #8013 , but no confidence at all.. T_T

@panpan0000
Copy link
Copy Markdown
Contributor Author

panpan0000 commented Jul 15, 2025

It looks like dsv3_router_gemm only supports 256 experts.

@yhyang201
After I revisit the code,
Only when above >= Hopper GPU, the dsv3_router_gemm will be used (in srt/models/deepseek_v2.py)

if(
           .....
            _device_sm >= 90         <----------- >= Hopper serial GPU.
        ):
            logits = dsv3_router_gemm(hidden_states, self.weight).to(
          ...

although my best GPU in hand is Ampere.... cannot test dsv3_router_gemm T_T

But dsv3_router_gemm is NOT in the critical path any how...
So we can move to another PR to enhance dsv3_router_gemm ?

@yhyang201
Copy link
Copy Markdown
Collaborator

It looks like dsv3_router_gemm only supports 256 experts.

@yhyang201 After I revisit the code, Only when above >= Hopper GPU, the dsv3_router_gemm will be used (in srt/models/deepseek_v2.py)

if(
           .....
            _device_sm >= 90         <----------- >= Hopper serial GPU.
        ):
            logits = dsv3_router_gemm(hidden_states, self.weight).to(
          ...

although my best GPU in hand is Ampere.... cannot test dsv3_router_gemm T_T

But dsv3_router_gemm is NOT in the critical path any how... So we can move to another PR to enhance dsv3_router_gemm ?

This operator is designed to reduce latency for small batches. If that isn’t a major concern, you may simply use torch.nn.Linear for now.

@panpan0000
Copy link
Copy Markdown
Contributor Author

panpan0000 commented Jul 16, 2025

It looks like dsv3_router_gemm only supports 256 experts.

@yhyang201 After I revisit the code, Only when above >= Hopper GPU, the dsv3_router_gemm will be used (in srt/models/deepseek_v2.py)

This operator is designed to reduce latency for small batches. If that isn’t a major concern, you may simply use torch.nn.Linear for now.

Yes, @yhyang201 . I meant, so changing dsv3_router_gemm is NOT blocking for other part of this PR, right ?

static constexpr int MAX_VPT = 32; // maximum VPT we support, > params.VPT = num_expert / num_expert_group
// DeepSeek V3/R1: num_experts = 256, n_group = 8
// Kimi K2: num_experts = 384, n_group =1
static constexpr int MAX_VPT = 384; // maximum VPT we support, > params.VPT = num_expert / num_expert_group
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ispobock I feel that directly changing MAX_VPT to 384 in this place, without any register reuse strategy, will cause this kernel to experience register overflow. The modification here is extremely sensitive. I personally believe the closest implementation to what we want currently is this change: #6946. This uses a tiling approach for register reuse and also performs data prefetching between tiles. With minor modifications, it should support the Kimi V2 (384) case here. I'll push this forward as soon as possible.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ispobock on h20 sm90, compile error, the error cause by 384
/sgl-workspace/sglang/sgl-kernel/build/_deps/repo-cutlass-src/include/cutlass/array.h(2849): error: invalid alignment value specified by attribute class alignas(Alignment) AlignedArray: public Array<T, N> { ^ detected during: instantiation of class "cutlass::AlignedArray<T, N, Alignment> [with T=bfloat16_t, N=384, Alignment=768]" at line 86 of /sgl-workspace/sglang/sgl-kernel/csrc/moe/moe_fused_gate.cu instantiation of "void moe_fused_gate_impl<T,Params>(void *, void *, float *, int32_t *, int64_t, int64_t, int64_t, int64_t, double, Params) [with T=bfloat16_t, Params=KernelParams<384, 384, 1, 32, 192, 6>]" at line 299 of /sgl-workspace/sglang/sgl-kernel/csrc/moe/moe_fused_gate.cu instantiation of "void moe_fused_gate_kernel<T,VPT,NUM_EXPERTS,THREADS_PER_ROW,ROWS_PER_WARP,ROWS_PER_CTA,WARPS_PER_CTA>(void *, void *, float *, int32_t *, int64_t, int64_t, int64_t, int64_t, double) [with T=bfloat16_t, VPT=384, NUM_EXPERTS=384, THREADS_PER_ROW=1, ROWS_PER_WARP=32, ROWS_PER_CTA=192, WARPS_PER_CTA=6]" at line 426 of /sgl-workspace/sglang/sgl-kernel/csrc/moe/moe_fused_gate.cu

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also encountered this compilation error

@ltaodream
Copy link
Copy Markdown

Kimi K2 has "n_routed_experts": 384   change the code , to avoid go to "Unexpected num_experts: "

Test Result for CUDA code on A800:

python3  sgl-kernel/tests/test_moe_fused_gate.py

sgl-kernel/tests/test_moe_fused_gate.py 

============ warnings summary ===================================

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============== 594 passed, 1 warning in 11.16s =================================
python sgl-kernel/benchmark/bench_moe_fused_gate.py

moe-fused-gate-performance:
   seq_length    Original  SGL Kernel
0      5000.0   39.503999   39.519999
1     10000.0   60.704000   59.967998
2     15000.0   79.712003   79.712003
3     20000.0   98.112002   97.216003
4     25000.0  115.103997  115.808003
5     30000.0  133.504003  132.671997
6     35000.0  150.335997  150.368005
7     40000.0  168.240003  168.096006
···

May I ask if you are interested in working with me on #6946,it's num_experts extends to 512 with tiling

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants