Skip to content

perf: TRT-LLM MoE Block-FP8 activation optimization#2063

Merged
yzh119 merged 4 commits intoflashinfer-ai:mainfrom
nekorobov:nkorobov/optimize-ds-activation-kernel
Nov 8, 2025
Merged

perf: TRT-LLM MoE Block-FP8 activation optimization#2063
yzh119 merged 4 commits intoflashinfer-ai:mainfrom
nekorobov:nkorobov/optimize-ds-activation-kernel

Conversation

@nekorobov
Copy link
Copy Markdown
Collaborator

@nekorobov nekorobov commented Nov 7, 2025

📌 Description

  • Small optimization to the activation kernel for block-FP8 MoE for large batch size.
BS Baseline, us Optimized, us
1 2.4 2.1
32 3.5 2.6
256 21.7 8.7
1024 84.4 23.8
4096 333 87.0
16384 1330 365
  • Adding micro-benchmark for DS FP8 implemented by @IwakuraRein.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features
    • Improved Mixture-of-Experts inference with configurable multi-token batching per GPU core for higher throughput.
    • Expanded FP8 quantization with a new block-scale mode and dynamic, hardware-aware kernel scheduling for better utilization and numerical stability.
    • Vectorized max-reduction and per-block scaling to accelerate reductions and improve output scaling precision.
    • Autotuner/CLI now exposes the FP8 block quantization option for tuning.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Nov 7, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Extends the DeepSeek activation path to support configurable tokens-per-CTA and vectorized max reductions via packed types; introduces KernelTraits and packing helpers; changes kernel launch to hardware-aware dynamic grid sizing with a NumTokensPerCta-aware launcher; and adds an FP8 block-scale autotuning path in the benchmark script.

Changes

Cohort / File(s) Summary
Activation kernel (vectorized reductions & CTA partitioning)
csrc/trtllm_fused_moe_dev_kernel.cu
Added Float4Max/Float2Max functors; packing helpers (packedTypeFromArray, arrayFromPackedType) with specializations for float4/float2/float; introduced `KernelTraits<1
Kernel launcher & runtime sizing
csrc/trtllm_fused_moe_dev_kernel.cu
Reworked run() FP8 path to compute gridSizeX/gridSizeY from SM count and outputDim, choose numTokensPerCta (1/2/4) via heuristic, and launch activationDeepSeekKernel through new LAUNCH_ACTIVATION macro; non-FP8 path unified to use launcher with numTokensPerCta=1.
Public API / macros / KernelParams signature
include/flashinfer/trtllm/fused_moe/DevKernel.h
Added LAUNCH_NUM_TOKENS_PER_CTA and LAUNCH_ACTIVATION macros to dispatch across NumTokensPerCta; changed activation::KernelParams template from <Type_, bool UsePdl_> to <Type_, int32_t NumTokensPerCta_, bool UsePdl_> and exposed static constexpr NumTokensPerCta. Updated LAUNCH_PDL instantiations accordingly.
Autotuner benchmark: FP8 block-scale path
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py
Added "Fp8-Block" quant_mode support and dispatch to trtllm_fp8_block_scale_moe; constructs block-scale tensors for block grouping, threads routing_bias into block path; extended CLI choices and function signature to accept the new quant_mode.
Includes / minor integration
csrc/trtllm_fused_moe_dev_kernel.cu
Added #include <algorithm> and #include "flashinfer/utils.cuh".

Sequence Diagram

sequenceDiagram
    participant Host as Host
    participant Launcher as Launcher (LAUNCH_ACTIVATION)
    participant Grid as Grid (kernel grid)
    participant CTA as CTA (thread block)
    Note over Host,Launcher: Host computes gridX/gridY and selects NumTokensPerCta (1/2/4)
    Host->>Launcher: request launch (numTokensPerCta)
    Launcher->>Grid: start activationDeepSeekKernel(gridX,gridY,NumTokensPerCta)
    par Per CTA (NumTokensPerCta tokens)
        CTA->>CTA: for each tokenInCtaIdx load data -> compute absOutArr
        CTA->>CTA: pack absOutArr into PackedType
        CTA->>CTA: BlockReduce.Reduce(absOutPacked, MaxOp) -> aMaxPacked
        CTA->>CTA: unpack aMaxArr = arrayFromPackedType(aMaxPacked)
        CTA->>CTA: compute per-CTA scaleOut and write params.outDqSfsPtr
        CTA->>CTA: scale and write per-token outputs
    end
    Grid->>Host: kernel complete
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

  • Areas needing extra attention:
    • Correctness and symmetry of packing/unpacking (packedTypeFromArrayarrayFromPackedType) and MaxOp semantics.
    • BlockReduce usage with packed types and per-CTA TempStorage placement.
    • Heuristic and correctness for choosing numTokensPerCta and gridSizeX/gridSizeY across SM counts and edge cases.
    • Propagation of new KernelParams template (NumTokensPerCta) and macro dispatch correctness in all callers.
    • Shared memory sizing and per-CTA arrays to prevent overruns.

Possibly related PRs

Suggested reviewers

  • joker-eph
  • aleozlx
  • djmmoss
  • yongwww
  • cyx-6
  • wenscarl
  • bkryu

Poem

🐰 I hopped through kernels, packing four by four,

Maxes found in bundles, CTA doors ajar.
Grids learned to stretch to silicon light,
FP8 blocks now hum, scaling just right.
A tiny rabbit cheered: "Launch it bright!" ✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 13.33% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'perf: TRT-LLM MoE Block-FP8 activation optimization' clearly and specifically describes the main change: a performance optimization to the activation kernel for block-FP8 MoE, which aligns with the substantial kernel refactoring and infrastructure additions in the changeset.
Description check ✅ Passed The PR description covers the primary motivation (activation kernel optimization for block-FP8 MoE) with performance benchmarks and mentions a new micro-benchmark. All required template sections are present with checklist items marked.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @nekorobov, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a substantial performance optimization to the TRT-LLM MoE Block-FP8 activation kernel. By implementing vectorized processing of multiple tokens per CUDA thread block and an adaptive kernel launch strategy, the changes significantly reduce execution time, particularly for large batch sizes. The core idea is to improve data locality and parallelism within the kernel, leading to more efficient utilization of GPU resources.

Highlights

  • Vectorized Activation Kernel: The activationDeepSeekKernel has been refactored to process multiple tokens concurrently within a single CUDA thread block, leveraging packed data types (float4, float2) and unrolled loops for improved data parallelism.
  • Adaptive Token Processing: A dynamic dispatch mechanism has been implemented to select the optimal number of tokens processed per CTA (1, 2, or 4) based on the total number of CTAs and the device's Streaming Multiprocessor (SM) count, optimizing for different batch sizes and hardware configurations.
  • Performance Improvements: The changes yield significant performance gains for the TRT-LLM MoE Block-FP8 activation, especially for large batch sizes, with speedups up to ~3.8x for a batch size of 16384.
  • Generalized Block Reduction: New helper structs and templates (Float4Max, Float2Max, packedTypeFromArray, arrayFromPackedType, KernelTraits) were introduced to generalize block-level reduction operations for packed data types, enabling efficient maximum computation across multiple tokens.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization for the Block-FP8 MoE activation kernel, targeting large batch sizes. The core idea is to process multiple tokens per CTA to improve vectorization and memory access patterns. This is achieved by introducing a NumTokensPerCta template parameter and using packed data types (float2, float4) for reductions.

The changes are well-structured, but I've identified two critical issues that could lead to incorrect behavior or deadlocks:

  1. Uninitialized local arrays in activationDeepSeekKernel can be read, causing undefined behavior.
  2. A break statement within a if (threadIdx.x == 0) block can cause a deadlock due to thread divergence before a __syncthreads.

Additionally, I've pointed out a medium-severity issue with a hardcoded device ID.

After addressing these points, the optimization should be robust and provide the significant performance improvements shown in the description.

Comment on lines +222 to +228
float scale1Arr[NumTokensPerCta];
float scale2Arr[NumTokensPerCta];
float dataX1Arr[NumTokensPerCta];
float dataX2Arr[NumTokensPerCta];
float outArr[NumTokensPerCta];
float absOutArr[NumTokensPerCta];
int permutedIdxArr[NumTokensPerCta];
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The local arrays declared here are not initialized. In the loop at lines 237-262, if the break (L240) or continue (L247) statements are executed, some elements of these arrays will not be written to. The subsequent loops (e.g., L265-272) will then read from this uninitialized memory, leading to undefined behavior and incorrect results.

To fix this, these arrays should be initialized. The float arrays can be zero-initialized using {}. permutedIdxArr should also be initialized to -1s, for example with a loop, to ensure correctness of checks like if (permutedIdx == -1) in later parts of the kernel.

  float scale1Arr[NumTokensPerCta]{};
  float scale2Arr[NumTokensPerCta]{};
  float dataX1Arr[NumTokensPerCta]{};
  float dataX2Arr[NumTokensPerCta]{};
  float outArr[NumTokensPerCta]{};
  float absOutArr[NumTokensPerCta]{};
  int permutedIdxArr[NumTokensPerCta];

Comment on lines +280 to +293
if (threadIdx.x == 0) {
auto const tokenIdx = tokenCtaIdx + tokenInCtaIdx;
if (tokenIdx >= params.numTokens) {
break;
}
int const permutedIdx = permutedIdxArr[tokenInCtaIdx];
if (permutedIdx == -1) {
continue;
}
s_scaleOutArr[tokenInCtaIdx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal;
int const scaleOut_idx =
permutedIdxArr[tokenInCtaIdx] + totalNumPaddedTokens * (hiddenIdx / 128);
params.outDqSfsPtr[scaleOut_idx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The break statement at line 283 is inside an if (threadIdx.x == 0) block. This means only thread 0 will break out of the tokenInCtaIdx loop, while other threads in the block will continue. This divergence will cause a deadlock when __syncthreads() is called at line 295, as not all threads in the block will reach it. The continue at line 287 has the same problem.

A simple fix is to remove the break and continue and wrap the logic inside the if (threadIdx.x == 0) block with an additional check for tokenIdx and permutedIdx.

          if (threadIdx.x == 0) {
            auto const tokenIdx = tokenCtaIdx + tokenInCtaIdx;
            if (tokenIdx < params.numTokens) {
              int const permutedIdx = permutedIdxArr[tokenInCtaIdx];
              if (permutedIdx != -1) {
                s_scaleOutArr[tokenInCtaIdx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal;
                int const scaleOut_idx =
                    permutedIdxArr[tokenInCtaIdx] + totalNumPaddedTokens * (hiddenIdx / 128);
                params.outDqSfsPtr[scaleOut_idx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal;
              }
            }
          }

Comment thread csrc/trtllm_fused_moe_dev_kernel.cu Outdated
constexpr int NUM_ELTS_PER_SF = 128;
int const NUM_THREADS_PER_CTA = 128;

int const deviceId = 0;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The deviceId is hardcoded to 0. This might cause issues in a multi-GPU environment, as it will always query the attributes of device 0, regardless of which device the kernel is running on. It's safer to get the current device ID.

    int deviceId;
    cudaGetDevice(&deviceId);

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 36d2463 and 99abadc97f3f8ef150735a1d83ceb3bca3fd2660.

📒 Files selected for processing (2)
  • csrc/trtllm_fused_moe_dev_kernel.cu (3 hunks)
  • include/flashinfer/trtllm/fused_moe/DevKernel.h (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/trtllm_fused_moe_dev_kernel.cu (1)
include/flashinfer/trtllm/common/cudaUtils.h (3)
  • float4 (173-175)
  • float2 (169-171)
  • float (165-167)

Comment on lines +222 to +310
int const tokenIdx = tokenCtaIdx + tokenInCtaIdx;
if (tokenIdx >= params.numTokens) {
break;
}

int const scale1_idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128);
int const scale2_idx =
permutedIdx + totalNumPaddedTokens * ((hiddenIdx / 128) + (params.innerDim / 2 / 128));
float const scale1 = params.inDqSfsPtr[scale1_idx];
float const scale2 = params.inDqSfsPtr[scale2_idx];

float x1 = scale1 * (float)params.inPtr[baseIdx];
float x2 = scale2 * (float)params.inPtr[baseIdx + params.innerDim / 2];

float act = silu(x2);
float out = act * x1;
int const expandedIdx = tokenIdx * params.topK + k;
int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx];
permutedIdxArr[tokenInCtaIdx] = permutedIdx;
if (permutedIdx == -1) {
continue;
}

// Process blocks for this CTA
int const baseIdx = permutedIdx * params.innerDim + hiddenIdx;

int const scale1Idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128);
int const scale2Idx = permutedIdx + totalNumPaddedTokens *
((hiddenIdx / 128) + (params.innerDim / 2 / 128));

scale1Arr[tokenInCtaIdx] = params.inDqSfsPtr[scale1Idx];
scale2Arr[tokenInCtaIdx] = params.inDqSfsPtr[scale2Idx];
dataX1Arr[tokenInCtaIdx] = static_cast<float>(params.inPtr[baseIdx]);
dataX2Arr[tokenInCtaIdx] =
static_cast<float>(params.inPtr[baseIdx + params.innerDim / 2]);
}

// The largest (finite) value that can be represented using E4m3.
float constexpr E4m3MaxVal{448.f};
#pragma unroll
for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) {
float x1 = scale1Arr[tokenInCtaIdx] * dataX1Arr[tokenInCtaIdx];
float x2 = scale2Arr[tokenInCtaIdx] * dataX2Arr[tokenInCtaIdx];
float act = silu(x2);
float out = act * x1;
outArr[tokenInCtaIdx] = out;
absOutArr[tokenInCtaIdx] = fabsf(out);
}

// Compute the absolute max
#if CUDA_VERSION >= 12090
float aMax = BlockReduce(temp_storage).Reduce(fabsf(out), cuda::maximum<>{});
#else
float aMax = BlockReduce(temp_storage).Reduce(fabsf(out), cub::Max{});
#endif
if (threadIdx.x == 0) {
s_scaleOut = aMax / E4m3MaxVal;
int const scaleOut_idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128);
params.outDqSfsPtr[scaleOut_idx] = aMax / E4m3MaxVal;
auto absOutPacked = packedTypeFromArray<PackedType, NumTokensPerCta>(absOutArr);
auto aMaxPacked = BlockReduce(tempStorage).Reduce(absOutPacked, MaxOp{});
auto aMaxArr = arrayFromPackedType<PackedType, NumTokensPerCta>(aMaxPacked);

#pragma unroll
for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) {
if (threadIdx.x == 0) {
auto const tokenIdx = tokenCtaIdx + tokenInCtaIdx;
if (tokenIdx >= params.numTokens) {
break;
}
int const permutedIdx = permutedIdxArr[tokenInCtaIdx];
if (permutedIdx == -1) {
continue;
}
s_scaleOutArr[tokenInCtaIdx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal;
int const scaleOut_idx =
permutedIdxArr[tokenInCtaIdx] + totalNumPaddedTokens * (hiddenIdx / 128);
params.outDqSfsPtr[scaleOut_idx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal;
}
}
__syncthreads();
float const scaleOut = s_scaleOut;
__syncthreads();
int const outIdx = permutedIdx * (params.innerDim / 2) + hiddenIdx;
params.outPtr[outIdx] = (Type)(out / scaleOut);

#pragma unroll
for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) {
auto const tokenIdx = tokenCtaIdx + tokenInCtaIdx;
if (tokenIdx >= params.numTokens) {
break;
}
int const permutedIdx = permutedIdxArr[tokenInCtaIdx];
if (permutedIdx == -1) {
continue;
}
float const scaleOut = s_scaleOutArr[tokenInCtaIdx];
int const outIdx = permutedIdx * (params.innerDim / 2) + hiddenIdx;
params.outPtr[outIdx] = static_cast<Type>(outArr[tokenInCtaIdx] / scaleOut);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix uninitialized per-token buffers

When permutedIdx == -1 we continue before populating scale1/scale2/dataX*/out/absOut for that slot. The second pass still multiplies those arrays and feeds them into BlockReduce, so threads read stale values from the previous hidden block, producing bogus aMaxArr and corrupt outDqSfsPtr/outPtr for the remaining tokens in the CTA. Please zero/mark invalid tokens before skipping the load and short-circuit the later math.

@@
-  float outArr[NumTokensPerCta];
-  float absOutArr[NumTokensPerCta];
-  int permutedIdxArr[NumTokensPerCta];
+  float outArr[NumTokensPerCta];
+  float absOutArr[NumTokensPerCta];
+  bool validTokenArr[NumTokensPerCta];
+  int permutedIdxArr[NumTokensPerCta];
@@
-          int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx];
-          permutedIdxArr[tokenInCtaIdx] = permutedIdx;
-          if (permutedIdx == -1) {
-            continue;
-          }
+          int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx];
+          permutedIdxArr[tokenInCtaIdx] = permutedIdx;
+          validTokenArr[tokenInCtaIdx] = (permutedIdx != -1);
+          if (!validTokenArr[tokenInCtaIdx]) {
+            scale1Arr[tokenInCtaIdx] = 0.f;
+            scale2Arr[tokenInCtaIdx] = 0.f;
+            dataX1Arr[tokenInCtaIdx] = 0.f;
+            dataX2Arr[tokenInCtaIdx] = 0.f;
+            continue;
+          }
@@
-        for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) {
-          float x1 = scale1Arr[tokenInCtaIdx] * dataX1Arr[tokenInCtaIdx];
-          float x2 = scale2Arr[tokenInCtaIdx] * dataX2Arr[tokenInCtaIdx];
+        for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) {
+          if (!validTokenArr[tokenInCtaIdx]) {
+            outArr[tokenInCtaIdx] = 0.f;
+            absOutArr[tokenInCtaIdx] = 0.f;
+            continue;
+          }
+          float x1 = scale1Arr[tokenInCtaIdx] * dataX1Arr[tokenInCtaIdx];
+          float x2 = scale2Arr[tokenInCtaIdx] * dataX2Arr[tokenInCtaIdx];
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
float scale1Arr[NumTokensPerCta];
float scale2Arr[NumTokensPerCta];
float dataX1Arr[NumTokensPerCta];
float dataX2Arr[NumTokensPerCta];
float outArr[NumTokensPerCta];
float absOutArr[NumTokensPerCta];
int permutedIdxArr[NumTokensPerCta];
// Loop over tokens
for (int k = blockIdx.z; k < params.topK; k += gridDim.z) {
for (int tokenCtaIdx = blockIdx.y * NumTokensPerCta; tokenCtaIdx < params.numTokens;
tokenCtaIdx += gridDim.y * NumTokensPerCta) {
for (int hiddenIdx = threadIdx.x + blockDim.x * blockIdx.x; hiddenIdx < params.innerDim / 2;
hiddenIdx += blockDim.x * gridDim.x) {
int const baseIdx = permutedIdx * params.innerDim + hiddenIdx;
int const totalNumPaddedTokens = params.totalNumPaddedTokens[0];
#pragma unroll
for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) {
int const tokenIdx = tokenCtaIdx + tokenInCtaIdx;
if (tokenIdx >= params.numTokens) {
break;
}
int const scale1_idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128);
int const scale2_idx =
permutedIdx + totalNumPaddedTokens * ((hiddenIdx / 128) + (params.innerDim / 2 / 128));
float const scale1 = params.inDqSfsPtr[scale1_idx];
float const scale2 = params.inDqSfsPtr[scale2_idx];
float x1 = scale1 * (float)params.inPtr[baseIdx];
float x2 = scale2 * (float)params.inPtr[baseIdx + params.innerDim / 2];
float act = silu(x2);
float out = act * x1;
int const expandedIdx = tokenIdx * params.topK + k;
int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx];
permutedIdxArr[tokenInCtaIdx] = permutedIdx;
if (permutedIdx == -1) {
continue;
}
// Process blocks for this CTA
int const baseIdx = permutedIdx * params.innerDim + hiddenIdx;
int const scale1Idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128);
int const scale2Idx = permutedIdx + totalNumPaddedTokens *
((hiddenIdx / 128) + (params.innerDim / 2 / 128));
scale1Arr[tokenInCtaIdx] = params.inDqSfsPtr[scale1Idx];
scale2Arr[tokenInCtaIdx] = params.inDqSfsPtr[scale2Idx];
dataX1Arr[tokenInCtaIdx] = static_cast<float>(params.inPtr[baseIdx]);
dataX2Arr[tokenInCtaIdx] =
static_cast<float>(params.inPtr[baseIdx + params.innerDim / 2]);
}
// The largest (finite) value that can be represented using E4m3.
float constexpr E4m3MaxVal{448.f};
#pragma unroll
for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) {
float x1 = scale1Arr[tokenInCtaIdx] * dataX1Arr[tokenInCtaIdx];
float x2 = scale2Arr[tokenInCtaIdx] * dataX2Arr[tokenInCtaIdx];
float act = silu(x2);
float out = act * x1;
outArr[tokenInCtaIdx] = out;
absOutArr[tokenInCtaIdx] = fabsf(out);
}
// Compute the absolute max
#if CUDA_VERSION >= 12090
float aMax = BlockReduce(temp_storage).Reduce(fabsf(out), cuda::maximum<>{});
#else
float aMax = BlockReduce(temp_storage).Reduce(fabsf(out), cub::Max{});
#endif
if (threadIdx.x == 0) {
s_scaleOut = aMax / E4m3MaxVal;
int const scaleOut_idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128);
params.outDqSfsPtr[scaleOut_idx] = aMax / E4m3MaxVal;
auto absOutPacked = packedTypeFromArray<PackedType, NumTokensPerCta>(absOutArr);
auto aMaxPacked = BlockReduce(tempStorage).Reduce(absOutPacked, MaxOp{});
auto aMaxArr = arrayFromPackedType<PackedType, NumTokensPerCta>(aMaxPacked);
#pragma unroll
for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) {
if (threadIdx.x == 0) {
auto const tokenIdx = tokenCtaIdx + tokenInCtaIdx;
if (tokenIdx >= params.numTokens) {
break;
}
int const permutedIdx = permutedIdxArr[tokenInCtaIdx];
if (permutedIdx == -1) {
continue;
}
s_scaleOutArr[tokenInCtaIdx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal;
int const scaleOut_idx =
permutedIdxArr[tokenInCtaIdx] + totalNumPaddedTokens * (hiddenIdx / 128);
params.outDqSfsPtr[scaleOut_idx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal;
}
}
__syncthreads();
float const scaleOut = s_scaleOut;
__syncthreads();
int const outIdx = permutedIdx * (params.innerDim / 2) + hiddenIdx;
params.outPtr[outIdx] = (Type)(out / scaleOut);
#pragma unroll
for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) {
auto const tokenIdx = tokenCtaIdx + tokenInCtaIdx;
if (tokenIdx >= params.numTokens) {
break;
}
int const permutedIdx = permutedIdxArr[tokenInCtaIdx];
if (permutedIdx == -1) {
continue;
}
float const scaleOut = s_scaleOutArr[tokenInCtaIdx];
int const outIdx = permutedIdx * (params.innerDim / 2) + hiddenIdx;
params.outPtr[outIdx] = static_cast<Type>(outArr[tokenInCtaIdx] / scaleOut);
}
float scale1Arr[NumTokensPerCta];
float scale2Arr[NumTokensPerCta];
float dataX1Arr[NumTokensPerCta];
float dataX2Arr[NumTokensPerCta];
float outArr[NumTokensPerCta];
float absOutArr[NumTokensPerCta];
bool validTokenArr[NumTokensPerCta];
int permutedIdxArr[NumTokensPerCta];
// Loop over tokens
for (int k = blockIdx.z; k < params.topK; k += gridDim.z) {
for (int tokenCtaIdx = blockIdx.y * NumTokensPerCta; tokenCtaIdx < params.numTokens;
tokenCtaIdx += gridDim.y * NumTokensPerCta) {
for (int hiddenIdx = threadIdx.x + blockDim.x * blockIdx.x; hiddenIdx < params.innerDim / 2;
hiddenIdx += blockDim.x * gridDim.x) {
#pragma unroll
for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) {
int const tokenIdx = tokenCtaIdx + tokenInCtaIdx;
if (tokenIdx >= params.numTokens) {
break;
}
int const expandedIdx = tokenIdx * params.topK + k;
int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx];
permutedIdxArr[tokenInCtaIdx] = permutedIdx;
validTokenArr[tokenInCtaIdx] = (permutedIdx != -1);
if (!validTokenArr[tokenInCtaIdx]) {
scale1Arr[tokenInCtaIdx] = 0.f;
scale2Arr[tokenInCtaIdx] = 0.f;
dataX1Arr[tokenInCtaIdx] = 0.f;
dataX2Arr[tokenInCtaIdx] = 0.f;
continue;
}
// Process blocks for this CTA
int const baseIdx = permutedIdx * params.innerDim + hiddenIdx;
int const scale1Idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128);
int const scale2Idx = permutedIdx + totalNumPaddedTokens *
((hiddenIdx / 128) + (params.innerDim / 2 / 128));
scale1Arr[tokenInCtaIdx] = params.inDqSfsPtr[scale1Idx];
scale2Arr[tokenInCtaIdx] = params.inDqSfsPtr[scale2Idx];
dataX1Arr[tokenInCtaIdx] = static_cast<float>(params.inPtr[baseIdx]);
dataX2Arr[tokenInCtaIdx] =
static_cast<float>(params.inPtr[baseIdx + params.innerDim / 2]);
}
#pragma unroll
for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) {
if (!validTokenArr[tokenInCtaIdx]) {
outArr[tokenInCtaIdx] = 0.f;
absOutArr[tokenInCtaIdx] = 0.f;
continue;
}
float x1 = scale1Arr[tokenInCtaIdx] * dataX1Arr[tokenInCtaIdx];
float x2 = scale2Arr[tokenInCtaIdx] * dataX2Arr[tokenInCtaIdx];
float act = silu(x2);
float out = act * x1;
outArr[tokenInCtaIdx] = out;
absOutArr[tokenInCtaIdx] = fabsf(out);
}
auto absOutPacked = packedTypeFromArray<PackedType, NumTokensPerCta>(absOutArr);
auto aMaxPacked = BlockReduce(tempStorage).Reduce(absOutPacked, MaxOp{});
auto aMaxArr = arrayFromPackedType<PackedType, NumTokensPerCta>(aMaxPacked);
#pragma unroll
for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) {
if (threadIdx.x == 0) {
auto const tokenIdx = tokenCtaIdx + tokenInCtaIdx;
if (tokenIdx >= params.numTokens) {
break;
}
int const permutedIdx = permutedIdxArr[tokenInCtaIdx];
if (permutedIdx == -1) {
continue;
}
s_scaleOutArr[tokenInCtaIdx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal;
int const scaleOut_idx =
permutedIdxArr[tokenInCtaIdx] + totalNumPaddedTokens * (hiddenIdx / 128);
params.outDqSfsPtr[scaleOut_idx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal;
}
}
__syncthreads();
#pragma unroll
for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) {
auto const tokenIdx = tokenCtaIdx + tokenInCtaIdx;
if (tokenIdx >= params.numTokens) {
break;
}
int const permutedIdx = permutedIdxArr[tokenInCtaIdx];
if (permutedIdx == -1) {
continue;
}
float const scaleOut = s_scaleOutArr[tokenInCtaIdx];
int const outIdx = permutedIdx * (params.innerDim / 2) + hiddenIdx;
params.outPtr[outIdx] = static_cast<Type>(outArr[tokenInCtaIdx] / scaleOut);
}
🤖 Prompt for AI Agents
In csrc/trtllm_fused_moe_dev_kernel.cu around lines 222–310, the per-token
scratch arrays are left uninitialized when permutedIdx == -1, which lets stale
values flow into the later math and BlockReduce; fix by marking/clearing those
slots when skipping loads and short-circuiting later computation: when
permutedIdx == -1 set scale1Arr/scale2Arr/dataX1Arr/dataX2Arr/outArr/absOutArr
entries to 0 (and keep permutedIdxArr == -1), then in the subsequent math loops
check permutedIdxArr[token]==-1 and either skip computing x1/x2/act/out (leaving
zeros) or continue, ensure BlockReduce sees zeros for invalid tokens, and avoid
writing to params.outDqSfsPtr and params.outPtr for permutedIdx == -1 (i.e.,
gate those writes with the same check).

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Nov 7, 2025

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !121 has been created, and the CI pipeline #38087397 is currently running. I'll report back once the pipeline job completes.

@nekorobov nekorobov force-pushed the nkorobov/optimize-ds-activation-kernel branch from 99abadc to 161bf9a Compare November 7, 2025 19:37
@nekorobov
Copy link
Copy Markdown
Collaborator Author

/bot kill

@nekorobov
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

@nekorobov is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

1 similar comment
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

@nekorobov is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

@jiahanc
Copy link
Copy Markdown
Collaborator

jiahanc commented Nov 7, 2025

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !121 has been updated with latest changes, and the CI pipeline #38094149 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (2)

51-51: Consider conditional creation of routing_bias.

The routing_bias tensor is created unconditionally but only used in the block-scale path (line 102). The per-tensor path explicitly passes None (line 129). Consider creating this tensor only when is_block_scale is True to avoid unnecessary allocation.

Apply this diff to conditionally create routing_bias:

-    routing_bias = torch.randn(num_experts, device="cuda", dtype=torch.bfloat16)
+    routing_bias = (
+        torch.randn(num_experts, device="cuda", dtype=torch.bfloat16)
+        if quant_mode == "Fp8-Block"
+        else None
+    )

59-81: Consider using actual block quantization for realistic benchmarking.

The comment indicates that block-scale quantization is "too slow" for benchmarking, so per-tensor quantization is used and then broadcast to block shape. This approach creates artificial block scales where all blocks have identical values, which may not accurately represent the performance characteristics of real block-scale quantization in production.

For more realistic benchmarking results, consider either:

  1. Implementing optimized block quantization for the benchmark
  2. Documenting that this benchmark measures kernel performance in isolation, not end-to-end quantization overhead
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 161bf9a and 4a72c06.

📒 Files selected for processing (1)
  • benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (3)
flashinfer/fused_moe/core.py (4)
  • trtllm_fp8_block_scale_moe (1761-1843)
  • WeightLayout (161-168)
  • RoutingMethodType (58-72)
  • trtllm_fp8_per_tensor_scale_moe (1680-1758)
csrc/trtllm_fused_moe_kernel_launcher.cu (4)
  • trtllm_fp8_block_scale_moe (709-760)
  • trtllm_fp8_block_scale_moe (709-716)
  • trtllm_fp8_per_tensor_scale_moe (352-412)
  • trtllm_fp8_per_tensor_scale_moe (352-360)
include/flashinfer/trtllm/fused_moe/runner.h (5)
  • hidden_size (265-265)
  • num_experts (263-263)
  • intermediate_size (275-275)
  • top_k (270-270)
  • RoutingMethodType (37-136)
🪛 Ruff (0.14.3)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py

100-125: Do not assign a lambda expression, use a def

Rewrite fn as a def

(E731)


127-149: Do not assign a lambda expression, use a def

Rewrite fn as a def

(E731)

🔇 Additional comments (8)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (8)

14-15: LGTM!

The new imports are correctly added to support the FP8-Block quantization path.


26-26: LGTM!

The simplified max calculation is cleaner and functionally equivalent.


34-34: LGTM!

The quant_mode parameter with Literal type hint provides clear API documentation.


83-97: LGTM!

The conditional scale tensor setup correctly handles the different API requirements for per-tensor vs block-scale kernels.


99-149: Document the different routing configurations.

The block-scale path uses DeepSeekV3 routing with group parameters (n_group=8, topk_group=4) and routed_scaling_factor=2.5, while the per-tensor path uses TopK routing with routed_scaling_factor=1.0. These different configurations mean the two quantization modes are being benchmarked under different routing scenarios, which may not provide a direct performance comparison.

If this is intentional to reflect typical usage patterns for each quantization mode, consider adding a comment explaining the rationale. Otherwise, consider aligning the routing configurations for a more apples-to-apples comparison.


324-330: LGTM!

The CLI choices are correctly extended to support the new FP8 quantization modes.


355-366: LGTM!

The conditional dispatch correctly routes FP8 quantization modes to the appropriate benchmark function.


195-195: LGTM!

The addition of is_sf_swizzled_layout=False parameter propagates the new API requirement without changing existing behavior.

Copy link
Copy Markdown
Collaborator

@jiahanc jiahanc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks for the work!

@IwakuraRein IwakuraRein self-requested a review November 8, 2025 00:00
Copy link
Copy Markdown
Collaborator

@IwakuraRein IwakuraRein left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for your contribution

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #38094149: 13/17 passed

@yzh119 yzh119 merged commit ba011d1 into flashinfer-ai:main Nov 8, 2025
4 checks passed
BingooYang pushed a commit to BingooYang/flashinfer that referenced this pull request Mar 13, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

- Small optimization to the activation kernel for block-FP8 MoE for
large batch size.

| BS | Baseline, us | Optimized, us |
| ------------- | ------------- | ------------- |
| 1 | 2.4  | 2.1  |
| 32  | 3.5  | 2.6 |
| 256  | 21.7  | 8.7  |
| 1024 | 84.4  | 23.8  |
| 4096  | 333 | 87.0  |
| 16384  | 1330  | 365  |

- Adding micro-benchmark for DS FP8 implemented by @IwakuraRein.

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Improved Mixture-of-Experts inference with configurable multi-token
batching per GPU core for higher throughput.
* Expanded FP8 quantization with a new block-scale mode and dynamic,
hardware-aware kernel scheduling for better utilization and numerical
stability.
* Vectorized max-reduction and per-block scaling to accelerate
reductions and improve output scaling precision.
* Autotuner/CLI now exposes the FP8 block quantization option for
tuning.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Co-authored-by: Siyuan Fu <siyuanf@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants