Skip to content

fix(jit): GEMM kernels produce NaN under concurrency — missing GDC flags cause PDL synchronization barriers to compile as no-ops#2716

Merged
yzh119 merged 1 commit intoflashinfer-ai:mainfrom
voipmonitor:fix/sm120-fp4-pdl-gdc-flags
Mar 9, 2026
Merged

fix(jit): GEMM kernels produce NaN under concurrency — missing GDC flags cause PDL synchronization barriers to compile as no-ops#2716
yzh119 merged 1 commit intoflashinfer-ai:mainfrom
voipmonitor:fix/sm120-fp4-pdl-gdc-flags

Conversation

@voipmonitor
Copy link
Copy Markdown
Contributor

@voipmonitor voipmonitor commented Mar 7, 2026

Summary

All CUTLASS GEMM templates use enablePDL=true (Programmatic Dependent Launch), but the JIT compilation is missing -DCUTLASS_ENABLE_GDC_FOR_SM100=1 and -DCUTLASS_ENABLE_GDC_FOR_SM90=1 compile flags. Without these flags, wait_on_dependent_grids() and launch_dependent_grids() in CUTLASS grid_dependency_control.h compile as empty no-ops, eliminating the synchronization barriers needed for safe PDL execution.

Root Cause

In cutlass/include/cutlass/arch/grid_dependency_control.h:

CUTLASS_DEVICE void wait_on_dependent_grids() {
#if (defined(CUTLASS_GDC_ENABLED))  // only defined when CUTLASS_ENABLE_GDC_FOR_SM100 is set
  asm volatile("griddepcontrol.wait;");
#endif
}

The CUTLASS_GDC_ENABLED macro is only defined when CUTLASS_ENABLE_GDC_FOR_SM100 is passed as a compile flag. Without it, PDL launches kernels with overlap enabled at the host level (cudaLaunchAttributeProgrammaticStreamSerialization), but the device-side synchronization barriers are compiled out — creating a race condition.

Symptoms

On SM120 (Blackwell RTX PRO 6000 / RTX 5090) with high concurrency (64+ simultaneous requests in SGLang with TP=8):

  • CUTLASS FP4 GEMM intermittently fails to write output tiles
  • Unwritten tiles contain uninitialized memory (NaN/garbage)
  • NaN blocks are always contiguous and 128-aligned, matching CTA tile boundaries
  • CUDA_LAUNCH_BLOCKING=1 eliminates the bug (confirms race condition)
  • cudnn backend is unaffected (does not use CUTLASS PDL)
  • Retry with identical inputs produces correct output

Fix

Add -DCUTLASS_ENABLE_GDC_FOR_SM100=1 and -DCUTLASS_ENABLE_GDC_FOR_SM90=1 to all affected GEMM JIT modules:

  • fp4_gemm_cutlass (SM100)
  • fp4_gemm_cutlass_sm103 (SM103)
  • fp4_gemm_cutlass_sm120 (SM120)
  • fp8_gemm_cutlass (SM100)
  • mxfp8_gemm_cutlass (SM100)
  • gemm_sm120 (SM120 FP8 groupwise)

The tgv_gemm module already had DCUTLASS_ENABLE_GDC_FOR_SM100.

Note: DCUTLASS_ENABLE_GDC_FOR_SM90 is needed because the SM120 CUTLASS kernel (sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp) guards launch_dependent_grids() with #ifdef CUTLASS_ENABLE_GDC_FOR_SM90 instead of SM100 (upstream CUTLASS bug).

Verification

Configuration Result
PDL=true, no GDC flags (current) NaN crash under high concurrency
PDL=false (workaround) OK
PDL=true + GDC flags (this PR) OK — tested with 64 concurrent requests, multiple SGLang restarts from JIT cache
CUDA_LAUNCH_BLOCKING=1 OK (confirms race condition)

Environment

  • Hardware: 8x NVIDIA RTX PRO 6000 Blackwell (SM120, 96GB)
  • FlashInfer 0.6.4, CUTLASS 4.4.1
  • SGLang with TP=8, EAGLE-v2, GLM-5-NVFP4-MTP model
  • PyTorch 2.12.0.dev, CUDA 12.8+

Related

Summary by CodeRabbit

Release Notes

Chores

  • Updated CUDA compilation configuration for SM100 and SM90 GPU architectures, enhancing build optimization and extending hardware compatibility for GPU acceleration workloads.

…MM kernels

All CUTLASS GEMM templates use `enablePDL=true` (Programmatic Dependent
Launch), but the JIT compilation was missing `-DCUTLASS_ENABLE_GDC_FOR_SM100=1`
and `-DCUTLASS_ENABLE_GDC_FOR_SM90=1` flags. Without these flags,
`wait_on_dependent_grids()` and `launch_dependent_grids()` in CUTLASS
`grid_dependency_control.h` compile as empty no-ops, eliminating the
synchronization barriers needed for safe PDL execution.

This causes a race condition where dependent kernels read stale data from
global memory before the previous kernel completes its writeback, resulting in
intermittent output tile corruption (128-aligned NaN blocks in output tensors)
under high concurrency on SM120 (Blackwell RTX PRO 6000 / RTX 5090) and
potentially SM100.

The fix adds both GDC flags to all affected GEMM JIT modules:
- SM100 FP4 (fp4_gemm_cutlass)
- SM103 FP4 (fp4_gemm_cutlass_sm103)
- SM120 FP4 (fp4_gemm_cutlass_sm120)
- SM100 FP8 (fp8_gemm_cutlass)
- SM100 MXFP8 (mxfp8_gemm_cutlass)
- SM120 FP8 groupwise (gemm_sm120)

Note: The TGV GEMM module already had DCUTLASS_ENABLE_GDC_FOR_SM100 but was
missing DCUTLASS_ENABLE_GDC_FOR_SM90 (not added here as it uses a different
code path).

Fixes flashinfer-ai#2708

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 7, 2026

📝 Walkthrough

Walkthrough

This PR adds two CUDA preprocessor compiler flags (-DCUTLASS_ENABLE_GDC_FOR_SM100=1 and -DCUTLASS_ENABLE_GDC_FOR_SM90=1) to Cutlass GEMM module builds in the JIT compilation pipeline to enable GDC support for SM100 and SM90 GPU targets.

Changes

Cohort / File(s) Summary
Cutlass GEMM GDC Flags
flashinfer/jit/gemm/core.py
Added -DCUTLASS_ENABLE_GDC_FOR_SM100=1 and -DCUTLASS_ENABLE_GDC_FOR_SM90=1 preprocessor flags across multiple GEMM generation paths (including standard, FP8, and FP4 variants) to enable GDC compilation support for SM100 and SM90 architectures.

Estimated code review effort

🎯 1 (Trivial) | ⏱️ ~3 minutes

Possibly related issues

  • Issue #2708: Directly addresses the missing GDC compile flags that caused PDL race condition and output-tile corruption on SM100/SM90 targets.

Possibly related PRs

  • PR #2549: Also modifies SM100 Cutlass GEMM build configuration (adds new SM100 mxfp8 generator alongside this PR's GDC flag additions).

Suggested labels

run-ci, v0.6.4

Suggested reviewers

  • yzh119
  • jimmyzho
  • nv-yunzheq

Poem

🐰 The compiler flags now bloom with grace,
SM100 and SM90 find their place,
GDC enabled, race conditions fade,
With careful patches, GPU works are made! 🚀

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description check ✅ Passed The description is comprehensive and addresses the required template sections. It includes a clear summary, related issues, and detailed technical context about the problem and fix.
Title check ✅ Passed The title accurately describes the main change: adding missing GDC compile flags to fix NaN issues in GEMM kernels under concurrency, which matches the core objective of the PR.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs).
Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a significant race condition within CUTLASS GEMM kernels by integrating crucial compile flags into the JIT compilation process. The absence of these flags led to device-side synchronization barriers being compiled out, causing data corruption and incorrect outputs, particularly under high concurrency. By adding the CUTLASS_ENABLE_GDC_FOR_SM100 and CUTLASS_ENABLE_GDC_FOR_SM90 flags, the change ensures that Programmatic Dependent Launch (PDL) operations correctly synchronize, thereby stabilizing GEMM computations across various architectures.

Highlights

  • Fixes Race Condition: Addressed a critical race condition in CUTLASS GEMM kernels by enabling proper device-side synchronization for Programmatic Dependent Launch (PDL) operations.
  • Missing Compile Flags: Introduced -DCUTLASS_ENABLE_GDC_FOR_SM100=1 and -DCUTLASS_ENABLE_GDC_FOR_SM90=1 compile flags to various JIT modules, which were previously missing.
  • Affected Modules: Applied the necessary flags to fp4_gemm_cutlass (SM100, SM103, SM120), fp8_gemm_cutlass (SM100), mxfp8_gemm_cutlass (SM100), and gemm_sm120 modules.
  • Symptoms Addressed: Resolved intermittent failures, uninitialized memory (NaN/garbage) in output tiles, and contiguous 128-aligned NaN blocks observed under high concurrency on SM120 hardware.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flashinfer/jit/gemm/core.py
    • Added DCUTLASS_ENABLE_GDC_FOR_SM100=1 and DCUTLASS_ENABLE_GDC_FOR_SM90=1 to the gen_gemm_sm100_module_cutlass_fp4 function's extra_cuda_cflags.
    • Added DCUTLASS_ENABLE_GDC_FOR_SM100=1 and DCUTLASS_ENABLE_GDC_FOR_SM90=1 to the gen_gemm_sm103_module_cutlass_fp4 function's extra_cuda_cflags.
    • Added DCUTLASS_ENABLE_GDC_FOR_SM100=1 and DCUTLASS_ENABLE_GDC_FOR_SM90=1 to the gen_gemm_sm120_module_cutlass_fp4 function's extra_cuda_cflags.
    • Added DCUTLASS_ENABLE_GDC_FOR_SM100=1 and DCUTLASS_ENABLE_GDC_FOR_SM90=1 to the gen_gemm_sm100_module_cutlass_fp8 function's extra_cuda_cflags.
    • Added DCUTLASS_ENABLE_GDC_FOR_SM100=1 and DCUTLASS_ENABLE_GDC_FOR_SM90=1 to the gen_gemm_sm100_module_cutlass_mxfp8 function's extra_cuda_cflags.
    • Added DCUTLASS_ENABLE_GDC_FOR_SM100=1 and DCUTLASS_ENABLE_GDC_FOR_SM90=1 to the gen_gemm_sm120_module function's extra_cuda_cflags.
Activity
  • No activity has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a race condition in CUTLASS GEMM kernels by adding the missing CUTLASS_ENABLE_GDC_FOR_SM100=1 and CUTLASS_ENABLE_GDC_FOR_SM90=1 compile flags. This is a crucial fix for enabling proper synchronization with Programmatic Dependent Launch (PDL). The changes are consistently applied across all the specified GEMM JIT modules. My feedback focuses on improving code maintainability by refactoring the duplicated compile flags into a shared constant, which will make future updates easier.

Comment on lines 91 to 96
+ [
"-DENABLE_BF16",
"-DENABLE_FP4",
"-DCUTLASS_ENABLE_GDC_FOR_SM100=1",
"-DCUTLASS_ENABLE_GDC_FOR_SM90=1",
],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve maintainability and reduce code duplication, consider defining these new GDC (Grid Dependency Control) compile flags as a constant list at the module level. This constant can then be reused across all the JIT spec generation functions that require these flags, as this pattern is repeated in 5 other places in this file.

For example, you could define the following at the top of flashinfer/jit/gemm/core.py:

GDC_COMPILE_FLAGS = [
    "-DCUTLASS_ENABLE_GDC_FOR_SM100=1",
    "-DCUTLASS_ENABLE_GDC_FOR_SM90=1",
]

And then use it as suggested below.

        + [
            "-DENABLE_BF16",
            "-DENABLE_FP4",
        ] + GDC_COMPILE_FLAGS

Comment on lines 160 to 165
+ [
"-DENABLE_BF16",
"-DENABLE_FP4",
"-DCUTLASS_ENABLE_GDC_FOR_SM100=1",
"-DCUTLASS_ENABLE_GDC_FOR_SM90=1",
],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To avoid code duplication, please use the GDC_COMPILE_FLAGS constant as suggested in the previous comment.

        + [
            "-DENABLE_BF16",
            "-DENABLE_FP4",
        ] + GDC_COMPILE_FLAGS

Comment on lines 210 to 215
+ [
"-DENABLE_BF16",
"-DENABLE_FP4",
"-DCUTLASS_ENABLE_GDC_FOR_SM100=1",
"-DCUTLASS_ENABLE_GDC_FOR_SM90=1",
],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To avoid code duplication, please use the GDC_COMPILE_FLAGS constant as suggested.

        + [
            "-DENABLE_BF16",
            "-DENABLE_FP4",
        ] + GDC_COMPILE_FLAGS

Comment on lines 263 to 267
+ [
"-DENABLE_BF16",
"-DCUTLASS_ENABLE_GDC_FOR_SM100=1",
"-DCUTLASS_ENABLE_GDC_FOR_SM90=1",
],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To avoid code duplication, please use the GDC_COMPILE_FLAGS constant as suggested.

        + [
            "-DENABLE_BF16",
        ] + GDC_COMPILE_FLAGS

Comment on lines 358 to 362
+ [
"-DENABLE_BF16",
"-DCUTLASS_ENABLE_GDC_FOR_SM100=1",
"-DCUTLASS_ENABLE_GDC_FOR_SM90=1",
],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To avoid code duplication, please use the GDC_COMPILE_FLAGS constant as suggested.

        + [
            "-DENABLE_BF16",
        ] + GDC_COMPILE_FLAGS

Comment on lines +529 to +533
extra_cuda_cflags=nvcc_flags
+ [
"-DCUTLASS_ENABLE_GDC_FOR_SM100=1",
"-DCUTLASS_ENABLE_GDC_FOR_SM90=1",
],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To avoid code duplication, please use the GDC_COMPILE_FLAGS constant as suggested.

        extra_cuda_cflags=nvcc_flags + GDC_COMPILE_FLAGS

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
flashinfer/jit/gemm/core.py (1)

94-95: Centralize the shared GDC flag pair.

Repeating the same two literals across six generators makes this easy to miss again. A module-level constant keeps future CUTLASS JIT entries aligned.

♻️ Refactor sketch
+CUTLASS_GDC_NVCC_FLAGS = [
+    "-DCUTLASS_ENABLE_GDC_FOR_SM100=1",
+    "-DCUTLASS_ENABLE_GDC_FOR_SM90=1",
+]
+
 def gen_gemm_sm100_module_cutlass_fp4() -> JitSpec:
     ...
     return gen_jit_spec(
         "fp4_gemm_cutlass",
         source_paths,
         extra_cuda_cflags=nvcc_flags
         + [
             "-DENABLE_BF16",
             "-DENABLE_FP4",
-            "-DCUTLASS_ENABLE_GDC_FOR_SM100=1",
-            "-DCUTLASS_ENABLE_GDC_FOR_SM90=1",
-        ],
+        ]
+        + CUTLASS_GDC_NVCC_FLAGS,
         extra_cflags=[
             "-DFAST_BUILD",
         ],
     )

Apply the same replacement in the other five generators.

Also applies to: 163-164, 213-214, 265-266, 360-361, 529-533

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/jit/gemm/core.py` around lines 94 - 95, Create a module-level
constant in flashinfer/jit/gemm/core.py (e.g., GDC_CFLAGS =
["-DCUTLASS_ENABLE_GDC_FOR_SM100=1", "-DCUTLASS_ENABLE_GDC_FOR_SM90=1"]) and
replace every occurrence of those two literal strings in the six JIT generator
option lists with that constant (where the code currently has the two-flag pair
at the shown locations). Ensure each generator appends or extends using
GDC_CFLAGS instead of duplicating the literals so all CUTLASS JIT entries share
the single source of truth.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@flashinfer/jit/gemm/core.py`:
- Around line 94-95: Create a module-level constant in
flashinfer/jit/gemm/core.py (e.g., GDC_CFLAGS =
["-DCUTLASS_ENABLE_GDC_FOR_SM100=1", "-DCUTLASS_ENABLE_GDC_FOR_SM90=1"]) and
replace every occurrence of those two literal strings in the six JIT generator
option lists with that constant (where the code currently has the two-flag pair
at the shown locations). Ensure each generator appends or extends using
GDC_CFLAGS instead of duplicating the literals so all CUTLASS JIT entries share
the single source of truth.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c8ece308-23c9-4d97-a1d5-7b443823d60d

📥 Commits

Reviewing files that changed from the base of the PR and between 44abf50 and 4a4c903.

📒 Files selected for processing (1)
  • flashinfer/jit/gemm/core.py

@johnnynunez
Copy link
Copy Markdown
Contributor

Reported! Thanks

@yzh119 yzh119 merged commit 4c4013b into flashinfer-ai:main Mar 9, 2026
20 checks passed
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Mar 10, 2026

this PR seem to affect AOT package for 0.6.6 release
https://github.com/flashinfer-ai/flashinfer/actions/runs/22870567870/job/66353637447?pr=2730

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Mar 10, 2026

btw have we filed an issue to cutlass?

@johnnynunez
Copy link
Copy Markdown
Contributor

johnnynunez commented Mar 10, 2026

btw have we filed an issue to cutlass?

@depaulmillz PR is the correct one, he is from cutlass team... I pinged him because the vLLM team pinged us and it was related with cutlass

@eugr
Copy link
Copy Markdown

eugr commented Mar 10, 2026

@voipmonitor - Looks like this PR causes flashinfer-jit-cache compilation to fail if the build doesn't include SM90 arch, e.g. if we build with FLASHINFER_CUDA_ARCH_LIST=12.1a:

lots of errors like this.

/workspace/flashinfer/build/aot/generated/gen_gemm_sm120/gemm_groupwise_e5m2_f16_majorfalse_sm120.cu -o /workspace/flashinfer/build/aot/cached_ops/gemm_sm120/gemm_groupwise_e5m2_f16_majorfalse_sm120.cuda.o
2026-03-10T15:07:27.044869Z 01E 2026-03-10 08:07:27,044 - INFO - #15 975.0 /workspace/flashinfer/3rdparty/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp(794): error: class "cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100<cute::tuple<cute::_1, cute::_1, cute::_1>, 2U>" has no member "is_last_tile"
2026-03-10T15:07:27.044870Z 01E 2026-03-10 08:07:27,044 - INFO - #15 975.0           if (scheduler.is_last_tile(work_tile_info)) {
2026-03-10T15:07:27.044871Z 01E 2026-03-10 08:07:27,044 - INFO - #15 975.0                         ^
2026-03-10T15:07:27.044875Z 01E 2026-03-10 08:07:27,044 - INFO - #15 975.0           detected during:
2026-03-10T15:07:27.044876Z 01E 2026-03-10 08:07:27,044 - INFO - #15 975.0             instantiation of "void cutlass::gemm::kernel::GemmUniversal<ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileSchedulerTag_, std::enable_if_t<std::is_base_of_v<cutlass::gemm::KernelTmaWarpSpecializedCooperative, CollectiveMainloop_::DispatchPolicy::Schedule>, void>>::operator()(const cutlass::gemm::kernel::GemmUniversal<ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileSchedulerTag_, std::enable_if_t<std::is_base_of_v<cutlass::gemm::KernelTmaWarpSpecializedCooperative, CollectiveMainloop_::DispatchPolicy::Schedule>, void>>::Params &, char *) [with ProblemShape_=cute::tuple<int, int, int, int>, CollectiveMainloop_=cutlass::gemm::collective::CollectiveMma<cutlass::gemm::MainloopSm120TmaWarpSpecializedBlockwiseScaling<2, 2, cute::tuple<cute::_1, cute::_1, cute::_1>, cutlass::gemm::KernelTmaWarpSpecializedCooperativeBlockwiseScalingSm120<2>>, cute::tuple<cute::_128, cute::_128, cute::_128>, cutlass::float_e5m2_t, cute::tuple<cute::tuple<int64_t, cute::C<1>, int64_t>, cute::Layout<cute::tuple<cute::tuple<cute::C<1>, int32_t>, cute::tuple<cute::C<128>, int32_t>, int32_t>, cute::tuple<cute::tuple<cute::C<0>, cute::C<1>>, cute::tuple<cute::_0, int32_t>, int32_t>>>, cutlass::float_e5m2_t, cute::tuple<cute::tuple<int64_t, cute::C<1>, int64_t>, cute::Layout<cute::tuple<cute::tuple<cute::C<128>, int32_t>, cute::tuple<cute::C<128>, int32_t>, int32_t>, cute::tuple<cute::tuple<cute::C<0>, cute::C<1>>, cute::tuple<cute::_0, int32_t>, int32_t>>>, cute::TiledMMA<cute::MMA_Atom<cute::SM120_16x8x32_TN<cutlass::float_e5m2_t, cutlass::float_e5m2_t, float>>, cute::Layout<cute::tuple<cute::_4, cute::_2, cute::_1>, cute::tuple<cute::_1, cute::_4, cute::C<0>>>, cute::tuple<cute::C<128>, cute::C<32>, cute::_32>>, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<3, 4, 3>, cute::smem_ptr_flag_bits<8>, cute::Layout<cute::tuple<cute::_8, cute::_128>, cute::tuple<cute::_128, cute::_1>>>, cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, uint8_t>, cute::identity, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<3, 4, 3>, cute::smem_ptr_flag_bits<8>, cute::Layout<cute::tuple<cute::_8, cute::_128>, cute::tuple<cute::_128, cute::_1>>>, cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, uint8_t>, cute::identity>, CollectiveEpilogue_=cutlass::epilogue::collective::CollectiveEpilogue<cutlass::epilogue::Sm90TmaWarpSpecialized<3, 2, 4, true, false>, cute::tuple<cute::_128, cute::_128, cute::_128>, cute::tuple<cute::C<64>, cute::C<32>>, cutlass::half_t, cute::tuple<int64_t, cute::C<1>, int64_t>, cutlass::half_t, cute::tuple<int64_t, cute::C<1>, int64_t>, cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::Sm120TmaWarpSpecialized<3, 2, 4, true, false>, cutlass::epilogue::fusion::LinearCombination<cutlass::half_t, float, cutlass::half_t, float, cutlass::FloatRoundStyle::round_to_nearest>, cute::tuple<cute::_128, cute::_128, cute::_128>, cute::tuple<cute::C<64>, cute::C<32>>>, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<2, 4, 3>, cute::smem_ptr_flag_bits<16>, cute::Layout<cute::tuple<cute::C<8>, cute::C<32>>, cute::tuple<cute::_32, cute::_1>>>, cute::SM75_U32x2_LDSM_N, cute::SM90_TMA_STORE, cute::ComposedLayout<cute::Swizzle<2, 4, 3>, cute::smem_ptr_flag_bits<16>, cute::Layout<cute::tuple<cute::C<8>, cute::C<32>>, cute::tuple<cute::_32, cute::_1>>>, cute::SM90_U32x2_STSM_N, cute::Copy_Atom<cute::SM90_U32x2_STSM_N, cutlass::half_t>, void>, TileSchedulerTag_=void]" at line 123 of /workspace/flashinfer/3rdparty/cutlass/include/cutlass/device_kernel.h

@johnnynunez, @yzh119 - FYI

bkryu pushed a commit that referenced this pull request Mar 12, 2026
…g GDC flags cause PDL synchronization barriers to compile as no-ops" (#2737)

Proposing to revert #2716 in order to unblock
0.6.6 release

#2716 seems to have
broken AOT packages


https://github.com/flashinfer-ai/flashinfer/actions/runs/22870567870/job/66353637447?pr=2730


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Removed legacy GPU compilation flags related to GDC enablement for
certain GPU tiers during JIT GEMM generation, reducing extra compile
flags and build noise; GDC-related flags for the latest GPU tier remain
enabled where still applicable.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: yzh119 <zihaoy@nvidia.com>
voipmonitor added a commit to voipmonitor/flashinfer that referenced this pull request Mar 13, 2026
Re-applies the fix from flashinfer-ai#2716 (reverted in flashinfer-ai#2737) using only
-DCUTLASS_ENABLE_GDC_FOR_SM100=1, without -DCUTLASS_ENABLE_GDC_FOR_SM90=1.

The SM90 flag was the cause of the AOT build failure: it triggers a
direct #ifdef in sm90_gemm_tma_warpspecialized_cooperative.hpp (line 794)
that calls scheduler.is_last_tile() — but SM100+/SM120 schedulers
(PersistentTileSchedulerSm100StreamK) don't have that method.

The SM100 flag alone is sufficient because CUTLASS 4.2.1's
grid_dependency_control.h defines CUTLASS_GDC_ENABLED for the entire
SM100 family (SM100, SM101, SM103, SM120, SM121) when
CUTLASS_ENABLE_GDC_FOR_SM100 is set.

All affected GEMM kernels use enablePDL=true, so the device-side
GDC barriers (griddepcontrol.wait / griddepcontrol.launch_dependents)
must be compiled in — otherwise PDL enables host-side kernel overlap
but device-side synchronization is compiled out as no-ops, causing
a race condition (NaN/garbage in output tiles under concurrency).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@coderabbitai coderabbitai Bot mentioned this pull request Mar 16, 2026
5 tasks
aleozlx pushed a commit that referenced this pull request Mar 20, 2026
## Summary

Re-applies #2716 (reverted in #2737) with the fix for the AOT build
failure.

**Only `-DCUTLASS_ENABLE_GDC_FOR_SM100=1`** is added. The
`-DCUTLASS_ENABLE_GDC_FOR_SM90=1` flag that broke AOT builds is
intentionally omitted.

## Why the original PR broke AOT

`sm90_gemm_tma_warpspecialized_cooperative.hpp:794` has a direct `#ifdef
CUTLASS_ENABLE_GDC_FOR_SM90` guard (not `CUTLASS_GDC_ENABLED`) that
calls `scheduler.is_last_tile()`. When compiling SM120 kernels with that
flag, the SM120 scheduler (`PersistentTileSchedulerSm100StreamK`)
doesn't have `is_last_tile()` → compilation error.

## Why SM100 flag alone is sufficient

CUTLASS 4.2.1 `grid_dependency_control.h` defines `CUTLASS_GDC_ENABLED`
for the entire SM100 family (SM100/101/103/120/121) when
`CUTLASS_ENABLE_GDC_FOR_SM100` is set. This enables
`griddepcontrol.wait` and `griddepcontrol.launch_dependents` device-side
barriers for all affected architectures.

## Why this is needed

All affected GEMM kernels hardcode `enablePDL=true`, which enables
host-side kernel overlap. Without the GDC compile flag, the device-side
synchronization barriers compile as no-ops → race condition →
NaN/garbage output tiles under concurrency.

## Affected modules

- `fp4_gemm_cutlass` (SM100)
- `fp4_gemm_cutlass_sm103` (SM103)
- `fp4_gemm_cutlass_sm120` (SM120)
- `fp8_gemm_cutlass` (SM100)
- `mxfp8_gemm_cutlass` (SM100)
- `gemm_sm120` (SM120 FP8 groupwise)

(`tgv_gemm` already had the SM100 flag.)

## Test plan

- [ ] AOT build with `FLASHINFER_CUDA_ARCH_LIST="12.1a"` (the exact
config that broke before)
- [ ] AOT build with full arch list `"7.5 8.0 8.9 9.0a 10.0a 12.0a"`
- [ ] FP4 GEMM correctness under concurrent streams on SM120

🤖 Generated with [Claude Code](https://claude.com/claude-code)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Chores**
* Updated CUDA compilation configurations for matrix multiplication
kernels across multiple data format variants (FP4, FP8, MXFP8, BF16)
supporting additional GPU architectures.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
murphymatt pushed a commit to fw-ai/flashinfer that referenced this pull request Mar 31, 2026
…g GDC flags cause PDL synchronization barriers to compile as no-ops" (#2737)

Proposing to revert flashinfer-ai/flashinfer#2716 in order to unblock
0.6.6 release

flashinfer-ai/flashinfer#2716 seems to have
broken AOT packages


https://github.com/flashinfer-ai/flashinfer/actions/runs/22870567870/job/66353637447?pr=2730


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Removed legacy GPU compilation flags related to GDC enablement for
certain GPU tiers during JIT GEMM generation, reducing extra compile
flags and build noise; GDC-related flags for the latest GPU tier remain
enabled where still applicable.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: yzh119 <zihaoy@nvidia.com>
aleozlx pushed a commit that referenced this pull request Apr 1, 2026
…es on SM12x (#2913)

### Summary

- Add missing `-DCUTLASS_ENABLE_GDC_FOR_SM100=1` compile flag to all
CUTLASS fused MoE JIT modules (SM100/SM103/SM120) and
`-DCUTLASS_ENABLE_GDC_FOR_SM90=1` to SM90 modules
- Sync nv_internal `grid_dependency_control.h` with upstream CUTLASS to
support SM100/SM103/SM110/SM120/SM121 GDC
- Add `-DCUTLASS_ENABLE_GDC_FOR_SM90=1` to FP8 blockscale GEMM SM90
module

### Problem

Random `cudaErrorIllegalInstruction` crashes on DGX Spark (SM121) and
RTX 50-series (SM120) when running NVFP4 MoE models (e.g., Nemotron,
Qwen3.5-122B) under load. The crashes are intermittent and worsen with
longer context lengths and higher concurrency.

**Root cause:** PR #2780 fixed the missing GDC compile flags for GEMM
modules (`flashinfer/jit/gemm/core.py`), but the **CUTLASS fused MoE
modules** in `flashinfer/jit/fused_moe.py` and the **FP8 blockscale GEMM
module** were not fixed. This is the exact same class of bug as #2708.

Without `-DCUTLASS_ENABLE_GDC_FOR_SM100=1`, CUTLASS's
`grid_dependency_control.h` compiles `wait_on_dependent_grids()` and
`launch_dependent_grids()` as **empty no-ops**:

```cpp
CUTLASS_DEVICE void wait_on_dependent_grids() {
#if (defined(CUTLASS_GDC_ENABLED))   // ← not defined without the flag
  asm volatile("griddepcontrol.wait;");
#endif
}
```

Meanwhile, the host-side code still sets
`programmaticStreamSerializationAllowed = true` (PDL enabled) via
`device_support_pdl()` which returns `True` for all `major >= 9`,
including SM12x. This means:

1. **Host enables PDL** → CUDA runtime overlaps consecutive kernels
2. **Device GDC barriers are no-ops** → No synchronization between
overlapping kernels
3. **Race condition** → Dependent kernel reads stale global memory →
corruption → `cudaErrorIllegalInstruction`

The crash is random because it depends on exact kernel scheduling
timing, which varies per request.

### Fix

**`flashinfer/jit/fused_moe.py`** — Added GDC flags to all CUTLASS fused
MoE modules:

| Module | Flag | Architectures Covered |
|---|---|---|
| `gen_cutlass_fused_moe_sm120_module()` |
`-DCUTLASS_ENABLE_GDC_FOR_SM100=1` | SM120, SM121 |
| `gen_cutlass_fused_moe_sm103_module()` |
`-DCUTLASS_ENABLE_GDC_FOR_SM100=1` | SM103, SM120, SM121 |
| `gen_cutlass_fused_moe_sm100_module()` |
`-DCUTLASS_ENABLE_GDC_FOR_SM100=1` | SM100, SM110, SM120, SM121 |
| `gen_cutlass_fused_moe_sm90_module()` |
`-DCUTLASS_ENABLE_GDC_FOR_SM90=1` | SM90 |
| `gen_trtllm_gen_fused_moe_sm100_module()` |
`-DCUTLASS_ENABLE_GDC_FOR_SM100=1` | SM100+, SM120, SM121 |

**`flashinfer/jit/gemm/fp8_blockscale.py`** — Added
`-DCUTLASS_ENABLE_GDC_FOR_SM90=1` to
`gen_fp8_blockscale_gemm_sm90_module()`.

**`csrc/nv_internal/.../grid_dependency_control.h`** — Synced with
upstream CUTLASS
(`3rdparty/cutlass/include/cutlass/arch/grid_dependency_control.h`) to
add SM100+ GDC support. Previously only handled SM90, so any nv_internal
TensorRT-LLM code compiled for SM12x would have GDC barriers silently
compiled as no-ops.

### Why `-DCUTLASS_ENABLE_GDC_FOR_SM100=1` covers SM12x

CUTLASS uses a single flag for the entire Blackwell family. From
`grid_dependency_control.h`:

```cpp
#if(CUDA_BARRIER_ENABLED && defined(CUTLASS_ENABLE_GDC_FOR_SM100) && defined(__CUDA_ARCH__) && \
    ((__CUDA_ARCH__ == 1000 && ...) ||   // SM100
     (__CUDA_ARCH__ == 1030 && ...) ||   // SM103
     (__CUDA_ARCH__ == 1100 && ...) ||   // SM110
     (__CUDA_ARCH__ == 1200 && ...) ||   // SM120 (RTX 50-series)
     (__CUDA_ARCH__ == 1210 && ...)))    // SM121 (DGX Spark)
#define CUTLASS_GDC_ENABLED
```

### Why SM90 GDC flag was NOT added to SM100+ modules

PR #2716 attempted to add both `-DCUTLASS_ENABLE_GDC_FOR_SM90=1` and
`-DCUTLASS_ENABLE_GDC_FOR_SM100=1` to all modules. It broke AOT builds
because `sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp`
checks `CUTLASS_ENABLE_GDC_FOR_SM90` and calls
`scheduler.is_last_tile()` — a method not present on the SM120
scheduler. PR #2780 corrected this by using only the SM100 flag for
SM100+ modules. This PR follows the same approach.

### Related

- #2708 — Original issue: missing GDC flags cause PDL race condition
- #2716 — First fix attempt (reverted — broke AOT)
- #2780 — Corrected fix for GEMM modules only
-
[vllm-project/vllm#38423](vllm-project/vllm#38423)
— NVFP4 bugfix on DGX Spark
- [NVIDIA/cutlass#3121](NVIDIA/cutlass#3121) —
K=64 block-scaled GEMM tiles (separate issue)

### Test plan

- [x] Clear JIT cache: `rm -rf ~/.cache/flashinfer/`
- [x] Run NVFP4 MoE model on SM121 (DGX Spark) with 128K context under
load — verify no `cudaErrorIllegalInstruction`
- [x] Run NVFP4 MoE model on SM120 (RTX 50-series) with concurrent
requests — verify no NaN/garbage output
- [x] Verify `CUDA_LAUNCH_BLOCKING=1` workaround is no longer needed
- [x] AOT build with `FLASHINFER_CUDA_ARCH_LIST="12.1a"` completes
without errors
- [x] SM90 (Hopper) fused MoE tests pass: `pytest tests/moe/`
- [x] SM100 GEMM tests still pass (no regression from existing GDC
flags)


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Expanded GPU kernel compilation support: enabled additional
optimizations for NVIDIA SM100 and SM90 GPUs, activating
dependency-control optimizations where available.
* Updated JIT/GEMM build configs to include these architecture-specific
compile options, improving performance and compatibility on supported
hardware.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants