Skip to content

MoE: Skip SiLU/GELU activation for masked experts#15539

Merged
ch-wan merged 17 commits intosgl-project:mainfrom
yuchengz816-bot:unified-activation-kernel-tma
Dec 23, 2025
Merged

MoE: Skip SiLU/GELU activation for masked experts#15539
ch-wan merged 17 commits intosgl-project:mainfrom
yuchengz816-bot:unified-activation-kernel-tma

Conversation

@yuchengz816-bot
Copy link
Copy Markdown
Contributor

Motivation

When filter_expert is enabled, some tokens are routed to masked experts (expert_id == -1) and should skip activation computation. This PR adds Triton SiLU/GELU multiply kernels that correctly bypass compute for these cases while still producing valid outputs, for both sorted and non-sorted routing layouts.

Modifications

Accuracy Tests

gsm8k's accuracy is correct under different concurrency.

Benchmarking and Profiling

python -m sglang.launch_server --model-path mistralai/Mixtral-8x7B-Instruct-v0.1 --tp 8 --ep 8 --trust-remote-code --chunked-prefill-size 10000

python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 500 --random-input 8000 --random-output 2 --random-range-ratio 1 --host 127.0.0.1 --port 30000 --max-concurrency 1

image

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @yuchengz816-bot, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the Mixture-of-Experts (MoE) architecture by optimizing the activation function computation. It introduces a smart bypass for SiLU and GELU activations when tokens are assigned to inactive or 'masked' experts, preventing unnecessary calculations. This optimization is achieved through new, unified Triton kernels that efficiently manage conditional execution based on expert routing, leading to improved performance without compromising accuracy.

Highlights

  • Conditional Activation Skipping: Implemented a mechanism to skip SiLU/GELU activation computations for tokens routed to 'masked experts' (expert_id == -1) in Mixture-of-Experts (MoE) models, improving computational efficiency.
  • New Triton Kernels: Introduced new Triton JIT kernels (act_and_mul_kernel and its Python wrapper act_and_mul_triton) that unify the activation and multiplication logic, supporting both SiLU and GELU activations and handling both sorted and unsorted routing layouts.
  • Unified Logic: The new Triton kernel collapses four variants (two activation types x two routing layouts) into a single, more maintainable implementation using compile-time constants.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new unified Triton kernel, act_and_mul_kernel, and its Python wrapper, act_and_mul_triton, to optimize activation and multiplication operations within the fused Mixture of Experts (MoE) layer. This new kernel consolidates SiLU and GELU activations and handles both sorted and unsorted routing layouts, especially when expert filtering is active. Review comments indicate that topk_ids needs to be flattened to prevent shape mismatches, the docstring for act_and_mul_triton should be updated to remove non-existent parameters, and the BLOCK_SIZE_N parameter should use the optimal value from the config dictionary instead of a hardcoded value for better performance.

"""
grid = (down_input.shape[0],)
hidden_size = gateup_output.shape[1]
expert_ids_row = topk_ids if not down_moe_use_tma else expert_ids
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The act_and_mul_kernel expects a 1D tensor of expert IDs when down_moe_use_tma is false. However, topk_ids is a 2D tensor. You need to flatten it before passing it to the kernel to avoid shape mismatch errors.

    expert_ids_row = topk_ids.flatten() if not down_moe_use_tma else expert_ids

Comment on lines +875 to +890
"""
Unified activation and multiply wrapper that dispatches to the unified kernel
with appropriate compile-time constants.

Args:
gateup_output: Input tensor containing gate and up outputs concatenated
down_input: Output tensor for the result
hidden_size: Size of the hidden dimension
config: Configuration dictionary with BLOCK_SIZE_M and BLOCK_SIZE_N
topk_ids: Expert IDs for unsorted routing (used when down_moe_use_tma=False)
expert_ids: Expert IDs for sorted routing (used when down_moe_use_tma=True)
num_tokens_post_padded: Number of tokens after padding (used for sorted routing)
sorted_token_ids: Sorted token IDs (used for sorted routing)
down_moe_use_tma: Whether to use sorted routing layout
activation: Activation type ("silu" or "gelu")
"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for act_and_mul_triton lists several parameters that are not actually part of the function's signature: hidden_size, num_tokens_post_padded, and sorted_token_ids. These should be removed to avoid confusion and keep the documentation accurate.

    """
    Unified activation and multiply wrapper that dispatches to the unified kernel
    with appropriate compile-time constants.

    Args:
        gateup_output: Input tensor containing gate and up outputs concatenated
        down_input: Output tensor for the result
        config: Configuration dictionary with BLOCK_SIZE_M and BLOCK_SIZE_N
        topk_ids: Expert IDs for unsorted routing (used when down_moe_use_tma=False)
        expert_ids: Expert IDs for sorted routing (used when down_moe_use_tma=True)
        down_moe_use_tma: Whether to use sorted routing layout
        activation: Activation type ("silu" or "gelu")
    """

hidden_size,
expert_ids_row,
expert_step,
BLOCK_SIZE_N=512,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The BLOCK_SIZE_N is hardcoded to 512. The config dictionary, which contains an optimal BLOCK_SIZE_N, is passed to this function but not used for this parameter. You should use config["BLOCK_SIZE_N"] to ensure optimal performance.

        BLOCK_SIZE_N=config["BLOCK_SIZE_N"],

@ch-wan
Copy link
Copy Markdown
Collaborator

ch-wan commented Dec 22, 2025

/tag-and-rerun-ci

@ch-wan ch-wan self-assigned this Dec 22, 2025
@ch-wan ch-wan merged commit 061f41a into sgl-project:main Dec 23, 2025
300 of 322 checks passed
Liwansi added a commit to iforgetmyname/sglang that referenced this pull request Dec 23, 2025
…n_eagle3_dp

* 'main' of https://github.com/sgl-project/sglang: (208 commits)
  MoE: Skip SiLU/GELU activation for masked experts (sgl-project#15539)
  [GLM-ASR] GLM-ASR Support  (sgl-project#15570)
  Improve engine customization interface (sgl-project#15635)
  chore: bump sgl-kernel version to 0.3.20 (sgl-project#15590)
  bugfix[schedule]: Refactor sort method and add related UT (sgl-project#13576)
  Adjust wrong `mtp` meaning introduce by mimo (sgl-project#15632)
  Tiny add back missing router per attempt response metric (sgl-project#15621)
  Fix router gRPC mode launch error caused by async loading (sgl-project#15368)
  [model-gateway] return 503 when all workers are circuit-broken (sgl-project#15611)
  [Diffusion] Support peak memory record in offline generate and serving (sgl-project#15610)
  [VLM] Tiny: Unify VLM environment variables (sgl-project#15572)
  [diffusion] chore: remove default post-denoising dit offload in local mode (sgl-project#15573)
  Tiny enable soft watchdog in CI for stuck without logs (sgl-project#15616)
  Tiny add stuck simulation (sgl-project#15613)
  Support soft watchdog for tokenizer/detokenizer/dp-controller processes (sgl-project#15607)
  Tiny avoid EnvField misuse (sgl-project#15612)
  add decode round robin policy (sgl-project#15164)
  Add glm-4.6-fp8 with/without mtp in nightly ci (sgl-project#15566)
  Adapt fixture-kit to gsm8k mixin (sgl-project#15599)
  [model-gateway] add retry support to OpenAI router chat endpoint (sgl-project#15589)
  ...
jiaming1130 pushed a commit to zhuyijie88/sglang that referenced this pull request Dec 25, 2025
Co-authored-by: Runkai Tao <rt572@physics.rutgers.edu>
YChange01 pushed a commit to YChange01/sglang that referenced this pull request Jan 13, 2026
Co-authored-by: Runkai Tao <rt572@physics.rutgers.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants