Skip to content

Fix silent bug with FP8 per tensor non-gated MoE#2882

Merged
aleozlx merged 3 commits intoflashinfer-ai:mainfrom
danisereb:fix_fp8_non_gated
Apr 14, 2026
Merged

Fix silent bug with FP8 per tensor non-gated MoE#2882
aleozlx merged 3 commits intoflashinfer-ai:mainfrom
danisereb:fix_fp8_non_gated

Conversation

@danisereb
Copy link
Copy Markdown
Contributor

@danisereb danisereb commented Mar 24, 2026

📌 Description

This PR fixed a silent bug that forces fallback tactic with TRTLLM MoE FP8 and non-gated models (Nemotron).

[Autotuner]: Failed to get valid tactics for ... Error occurred: FP8 per-tensor currently supports gated activations only, got act_type=6

There is no crash, only a performance bug (fallback tactic will always be selected).

The error raised from the C++ code seems to be hidden by this try-except in function get_trtllm_moe_sm100_module:
https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/fused_moe/core.py#L924

Code:

            if instance_key not in MoERunner.valid_tactics_dict:
                try:
                    valid_tactics = moe_op.trtllm_get_valid_moe_configs(*instance_key)
                except Exception as e:
                    logger.debug(
                        f"[Autotuner]: Failed to get valid tactics for {instance_key}. Error occurred: {e}"
                    )
                    return []
                MoERunner.valid_tactics_dict[instance_key] = valid_tactics
            return MoERunner.valid_tactics_dict[instance_key]

Only by enabling FLASHINFER_LOGGING_LEVEL=DEBUG the error appears (logger.debug).

This bug does not exist in flashinfer version 0.6.6.

With version 0.6.6

commit 70b142b75b46aa56e7f675a8e6ec1a977352c91f

# from flashinfer clone path
git checkout v0.6.6

uv pip install -v .

flashinfer clear-cache

export FLASHINFER_LOGGING_LEVEL=DEBUG

python -m pytest --maxfail 1 -sv --tb short \
tests/moe/test_trtllm_gen_fused_moe.py::test_deepseekv3_routing  \
-k "Relu2 and nemotron_3_dummy and FP8_PerTensor"

Output:

tests/moe/test_trtllm_gen_fused_moe.py::test_deepseekv3_routing[Relu2-Shuffled_MajorK-nemotron_3_dummy-FP8_PerTensor-2944-1024-8] 2026-03-24 12:42:51,264 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
2026-03-24 12:42:51,265 - DEBUG - cubin_loader.py:182 - flashinfer.jit: Loading from /root/.cache/flashinfer/cubins/b55211623be7f5697c5262ffd8361fc06c147bc9/batched_gemm-b3c1646-c111d7c/checksums.txt
2026-03-24 12:42:51,266 - DEBUG - cubin_loader.py:182 - flashinfer.jit: Loading from /root/.cache/flashinfer/cubins/b55211623be7f5697c5262ffd8361fc06c147bc9/batched_gemm-b3c1646-c111d7c/include/flashinferMetaInfo.h
2026-03-24 12:42:51,270 - DEBUG - cubin_loader.py:182 - flashinfer.jit: Loading from /root/.cache/flashinfer/cubins/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h
2026-03-24 12:42:51,271 - DEBUG - cubin_loader.py:182 - flashinfer.jit: Loading from /root/.cache/flashinfer/cubins/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h
...
2026-03-24 12:59:50,604 - DEBUG - autotuner.py:750 - flashinfer.jit: [Autotuner]: generated profile: OptimizationProfile(shapes=[[DynamicDim(min=1, opt=1, max=2), StaticDim(val=1024)], [DynamicDim(min=1, opt=1, max=2), StaticDim(val=512)], [DynamicDim(min=1, opt=1, max=2), StaticDim(val=22)], [DynamicDim(min=1, opt=1, max=2), StaticDim(val=22)], [DynamicDim(min=1, opt=1, max=2), StaticDim(val=1024)]], tensor_initializers=[<function get_trtllm_moe_sm100_module.<locals>.MoERunner.<lambda> at 0xfffbd155d260>, <function get_trtllm_moe_sm100_module.<locals>.MoERunner.<lambda> at 0xfffbd155eb60>, <function get_trtllm_moe_sm100_module.<locals>.MoERunner.<lambda> at 0xfffbd155ed40>, <function get_trtllm_moe_sm100_module.<locals>.MoERunner.<lambda> at 0xfffbd155d120>, <function get_trtllm_moe_sm100_module.<locals>.MoERunner.<lambda> at 0xfffbd155ea20>])
2026-03-24 12:59:50,604 - DEBUG - autotuner.py:750 - flashinfer.jit: [Autotuner]: generated profile: OptimizationProfile(shapes=[[DynamicDim(min=2, opt=2, max=4), StaticDim(val=1024)], [DynamicDim(min=2, opt=2, max=4), StaticDim(val=512)], [DynamicDim(min=2, opt=2, max=4), StaticDim(val=22)], [DynamicDim(min=2, opt=2, max=4), StaticDim(val=22)], [DynamicDim(min=2, opt=2, max=4), StaticDim(val=1024)]], tensor_initializers=[<function get_trtllm_moe_sm100_module.<locals>.MoERunner.<lambda> at 0xfffbd155d260>, <function get_trtllm_moe_sm100_module.<locals>.MoERunner.<lambda> at 0xfffbd155eb60>, <function get_trtllm_moe_sm100_module.<locals>.MoERunner.<lambda> at 0xfffbd155ed40>, <function get_trtllm_moe_sm100_module.<locals>.MoERunner.<lambda> at 0xfffbd155d120>, <function get_trtllm_moe_sm100_module.<locals>.MoERunner.<lambda> at 0xfffbd155ea20>])
2026-03-24 12:59:50,604 - DEBUG - autotuner.py:750 - flashinfer.jit: [Autotuner]: generated profile: OptimizationProfile(shapes=[[DynamicDim(min=4, opt=4, max=8), StaticDim(val=1024)], [DynamicDim(min=4, opt=4, max=8), StaticDim(val=512)], [DynamicDim(min=4, opt=4, max=8), StaticDim(val=22)], [DynamicDim(min=4, opt=4, max=8), StaticDim(val=22)], [DynamicDim(min=4, opt=4, max=8), StaticDim(val=1024)]], tensor_initializers=[<function get_trtllm_moe_sm100_module.<locals>.MoERunner.<lambda> at 0xfffbd155d260>, <function get_trtllm_moe_sm100_module.<locals>.MoERunner.<lambda> at 0xfffbd155eb60>, <function get_trtllm_moe_sm100_module.<locals>.MoERunner.<lambda> at 0xfffbd155ed40>, <function get_trtllm_moe_sm100_module.<locals>.MoERunner.<lambda> at 0xfffbd155d120>, <function get_trtllm_moe_sm100_module.<locals>.MoERunner.<lambda> at 0xfffbd155ea20>])
...
2026-03-24 13:00:09,444 - DEBUG - autotuner.py:750 - flashinfer.jit: [Autotuner]: generated profile: OptimizationProfile(shapes=[[DynamicDim(min=2048, opt=2048, max=4096), StaticDim(val=1024)], [DynamicDim(min=2048, opt=2048, max=4096), StaticDim(val=512)], [DynamicDim(min=2048, opt=2048, max=4096), StaticDim(val=22)], [DynamicDim(min=2048, opt=2048, max=4096), StaticDim(val=22)], [DynamicDim(min=2048, opt=2048, max=4096), StaticDim(val=1024)]], tensor_initializers=[<function get_trtllm_moe_sm100_module.<locals>.MoERunner.<lambda> at 0xfffbd155d260>, <function get_trtllm_moe_sm100_module.<locals>.MoERunner.<lambda> at 0xfffbd155eb60>, <function get_trtllm_moe_sm100_module.<locals>.MoERunner.<lambda> at 0xfffbd155ed40>, <function get_trtllm_moe_sm100_module.<locals>.MoERunner.<lambda> at 0xfffbd155d120>, <function get_trtllm_moe_sm100_module.<locals>.MoERunner.<lambda> at 0xfffbd155ea20>])
2026-03-24 13:00:09,444 - DEBUG - autotuner.py:750 - flashinfer.jit: [Autotuner]: generated profile: OptimizationProfile(shapes=[[DynamicDim(min=4096, opt=4096, max=inf), StaticDim(val=1024)], [DynamicDim(min=4096, opt=4096, max=inf), StaticDim(val=512)], [DynamicDim(min=4096, opt=4096, max=inf), StaticDim(val=22)], [DynamicDim(min=4096, opt=4096, max=inf), StaticDim(val=22)], [DynamicDim(min=4096, opt=4096, max=inf), StaticDim(val=1024)]], tensor_initializers=[<function get_trtllm_moe_sm100_module.<locals>.MoERunner.<lambda> at 0xfffbd155d260>, <function get_trtllm_moe_sm100_module.<locals>.MoERunner.<lambda> at 0xfffbd155eb60>, <function get_trtllm_moe_sm100_module.<locals>.MoERunner.<lambda> at 0xfffbd155ed40>, <function get_trtllm_moe_sm100_module.<locals>.MoERunner.<lambda> at 0xfffbd155d120>, <function get_trtllm_moe_sm100_module.<locals>.MoERunner.<lambda> at 0xfffbd155ea20>])
2026-03-24 13:00:09,445 - INFO - autotuner.py:268 - flashinfer.jit: [Autotuner]: Autotuning process ends
PASSED
...
==================================================================================== 3 passed, 51 skipped, 6426 deselected in 1043.61s (0:17:23) =====================================================================================

With main branch

commit b8931925fbda0b4a77a79ac4b6577a1da235f605 (currently tagged as v0.6.7)

# from flashinfer clone path
git checkout v0.6.7

uv pip install -v .

flashinfer clear-cache

export FLASHINFER_LOGGING_LEVEL=DEBUG

python -m pytest --maxfail 1 -sv --tb short \
tests/moe/test_trtllm_gen_fused_moe.py::test_deepseekv3_routing  \
-k "Relu2 and nemotron_3_dummy and FP8_PerTensor"

Output:

tests/moe/test_trtllm_gen_fused_moe.py::test_deepseekv3_routing[Relu2-Shuffled_MajorK-nemotron_3_dummy-FP8_PerTensor-2944-1024-8] 2026-03-24 12:28:10,611 - INFO - autotuner.py:446 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
...
Traceback (most recent call last):
  File "<unknown>", line 0, in __tvm_ffi_trtllm_get_valid_moe_configs
  File "/opt/flashinfer/flashinfer/data/csrc/trtllm_fused_moe_kernel_launcher.cu", line 2234, in tvm::ffi::Array<tvm::ffi::Array<long int> > flashinfer::trtllm_get_valid_moe_configs(int64_t, int64_t, Fp8QuantizationType, int64_t, int64_t, int64_t, int64_t, int64_t, bool, int64_t, int64_t)
NotImplementedError: FP8 per-tensor currently supports gated activations only, got act_type=6.
2026-03-24 12:36:01,582 - DEBUG - core.py:925 - flashinfer.jit: [Autotuner]: Failed to get valid tactics for (<DtypeTrtllmGen.E4m3: 1050629>, <DtypeTrtllmGen.E4m3: 1050629>, <Fp8QuantizationType.NoneFp8: 0>, 22, 1024, 2944, 512, <ActivationType.Relu2: 6>, True, <WeightLayout.MajorK: 0>, 1). Error occurred: FP8 per-tensor currently supports gated activations only, got act_type=6.
Traceback (most recent call last):
  File "<unknown>", line 0, in __tvm_ffi_trtllm_get_valid_moe_configs
  File "/opt/flashinfer/flashinfer/data/csrc/trtllm_fused_moe_kernel_launcher.cu", line 2234, in tvm::ffi::Array<tvm::ffi::Array<long int> > flashinfer::trtllm_get_valid_moe_configs(int64_t, int64_t, Fp8QuantizationType, int64_t, int64_t, int64_t, int64_t, int64_t, bool, int64_t, int64_t)
NotImplementedError: FP8 per-tensor currently supports gated activations only, got act_type=6.
2026-03-24 12:36:01,583 - DEBUG - core.py:925 - flashinfer.jit: [Autotuner]: Failed to get valid tactics for (<DtypeTrtllmGen.E4m3: 1050629>, <DtypeTrtllmGen.E4m3: 1050629>, <Fp8QuantizationType.NoneFp8: 0>, 22, 1024, 2944, 512, <ActivationType.Relu2: 6>, True, <WeightLayout.MajorK: 0>, 2). Error occurred: FP8 per-tensor currently supports gated activations only, got act_type=6.
...
Traceback (most recent call last):
  File "<unknown>", line 0, in __tvm_ffi_trtllm_get_valid_moe_configs
  File "/opt/flashinfer/flashinfer/data/csrc/trtllm_fused_moe_kernel_launcher.cu", line 2234, in tvm::ffi::Array<tvm::ffi::Array<long int> > flashinfer::trtllm_get_valid_moe_configs(int64_t, int64_t, Fp8QuantizationType, int64_t, int64_t, int64_t, int64_t, int64_t, bool, int64_t, int64_t)
NotImplementedError: FP8 per-tensor currently supports gated activations only, got act_type=6.
2026-03-24 12:36:15,119 - DEBUG - core.py:925 - flashinfer.jit: [Autotuner]: Failed to get valid tactics for (<DtypeTrtllmGen.E4m3: 1050629>, <DtypeTrtllmGen.E4m3: 1050629>, <Fp8QuantizationType.NoneFp8: 0>, 22, 1024, 2944, 512, <ActivationType.Relu2: 6>, True, <WeightLayout.MajorK: 0>, 512). Error occurred: FP8 per-tensor currently supports gated activations only, got act_type=6.
Traceback (most recent call last):
  File "<unknown>", line 0, in __tvm_ffi_trtllm_get_valid_moe_configs
  File "/opt/flashinfer/flashinfer/data/csrc/trtllm_fused_moe_kernel_launcher.cu", line 2234, in tvm::ffi::Array<tvm::ffi::Array<long int> > flashinfer::trtllm_get_valid_moe_configs(int64_t, int64_t, Fp8QuantizationType, int64_t, int64_t, int64_t, int64_t, int64_t, bool, int64_t, int64_t)
NotImplementedError: FP8 per-tensor currently supports gated activations only, got act_type=6.
2026-03-24 12:36:15,119 - DEBUG - core.py:925 - flashinfer.jit: [Autotuner]: Failed to get valid tactics for (<DtypeTrtllmGen.E4m3: 1050629>, <DtypeTrtllmGen.E4m3: 1050629>, <Fp8QuantizationType.NoneFp8: 0>, 22, 1024, 2944, 512, <ActivationType.Relu2: 6>, True, <WeightLayout.MajorK: 0>, 1024). Error occurred: FP8 per-tensor currently supports gated activations only, got act_type=6.
Traceback (most recent call last):
  File "<unknown>", line 0, in __tvm_ffi_trtllm_get_valid_moe_configs
  File "/opt/flashinfer/flashinfer/data/csrc/trtllm_fused_moe_kernel_launcher.cu", line 2234, in tvm::ffi::Array<tvm::ffi::Array<long int> > flashinfer::trtllm_get_valid_moe_configs(int64_t, int64_t, Fp8QuantizationType, int64_t, int64_t, int64_t, int64_t, int64_t, bool, int64_t, int64_t)
NotImplementedError: FP8 per-tensor currently supports gated activations only, got act_type=6.
2026-03-24 12:36:15,119 - DEBUG - core.py:925 - flashinfer.jit: [Autotuner]: Failed to get valid tactics for (<DtypeTrtllmGen.E4m3: 1050629>, <DtypeTrtllmGen.E4m3: 1050629>, <Fp8QuantizationType.NoneFp8: 0>, 22, 1024, 2944, 512, <ActivationType.Relu2: 6>, True, <WeightLayout.MajorK: 0>, 2048). Error occurred: FP8 per-tensor currently supports gated activations only, got act_type=6.
Traceback (most recent call last):
  File "<unknown>", line 0, in __tvm_ffi_trtllm_get_valid_moe_configs
  File "/opt/flashinfer/flashinfer/data/csrc/trtllm_fused_moe_kernel_launcher.cu", line 2234, in tvm::ffi::Array<tvm::ffi::Array<long int> > flashinfer::trtllm_get_valid_moe_configs(int64_t, int64_t, Fp8QuantizationType, int64_t, int64_t, int64_t, int64_t, int64_t, bool, int64_t, int64_t)
NotImplementedError: FP8 per-tensor currently supports gated activations only, got act_type=6.
2026-03-24 12:36:15,120 - DEBUG - core.py:925 - flashinfer.jit: [Autotuner]: Failed to get valid tactics for (<DtypeTrtllmGen.E4m3: 1050629>, <DtypeTrtllmGen.E4m3: 1050629>, <Fp8QuantizationType.NoneFp8: 0>, 22, 1024, 2944, 512, <ActivationType.Relu2: 6>, True, <WeightLayout.MajorK: 0>, 4096). Error occurred: FP8 per-tensor currently supports gated activations only, got act_type=6.
2026-03-24 12:36:15,137 - DEBUG - cubin_loader.py:182 - flashinfer.jit: Loading from /root/.cache/flashinfer/cubins/b55211623be7f5697c5262ffd8361fc06c147bc9/batched_gemm-b3c1646-c111d7c/Bmm_E4m3_E4m3E4m3_Fp32_t128x64x256u2_s5_et128x64_m256x64x32_cga2x1x1_16dp256b_rM_TN_transOut_tokSfB_schPd2x1x2x3_relu2_bN_tma_tmaSf_rgTma_clmp_lbW8_dynB_sm100f.cubin
2026-03-24 12:36:15,138 - DEBUG - cubin_loader.py:182 - flashinfer.jit: Loading from /root/.cache/flashinfer/cubins/b55211623be7f5697c5262ffd8361fc06c147bc9/batched_gemm-b3c1646-c111d7c/Bmm_Bfloat16_E4m3E4m3_Fp32_t128x64x128_s8_et128x64_m256x64x32_cga2x1x1_16dp256b_rM_TN_transOut_schPd2x1x2x3_bN_rgTma_clmp_dynB_sm100f.cubin
2026-03-24 12:36:15,138 - INFO - autotuner.py:455 - flashinfer.jit: [Autotuner]: Autotuning process ends
PASSED
...
===================================================================================== 3 passed, 51 skipped, 6426 deselected in 489.92s (0:08:09) =====================================================================================

In this case we can see the error I removed:
2026-03-24 12:36:01,582 - DEBUG - core.py:925 - flashinfer.jit: [Autotuner]: Failed to get valid tactics for ... Error occurred: FP8 per-tensor currently supports gated activations only, got act_type=6.

But the test does not fail.

With this fix we are back to behavior of 0.6.6 (no Failed to get valid tactics for).

Notes:
All of the above tested using GB200 GPU.
The artifacts path for TRTLLM_GEN_BMM did not change between 0.6.6 and 0.6.7:
b55211623be7f5697c5262ffd8361fc06c147bc9.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes
    • FP8 per-tensor quantization now accepts a broader range of activation types (no longer rejects certain non-gated activations), improving compatibility and reducing errors during model quantization.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 24, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: e3e2d1cf-4dc6-472a-9035-d66beb8e6d15

📥 Commits

Reviewing files that changed from the base of the PR and between acca1e8 and 5091519.

📒 Files selected for processing (1)
  • csrc/trtllm_fused_moe_kernel_launcher.cu
💤 Files with no reviewable changes (1)
  • csrc/trtllm_fused_moe_kernel_launcher.cu

📝 Walkthrough

Walkthrough

Removed the activation-type validation guard in the FP8 per-tensor branch of trtllm_get_valid_moe_configs, so the code now forwards all activation types to Fp8PerTensorLauncher::getValidConfigs(...) for the PerTensorFp8 / NoneFp8 (E4m3) path.

Changes

Cohort / File(s) Summary
FP8 Per-Tensor MOE Validation
csrc/trtllm_fused_moe_kernel_launcher.cu
Removed the isGatedActivation(activation_type) check and associated NotImplementedError in trtllm_get_valid_moe_configs; always dispatches to Fp8PerTensorLauncher::getValidConfigs(...) for the PerTensorFp8/NoneFp8 (E4m3) branch.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Possibly related PRs

Suggested reviewers

  • yzh119
  • cyx-6
  • yongwww
  • samuellees
  • bkryu
  • nv-yunzheq
  • jimmyzho

Poem

🐰 I hopped where kernels used to bar,

gates removed — I saw them spar,
FP8 dreams now roam free,
activations wander with glee,
the rabbit applauds: "Let it be!"

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: fixing a silent bug with FP8 per-tensor support for non-gated MoE models.
Description check ✅ Passed The description comprehensively documents the bug, its impact, reproduction steps, comparison with v0.6.6, and test results; all required template sections are addressed.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a silent bug by lifting an unnecessary constraint within the FP8 per-tensor quantization logic for Mixture of Experts (MoE) models. Previously, the system incorrectly enforced that FP8 per-tensor quantization could only be applied to gated activations, which prevented its use in non-gated MoE setups. The change expands the compatibility of FP8 per-tensor quantization, allowing for greater flexibility and efficiency in MoE deployments.

Highlights

  • FP8 Per-Tensor Quantization: Removed a previous restriction that limited FP8 per-tensor quantization to only gated activations, thereby enabling its use with non-gated Mixture of Experts (MoE) configurations.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request removes the restriction on using non-gated activations for FP8 per-tensor MoE configurations. However, this change exposes an existing bug in Fp8PerTensorLauncher::prepare_moe where buffer allocations for gemm1_output and gemm1_output_scale are hardcoded with a multiplier of 2. This hardcoded value is only correct for gated activations and will lead to inefficient memory allocation for non-gated activations. It is recommended to update these allocations to use intermediate_size_factor for correct handling of both gated and non-gated activations.

Comment thread csrc/trtllm_fused_moe_kernel_launcher.cu
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Mar 24, 2026

/bot run

@aleozlx aleozlx enabled auto-merge (squash) March 24, 2026 20:37
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !459 has been created, and the CI pipeline #46907315 is currently running. I'll report back once the pipeline job completes.

@aleozlx aleozlx added the v0.6.8 release blocker label for 0.6.8 label Mar 24, 2026
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #46907315: 12/20 passed

TomerBN-Nvidia pushed a commit to TomerBN-Nvidia/flashinfer that referenced this pull request Mar 25, 2026
Cherry-picked from upstream flashinfer-ai#2882 (3b0244b).
Removes incorrect guard that rejected non-gated activations (Relu2/Nemotron)
for FP8 per-tensor, silently forcing fallback to slower tactic.
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 13, 2026

CI was blocked on a known irrelevant nvshm compilation error. restarted CI and waiting for auto-merge

@aleozlx aleozlx mentioned this pull request Apr 14, 2026
@aleozlx aleozlx merged commit 57c37bf into flashinfer-ai:main Apr 14, 2026
29 of 30 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

op: moe run-ci v0.6.8 release blocker label for 0.6.8

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants