perf: TRT-LLM MoE Block-FP8 activation optimization#2063
perf: TRT-LLM MoE Block-FP8 activation optimization#2063yzh119 merged 4 commits intoflashinfer-ai:mainfrom
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughExtends the DeepSeek activation path to support configurable tokens-per-CTA and vectorized max reductions via packed types; introduces KernelTraits and packing helpers; changes kernel launch to hardware-aware dynamic grid sizing with a NumTokensPerCta-aware launcher; and adds an FP8 block-scale autotuning path in the benchmark script. Changes
Sequence DiagramsequenceDiagram
participant Host as Host
participant Launcher as Launcher (LAUNCH_ACTIVATION)
participant Grid as Grid (kernel grid)
participant CTA as CTA (thread block)
Note over Host,Launcher: Host computes gridX/gridY and selects NumTokensPerCta (1/2/4)
Host->>Launcher: request launch (numTokensPerCta)
Launcher->>Grid: start activationDeepSeekKernel(gridX,gridY,NumTokensPerCta)
par Per CTA (NumTokensPerCta tokens)
CTA->>CTA: for each tokenInCtaIdx load data -> compute absOutArr
CTA->>CTA: pack absOutArr into PackedType
CTA->>CTA: BlockReduce.Reduce(absOutPacked, MaxOp) -> aMaxPacked
CTA->>CTA: unpack aMaxArr = arrayFromPackedType(aMaxPacked)
CTA->>CTA: compute per-CTA scaleOut and write params.outDqSfsPtr
CTA->>CTA: scale and write per-token outputs
end
Grid->>Host: kernel complete
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @nekorobov, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a substantial performance optimization to the TRT-LLM MoE Block-FP8 activation kernel. By implementing vectorized processing of multiple tokens per CUDA thread block and an adaptive kernel launch strategy, the changes significantly reduce execution time, particularly for large batch sizes. The core idea is to improve data locality and parallelism within the kernel, leading to more efficient utilization of GPU resources. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a performance optimization for the Block-FP8 MoE activation kernel, targeting large batch sizes. The core idea is to process multiple tokens per CTA to improve vectorization and memory access patterns. This is achieved by introducing a NumTokensPerCta template parameter and using packed data types (float2, float4) for reductions.
The changes are well-structured, but I've identified two critical issues that could lead to incorrect behavior or deadlocks:
- Uninitialized local arrays in
activationDeepSeekKernelcan be read, causing undefined behavior. - A
breakstatement within aif (threadIdx.x == 0)block can cause a deadlock due to thread divergence before a__syncthreads.
Additionally, I've pointed out a medium-severity issue with a hardcoded device ID.
After addressing these points, the optimization should be robust and provide the significant performance improvements shown in the description.
| float scale1Arr[NumTokensPerCta]; | ||
| float scale2Arr[NumTokensPerCta]; | ||
| float dataX1Arr[NumTokensPerCta]; | ||
| float dataX2Arr[NumTokensPerCta]; | ||
| float outArr[NumTokensPerCta]; | ||
| float absOutArr[NumTokensPerCta]; | ||
| int permutedIdxArr[NumTokensPerCta]; |
There was a problem hiding this comment.
The local arrays declared here are not initialized. In the loop at lines 237-262, if the break (L240) or continue (L247) statements are executed, some elements of these arrays will not be written to. The subsequent loops (e.g., L265-272) will then read from this uninitialized memory, leading to undefined behavior and incorrect results.
To fix this, these arrays should be initialized. The float arrays can be zero-initialized using {}. permutedIdxArr should also be initialized to -1s, for example with a loop, to ensure correctness of checks like if (permutedIdx == -1) in later parts of the kernel.
float scale1Arr[NumTokensPerCta]{};
float scale2Arr[NumTokensPerCta]{};
float dataX1Arr[NumTokensPerCta]{};
float dataX2Arr[NumTokensPerCta]{};
float outArr[NumTokensPerCta]{};
float absOutArr[NumTokensPerCta]{};
int permutedIdxArr[NumTokensPerCta];
| if (threadIdx.x == 0) { | ||
| auto const tokenIdx = tokenCtaIdx + tokenInCtaIdx; | ||
| if (tokenIdx >= params.numTokens) { | ||
| break; | ||
| } | ||
| int const permutedIdx = permutedIdxArr[tokenInCtaIdx]; | ||
| if (permutedIdx == -1) { | ||
| continue; | ||
| } | ||
| s_scaleOutArr[tokenInCtaIdx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal; | ||
| int const scaleOut_idx = | ||
| permutedIdxArr[tokenInCtaIdx] + totalNumPaddedTokens * (hiddenIdx / 128); | ||
| params.outDqSfsPtr[scaleOut_idx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal; | ||
| } |
There was a problem hiding this comment.
The break statement at line 283 is inside an if (threadIdx.x == 0) block. This means only thread 0 will break out of the tokenInCtaIdx loop, while other threads in the block will continue. This divergence will cause a deadlock when __syncthreads() is called at line 295, as not all threads in the block will reach it. The continue at line 287 has the same problem.
A simple fix is to remove the break and continue and wrap the logic inside the if (threadIdx.x == 0) block with an additional check for tokenIdx and permutedIdx.
if (threadIdx.x == 0) {
auto const tokenIdx = tokenCtaIdx + tokenInCtaIdx;
if (tokenIdx < params.numTokens) {
int const permutedIdx = permutedIdxArr[tokenInCtaIdx];
if (permutedIdx != -1) {
s_scaleOutArr[tokenInCtaIdx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal;
int const scaleOut_idx =
permutedIdxArr[tokenInCtaIdx] + totalNumPaddedTokens * (hiddenIdx / 128);
params.outDqSfsPtr[scaleOut_idx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal;
}
}
}
| constexpr int NUM_ELTS_PER_SF = 128; | ||
| int const NUM_THREADS_PER_CTA = 128; | ||
|
|
||
| int const deviceId = 0; |
There was a problem hiding this comment.
There was a problem hiding this comment.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between 36d2463 and 99abadc97f3f8ef150735a1d83ceb3bca3fd2660.
📒 Files selected for processing (2)
csrc/trtllm_fused_moe_dev_kernel.cu(3 hunks)include/flashinfer/trtllm/fused_moe/DevKernel.h(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/trtllm_fused_moe_dev_kernel.cu (1)
include/flashinfer/trtllm/common/cudaUtils.h (3)
float4(173-175)float2(169-171)float(165-167)
| int const tokenIdx = tokenCtaIdx + tokenInCtaIdx; | ||
| if (tokenIdx >= params.numTokens) { | ||
| break; | ||
| } | ||
|
|
||
| int const scale1_idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128); | ||
| int const scale2_idx = | ||
| permutedIdx + totalNumPaddedTokens * ((hiddenIdx / 128) + (params.innerDim / 2 / 128)); | ||
| float const scale1 = params.inDqSfsPtr[scale1_idx]; | ||
| float const scale2 = params.inDqSfsPtr[scale2_idx]; | ||
|
|
||
| float x1 = scale1 * (float)params.inPtr[baseIdx]; | ||
| float x2 = scale2 * (float)params.inPtr[baseIdx + params.innerDim / 2]; | ||
|
|
||
| float act = silu(x2); | ||
| float out = act * x1; | ||
| int const expandedIdx = tokenIdx * params.topK + k; | ||
| int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx]; | ||
| permutedIdxArr[tokenInCtaIdx] = permutedIdx; | ||
| if (permutedIdx == -1) { | ||
| continue; | ||
| } | ||
|
|
||
| // Process blocks for this CTA | ||
| int const baseIdx = permutedIdx * params.innerDim + hiddenIdx; | ||
|
|
||
| int const scale1Idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128); | ||
| int const scale2Idx = permutedIdx + totalNumPaddedTokens * | ||
| ((hiddenIdx / 128) + (params.innerDim / 2 / 128)); | ||
|
|
||
| scale1Arr[tokenInCtaIdx] = params.inDqSfsPtr[scale1Idx]; | ||
| scale2Arr[tokenInCtaIdx] = params.inDqSfsPtr[scale2Idx]; | ||
| dataX1Arr[tokenInCtaIdx] = static_cast<float>(params.inPtr[baseIdx]); | ||
| dataX2Arr[tokenInCtaIdx] = | ||
| static_cast<float>(params.inPtr[baseIdx + params.innerDim / 2]); | ||
| } | ||
|
|
||
| // The largest (finite) value that can be represented using E4m3. | ||
| float constexpr E4m3MaxVal{448.f}; | ||
| #pragma unroll | ||
| for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) { | ||
| float x1 = scale1Arr[tokenInCtaIdx] * dataX1Arr[tokenInCtaIdx]; | ||
| float x2 = scale2Arr[tokenInCtaIdx] * dataX2Arr[tokenInCtaIdx]; | ||
| float act = silu(x2); | ||
| float out = act * x1; | ||
| outArr[tokenInCtaIdx] = out; | ||
| absOutArr[tokenInCtaIdx] = fabsf(out); | ||
| } | ||
|
|
||
| // Compute the absolute max | ||
| #if CUDA_VERSION >= 12090 | ||
| float aMax = BlockReduce(temp_storage).Reduce(fabsf(out), cuda::maximum<>{}); | ||
| #else | ||
| float aMax = BlockReduce(temp_storage).Reduce(fabsf(out), cub::Max{}); | ||
| #endif | ||
| if (threadIdx.x == 0) { | ||
| s_scaleOut = aMax / E4m3MaxVal; | ||
| int const scaleOut_idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128); | ||
| params.outDqSfsPtr[scaleOut_idx] = aMax / E4m3MaxVal; | ||
| auto absOutPacked = packedTypeFromArray<PackedType, NumTokensPerCta>(absOutArr); | ||
| auto aMaxPacked = BlockReduce(tempStorage).Reduce(absOutPacked, MaxOp{}); | ||
| auto aMaxArr = arrayFromPackedType<PackedType, NumTokensPerCta>(aMaxPacked); | ||
|
|
||
| #pragma unroll | ||
| for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) { | ||
| if (threadIdx.x == 0) { | ||
| auto const tokenIdx = tokenCtaIdx + tokenInCtaIdx; | ||
| if (tokenIdx >= params.numTokens) { | ||
| break; | ||
| } | ||
| int const permutedIdx = permutedIdxArr[tokenInCtaIdx]; | ||
| if (permutedIdx == -1) { | ||
| continue; | ||
| } | ||
| s_scaleOutArr[tokenInCtaIdx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal; | ||
| int const scaleOut_idx = | ||
| permutedIdxArr[tokenInCtaIdx] + totalNumPaddedTokens * (hiddenIdx / 128); | ||
| params.outDqSfsPtr[scaleOut_idx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal; | ||
| } | ||
| } | ||
| __syncthreads(); | ||
| float const scaleOut = s_scaleOut; | ||
| __syncthreads(); | ||
| int const outIdx = permutedIdx * (params.innerDim / 2) + hiddenIdx; | ||
| params.outPtr[outIdx] = (Type)(out / scaleOut); | ||
|
|
||
| #pragma unroll | ||
| for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) { | ||
| auto const tokenIdx = tokenCtaIdx + tokenInCtaIdx; | ||
| if (tokenIdx >= params.numTokens) { | ||
| break; | ||
| } | ||
| int const permutedIdx = permutedIdxArr[tokenInCtaIdx]; | ||
| if (permutedIdx == -1) { | ||
| continue; | ||
| } | ||
| float const scaleOut = s_scaleOutArr[tokenInCtaIdx]; | ||
| int const outIdx = permutedIdx * (params.innerDim / 2) + hiddenIdx; | ||
| params.outPtr[outIdx] = static_cast<Type>(outArr[tokenInCtaIdx] / scaleOut); | ||
| } |
There was a problem hiding this comment.
Fix uninitialized per-token buffers
When permutedIdx == -1 we continue before populating scale1/scale2/dataX*/out/absOut for that slot. The second pass still multiplies those arrays and feeds them into BlockReduce, so threads read stale values from the previous hidden block, producing bogus aMaxArr and corrupt outDqSfsPtr/outPtr for the remaining tokens in the CTA. Please zero/mark invalid tokens before skipping the load and short-circuit the later math.
@@
- float outArr[NumTokensPerCta];
- float absOutArr[NumTokensPerCta];
- int permutedIdxArr[NumTokensPerCta];
+ float outArr[NumTokensPerCta];
+ float absOutArr[NumTokensPerCta];
+ bool validTokenArr[NumTokensPerCta];
+ int permutedIdxArr[NumTokensPerCta];
@@
- int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx];
- permutedIdxArr[tokenInCtaIdx] = permutedIdx;
- if (permutedIdx == -1) {
- continue;
- }
+ int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx];
+ permutedIdxArr[tokenInCtaIdx] = permutedIdx;
+ validTokenArr[tokenInCtaIdx] = (permutedIdx != -1);
+ if (!validTokenArr[tokenInCtaIdx]) {
+ scale1Arr[tokenInCtaIdx] = 0.f;
+ scale2Arr[tokenInCtaIdx] = 0.f;
+ dataX1Arr[tokenInCtaIdx] = 0.f;
+ dataX2Arr[tokenInCtaIdx] = 0.f;
+ continue;
+ }
@@
- for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) {
- float x1 = scale1Arr[tokenInCtaIdx] * dataX1Arr[tokenInCtaIdx];
- float x2 = scale2Arr[tokenInCtaIdx] * dataX2Arr[tokenInCtaIdx];
+ for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) {
+ if (!validTokenArr[tokenInCtaIdx]) {
+ outArr[tokenInCtaIdx] = 0.f;
+ absOutArr[tokenInCtaIdx] = 0.f;
+ continue;
+ }
+ float x1 = scale1Arr[tokenInCtaIdx] * dataX1Arr[tokenInCtaIdx];
+ float x2 = scale2Arr[tokenInCtaIdx] * dataX2Arr[tokenInCtaIdx];📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| float scale1Arr[NumTokensPerCta]; | |
| float scale2Arr[NumTokensPerCta]; | |
| float dataX1Arr[NumTokensPerCta]; | |
| float dataX2Arr[NumTokensPerCta]; | |
| float outArr[NumTokensPerCta]; | |
| float absOutArr[NumTokensPerCta]; | |
| int permutedIdxArr[NumTokensPerCta]; | |
| // Loop over tokens | |
| for (int k = blockIdx.z; k < params.topK; k += gridDim.z) { | |
| for (int tokenCtaIdx = blockIdx.y * NumTokensPerCta; tokenCtaIdx < params.numTokens; | |
| tokenCtaIdx += gridDim.y * NumTokensPerCta) { | |
| for (int hiddenIdx = threadIdx.x + blockDim.x * blockIdx.x; hiddenIdx < params.innerDim / 2; | |
| hiddenIdx += blockDim.x * gridDim.x) { | |
| int const baseIdx = permutedIdx * params.innerDim + hiddenIdx; | |
| int const totalNumPaddedTokens = params.totalNumPaddedTokens[0]; | |
| #pragma unroll | |
| for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) { | |
| int const tokenIdx = tokenCtaIdx + tokenInCtaIdx; | |
| if (tokenIdx >= params.numTokens) { | |
| break; | |
| } | |
| int const scale1_idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128); | |
| int const scale2_idx = | |
| permutedIdx + totalNumPaddedTokens * ((hiddenIdx / 128) + (params.innerDim / 2 / 128)); | |
| float const scale1 = params.inDqSfsPtr[scale1_idx]; | |
| float const scale2 = params.inDqSfsPtr[scale2_idx]; | |
| float x1 = scale1 * (float)params.inPtr[baseIdx]; | |
| float x2 = scale2 * (float)params.inPtr[baseIdx + params.innerDim / 2]; | |
| float act = silu(x2); | |
| float out = act * x1; | |
| int const expandedIdx = tokenIdx * params.topK + k; | |
| int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx]; | |
| permutedIdxArr[tokenInCtaIdx] = permutedIdx; | |
| if (permutedIdx == -1) { | |
| continue; | |
| } | |
| // Process blocks for this CTA | |
| int const baseIdx = permutedIdx * params.innerDim + hiddenIdx; | |
| int const scale1Idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128); | |
| int const scale2Idx = permutedIdx + totalNumPaddedTokens * | |
| ((hiddenIdx / 128) + (params.innerDim / 2 / 128)); | |
| scale1Arr[tokenInCtaIdx] = params.inDqSfsPtr[scale1Idx]; | |
| scale2Arr[tokenInCtaIdx] = params.inDqSfsPtr[scale2Idx]; | |
| dataX1Arr[tokenInCtaIdx] = static_cast<float>(params.inPtr[baseIdx]); | |
| dataX2Arr[tokenInCtaIdx] = | |
| static_cast<float>(params.inPtr[baseIdx + params.innerDim / 2]); | |
| } | |
| // The largest (finite) value that can be represented using E4m3. | |
| float constexpr E4m3MaxVal{448.f}; | |
| #pragma unroll | |
| for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) { | |
| float x1 = scale1Arr[tokenInCtaIdx] * dataX1Arr[tokenInCtaIdx]; | |
| float x2 = scale2Arr[tokenInCtaIdx] * dataX2Arr[tokenInCtaIdx]; | |
| float act = silu(x2); | |
| float out = act * x1; | |
| outArr[tokenInCtaIdx] = out; | |
| absOutArr[tokenInCtaIdx] = fabsf(out); | |
| } | |
| // Compute the absolute max | |
| #if CUDA_VERSION >= 12090 | |
| float aMax = BlockReduce(temp_storage).Reduce(fabsf(out), cuda::maximum<>{}); | |
| #else | |
| float aMax = BlockReduce(temp_storage).Reduce(fabsf(out), cub::Max{}); | |
| #endif | |
| if (threadIdx.x == 0) { | |
| s_scaleOut = aMax / E4m3MaxVal; | |
| int const scaleOut_idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128); | |
| params.outDqSfsPtr[scaleOut_idx] = aMax / E4m3MaxVal; | |
| auto absOutPacked = packedTypeFromArray<PackedType, NumTokensPerCta>(absOutArr); | |
| auto aMaxPacked = BlockReduce(tempStorage).Reduce(absOutPacked, MaxOp{}); | |
| auto aMaxArr = arrayFromPackedType<PackedType, NumTokensPerCta>(aMaxPacked); | |
| #pragma unroll | |
| for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) { | |
| if (threadIdx.x == 0) { | |
| auto const tokenIdx = tokenCtaIdx + tokenInCtaIdx; | |
| if (tokenIdx >= params.numTokens) { | |
| break; | |
| } | |
| int const permutedIdx = permutedIdxArr[tokenInCtaIdx]; | |
| if (permutedIdx == -1) { | |
| continue; | |
| } | |
| s_scaleOutArr[tokenInCtaIdx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal; | |
| int const scaleOut_idx = | |
| permutedIdxArr[tokenInCtaIdx] + totalNumPaddedTokens * (hiddenIdx / 128); | |
| params.outDqSfsPtr[scaleOut_idx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal; | |
| } | |
| } | |
| __syncthreads(); | |
| float const scaleOut = s_scaleOut; | |
| __syncthreads(); | |
| int const outIdx = permutedIdx * (params.innerDim / 2) + hiddenIdx; | |
| params.outPtr[outIdx] = (Type)(out / scaleOut); | |
| #pragma unroll | |
| for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) { | |
| auto const tokenIdx = tokenCtaIdx + tokenInCtaIdx; | |
| if (tokenIdx >= params.numTokens) { | |
| break; | |
| } | |
| int const permutedIdx = permutedIdxArr[tokenInCtaIdx]; | |
| if (permutedIdx == -1) { | |
| continue; | |
| } | |
| float const scaleOut = s_scaleOutArr[tokenInCtaIdx]; | |
| int const outIdx = permutedIdx * (params.innerDim / 2) + hiddenIdx; | |
| params.outPtr[outIdx] = static_cast<Type>(outArr[tokenInCtaIdx] / scaleOut); | |
| } | |
| float scale1Arr[NumTokensPerCta]; | |
| float scale2Arr[NumTokensPerCta]; | |
| float dataX1Arr[NumTokensPerCta]; | |
| float dataX2Arr[NumTokensPerCta]; | |
| float outArr[NumTokensPerCta]; | |
| float absOutArr[NumTokensPerCta]; | |
| bool validTokenArr[NumTokensPerCta]; | |
| int permutedIdxArr[NumTokensPerCta]; | |
| // Loop over tokens | |
| for (int k = blockIdx.z; k < params.topK; k += gridDim.z) { | |
| for (int tokenCtaIdx = blockIdx.y * NumTokensPerCta; tokenCtaIdx < params.numTokens; | |
| tokenCtaIdx += gridDim.y * NumTokensPerCta) { | |
| for (int hiddenIdx = threadIdx.x + blockDim.x * blockIdx.x; hiddenIdx < params.innerDim / 2; | |
| hiddenIdx += blockDim.x * gridDim.x) { | |
| #pragma unroll | |
| for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) { | |
| int const tokenIdx = tokenCtaIdx + tokenInCtaIdx; | |
| if (tokenIdx >= params.numTokens) { | |
| break; | |
| } | |
| int const expandedIdx = tokenIdx * params.topK + k; | |
| int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx]; | |
| permutedIdxArr[tokenInCtaIdx] = permutedIdx; | |
| validTokenArr[tokenInCtaIdx] = (permutedIdx != -1); | |
| if (!validTokenArr[tokenInCtaIdx]) { | |
| scale1Arr[tokenInCtaIdx] = 0.f; | |
| scale2Arr[tokenInCtaIdx] = 0.f; | |
| dataX1Arr[tokenInCtaIdx] = 0.f; | |
| dataX2Arr[tokenInCtaIdx] = 0.f; | |
| continue; | |
| } | |
| // Process blocks for this CTA | |
| int const baseIdx = permutedIdx * params.innerDim + hiddenIdx; | |
| int const scale1Idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128); | |
| int const scale2Idx = permutedIdx + totalNumPaddedTokens * | |
| ((hiddenIdx / 128) + (params.innerDim / 2 / 128)); | |
| scale1Arr[tokenInCtaIdx] = params.inDqSfsPtr[scale1Idx]; | |
| scale2Arr[tokenInCtaIdx] = params.inDqSfsPtr[scale2Idx]; | |
| dataX1Arr[tokenInCtaIdx] = static_cast<float>(params.inPtr[baseIdx]); | |
| dataX2Arr[tokenInCtaIdx] = | |
| static_cast<float>(params.inPtr[baseIdx + params.innerDim / 2]); | |
| } | |
| #pragma unroll | |
| for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) { | |
| if (!validTokenArr[tokenInCtaIdx]) { | |
| outArr[tokenInCtaIdx] = 0.f; | |
| absOutArr[tokenInCtaIdx] = 0.f; | |
| continue; | |
| } | |
| float x1 = scale1Arr[tokenInCtaIdx] * dataX1Arr[tokenInCtaIdx]; | |
| float x2 = scale2Arr[tokenInCtaIdx] * dataX2Arr[tokenInCtaIdx]; | |
| float act = silu(x2); | |
| float out = act * x1; | |
| outArr[tokenInCtaIdx] = out; | |
| absOutArr[tokenInCtaIdx] = fabsf(out); | |
| } | |
| auto absOutPacked = packedTypeFromArray<PackedType, NumTokensPerCta>(absOutArr); | |
| auto aMaxPacked = BlockReduce(tempStorage).Reduce(absOutPacked, MaxOp{}); | |
| auto aMaxArr = arrayFromPackedType<PackedType, NumTokensPerCta>(aMaxPacked); | |
| #pragma unroll | |
| for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) { | |
| if (threadIdx.x == 0) { | |
| auto const tokenIdx = tokenCtaIdx + tokenInCtaIdx; | |
| if (tokenIdx >= params.numTokens) { | |
| break; | |
| } | |
| int const permutedIdx = permutedIdxArr[tokenInCtaIdx]; | |
| if (permutedIdx == -1) { | |
| continue; | |
| } | |
| s_scaleOutArr[tokenInCtaIdx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal; | |
| int const scaleOut_idx = | |
| permutedIdxArr[tokenInCtaIdx] + totalNumPaddedTokens * (hiddenIdx / 128); | |
| params.outDqSfsPtr[scaleOut_idx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal; | |
| } | |
| } | |
| __syncthreads(); | |
| #pragma unroll | |
| for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) { | |
| auto const tokenIdx = tokenCtaIdx + tokenInCtaIdx; | |
| if (tokenIdx >= params.numTokens) { | |
| break; | |
| } | |
| int const permutedIdx = permutedIdxArr[tokenInCtaIdx]; | |
| if (permutedIdx == -1) { | |
| continue; | |
| } | |
| float const scaleOut = s_scaleOutArr[tokenInCtaIdx]; | |
| int const outIdx = permutedIdx * (params.innerDim / 2) + hiddenIdx; | |
| params.outPtr[outIdx] = static_cast<Type>(outArr[tokenInCtaIdx] / scaleOut); | |
| } |
🤖 Prompt for AI Agents
In csrc/trtllm_fused_moe_dev_kernel.cu around lines 222–310, the per-token
scratch arrays are left uninitialized when permutedIdx == -1, which lets stale
values flow into the later math and BlockReduce; fix by marking/clearing those
slots when skipping loads and short-circuiting later computation: when
permutedIdx == -1 set scale1Arr/scale2Arr/dataX1Arr/dataX2Arr/outArr/absOutArr
entries to 0 (and keep permutedIdxArr == -1), then in the subsequent math loops
check permutedIdxArr[token]==-1 and either skip computing x1/x2/act/out (leaving
zeros) or continue, ensure BlockReduce sees zeros for invalid tokens, and avoid
writing to params.outDqSfsPtr and params.outPtr for permutedIdx == -1 (i.e.,
gate those writes with the same check).
|
/bot run |
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
99abadc to
161bf9a
Compare
|
/bot kill |
|
/bot run |
|
@nekorobov is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
1 similar comment
|
@nekorobov is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (2)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (2)
51-51: Consider conditional creation of routing_bias.The
routing_biastensor is created unconditionally but only used in the block-scale path (line 102). The per-tensor path explicitly passesNone(line 129). Consider creating this tensor only whenis_block_scaleis True to avoid unnecessary allocation.Apply this diff to conditionally create routing_bias:
- routing_bias = torch.randn(num_experts, device="cuda", dtype=torch.bfloat16) + routing_bias = ( + torch.randn(num_experts, device="cuda", dtype=torch.bfloat16) + if quant_mode == "Fp8-Block" + else None + )
59-81: Consider using actual block quantization for realistic benchmarking.The comment indicates that block-scale quantization is "too slow" for benchmarking, so per-tensor quantization is used and then broadcast to block shape. This approach creates artificial block scales where all blocks have identical values, which may not accurately represent the performance characteristics of real block-scale quantization in production.
For more realistic benchmarking results, consider either:
- Implementing optimized block quantization for the benchmark
- Documenting that this benchmark measures kernel performance in isolation, not end-to-end quantization overhead
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (3)
flashinfer/fused_moe/core.py (4)
trtllm_fp8_block_scale_moe(1761-1843)WeightLayout(161-168)RoutingMethodType(58-72)trtllm_fp8_per_tensor_scale_moe(1680-1758)csrc/trtllm_fused_moe_kernel_launcher.cu (4)
trtllm_fp8_block_scale_moe(709-760)trtllm_fp8_block_scale_moe(709-716)trtllm_fp8_per_tensor_scale_moe(352-412)trtllm_fp8_per_tensor_scale_moe(352-360)include/flashinfer/trtllm/fused_moe/runner.h (5)
hidden_size(265-265)num_experts(263-263)intermediate_size(275-275)top_k(270-270)RoutingMethodType(37-136)
🪛 Ruff (0.14.3)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py
100-125: Do not assign a lambda expression, use a def
Rewrite fn as a def
(E731)
127-149: Do not assign a lambda expression, use a def
Rewrite fn as a def
(E731)
🔇 Additional comments (8)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (8)
14-15: LGTM!The new imports are correctly added to support the FP8-Block quantization path.
26-26: LGTM!The simplified max calculation is cleaner and functionally equivalent.
34-34: LGTM!The quant_mode parameter with Literal type hint provides clear API documentation.
83-97: LGTM!The conditional scale tensor setup correctly handles the different API requirements for per-tensor vs block-scale kernels.
99-149: Document the different routing configurations.The block-scale path uses
DeepSeekV3routing with group parameters (n_group=8, topk_group=4) and routed_scaling_factor=2.5, while the per-tensor path usesTopKrouting with routed_scaling_factor=1.0. These different configurations mean the two quantization modes are being benchmarked under different routing scenarios, which may not provide a direct performance comparison.If this is intentional to reflect typical usage patterns for each quantization mode, consider adding a comment explaining the rationale. Otherwise, consider aligning the routing configurations for a more apples-to-apples comparison.
324-330: LGTM!The CLI choices are correctly extended to support the new FP8 quantization modes.
355-366: LGTM!The conditional dispatch correctly routes FP8 quantization modes to the appropriate benchmark function.
195-195: LGTM!The addition of
is_sf_swizzled_layout=Falseparameter propagates the new API requirement without changing existing behavior.
jiahanc
left a comment
There was a problem hiding this comment.
LGTM thanks for the work!
IwakuraRein
left a comment
There was a problem hiding this comment.
Thx for your contribution
|
[FAILED] Pipeline #38094149: 13/17 passed |
<!-- .github/pull_request_template.md --> ## 📌 Description - Small optimization to the activation kernel for block-FP8 MoE for large batch size. | BS | Baseline, us | Optimized, us | | ------------- | ------------- | ------------- | | 1 | 2.4 | 2.1 | | 32 | 3.5 | 2.6 | | 256 | 21.7 | 8.7 | | 1024 | 84.4 | 23.8 | | 4096 | 333 | 87.0 | | 16384 | 1330 | 365 | - Adding micro-benchmark for DS FP8 implemented by @IwakuraRein. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Improved Mixture-of-Experts inference with configurable multi-token batching per GPU core for higher throughput. * Expanded FP8 quantization with a new block-scale mode and dynamic, hardware-aware kernel scheduling for better utilization and numerical stability. * Vectorized max-reduction and per-block scaling to accelerate reductions and improve output scaling precision. * Autotuner/CLI now exposes the FP8 block quantization option for tuning. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Siyuan Fu <siyuanf@nvidia.com> Co-authored-by: Siyuan Fu <siyuanf@nvidia.com>
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit