fix(jit): GEMM kernels produce NaN under concurrency — missing GDC flags cause PDL synchronization barriers to compile as no-ops#2716
Conversation
…MM kernels All CUTLASS GEMM templates use `enablePDL=true` (Programmatic Dependent Launch), but the JIT compilation was missing `-DCUTLASS_ENABLE_GDC_FOR_SM100=1` and `-DCUTLASS_ENABLE_GDC_FOR_SM90=1` flags. Without these flags, `wait_on_dependent_grids()` and `launch_dependent_grids()` in CUTLASS `grid_dependency_control.h` compile as empty no-ops, eliminating the synchronization barriers needed for safe PDL execution. This causes a race condition where dependent kernels read stale data from global memory before the previous kernel completes its writeback, resulting in intermittent output tile corruption (128-aligned NaN blocks in output tensors) under high concurrency on SM120 (Blackwell RTX PRO 6000 / RTX 5090) and potentially SM100. The fix adds both GDC flags to all affected GEMM JIT modules: - SM100 FP4 (fp4_gemm_cutlass) - SM103 FP4 (fp4_gemm_cutlass_sm103) - SM120 FP4 (fp4_gemm_cutlass_sm120) - SM100 FP8 (fp8_gemm_cutlass) - SM100 MXFP8 (mxfp8_gemm_cutlass) - SM120 FP8 groupwise (gemm_sm120) Note: The TGV GEMM module already had DCUTLASS_ENABLE_GDC_FOR_SM100 but was missing DCUTLASS_ENABLE_GDC_FOR_SM90 (not added here as it uses a different code path). Fixes flashinfer-ai#2708 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
📝 WalkthroughWalkthroughThis PR adds two CUDA preprocessor compiler flags ( Changes
Estimated code review effort🎯 1 (Trivial) | ⏱️ ~3 minutes Possibly related issues
Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves a significant race condition within CUTLASS GEMM kernels by integrating crucial compile flags into the JIT compilation process. The absence of these flags led to device-side synchronization barriers being compiled out, causing data corruption and incorrect outputs, particularly under high concurrency. By adding the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request correctly addresses a race condition in CUTLASS GEMM kernels by adding the missing CUTLASS_ENABLE_GDC_FOR_SM100=1 and CUTLASS_ENABLE_GDC_FOR_SM90=1 compile flags. This is a crucial fix for enabling proper synchronization with Programmatic Dependent Launch (PDL). The changes are consistently applied across all the specified GEMM JIT modules. My feedback focuses on improving code maintainability by refactoring the duplicated compile flags into a shared constant, which will make future updates easier.
| + [ | ||
| "-DENABLE_BF16", | ||
| "-DENABLE_FP4", | ||
| "-DCUTLASS_ENABLE_GDC_FOR_SM100=1", | ||
| "-DCUTLASS_ENABLE_GDC_FOR_SM90=1", | ||
| ], |
There was a problem hiding this comment.
To improve maintainability and reduce code duplication, consider defining these new GDC (Grid Dependency Control) compile flags as a constant list at the module level. This constant can then be reused across all the JIT spec generation functions that require these flags, as this pattern is repeated in 5 other places in this file.
For example, you could define the following at the top of flashinfer/jit/gemm/core.py:
GDC_COMPILE_FLAGS = [
"-DCUTLASS_ENABLE_GDC_FOR_SM100=1",
"-DCUTLASS_ENABLE_GDC_FOR_SM90=1",
]And then use it as suggested below.
+ [
"-DENABLE_BF16",
"-DENABLE_FP4",
] + GDC_COMPILE_FLAGS| + [ | ||
| "-DENABLE_BF16", | ||
| "-DENABLE_FP4", | ||
| "-DCUTLASS_ENABLE_GDC_FOR_SM100=1", | ||
| "-DCUTLASS_ENABLE_GDC_FOR_SM90=1", | ||
| ], |
| + [ | ||
| "-DENABLE_BF16", | ||
| "-DENABLE_FP4", | ||
| "-DCUTLASS_ENABLE_GDC_FOR_SM100=1", | ||
| "-DCUTLASS_ENABLE_GDC_FOR_SM90=1", | ||
| ], |
| + [ | ||
| "-DENABLE_BF16", | ||
| "-DCUTLASS_ENABLE_GDC_FOR_SM100=1", | ||
| "-DCUTLASS_ENABLE_GDC_FOR_SM90=1", | ||
| ], |
| + [ | ||
| "-DENABLE_BF16", | ||
| "-DCUTLASS_ENABLE_GDC_FOR_SM100=1", | ||
| "-DCUTLASS_ENABLE_GDC_FOR_SM90=1", | ||
| ], |
| extra_cuda_cflags=nvcc_flags | ||
| + [ | ||
| "-DCUTLASS_ENABLE_GDC_FOR_SM100=1", | ||
| "-DCUTLASS_ENABLE_GDC_FOR_SM90=1", | ||
| ], |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
flashinfer/jit/gemm/core.py (1)
94-95: Centralize the shared GDC flag pair.Repeating the same two literals across six generators makes this easy to miss again. A module-level constant keeps future CUTLASS JIT entries aligned.
♻️ Refactor sketch
+CUTLASS_GDC_NVCC_FLAGS = [ + "-DCUTLASS_ENABLE_GDC_FOR_SM100=1", + "-DCUTLASS_ENABLE_GDC_FOR_SM90=1", +] + def gen_gemm_sm100_module_cutlass_fp4() -> JitSpec: ... return gen_jit_spec( "fp4_gemm_cutlass", source_paths, extra_cuda_cflags=nvcc_flags + [ "-DENABLE_BF16", "-DENABLE_FP4", - "-DCUTLASS_ENABLE_GDC_FOR_SM100=1", - "-DCUTLASS_ENABLE_GDC_FOR_SM90=1", - ], + ] + + CUTLASS_GDC_NVCC_FLAGS, extra_cflags=[ "-DFAST_BUILD", ], )Apply the same replacement in the other five generators.
Also applies to: 163-164, 213-214, 265-266, 360-361, 529-533
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/jit/gemm/core.py` around lines 94 - 95, Create a module-level constant in flashinfer/jit/gemm/core.py (e.g., GDC_CFLAGS = ["-DCUTLASS_ENABLE_GDC_FOR_SM100=1", "-DCUTLASS_ENABLE_GDC_FOR_SM90=1"]) and replace every occurrence of those two literal strings in the six JIT generator option lists with that constant (where the code currently has the two-flag pair at the shown locations). Ensure each generator appends or extends using GDC_CFLAGS instead of duplicating the literals so all CUTLASS JIT entries share the single source of truth.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@flashinfer/jit/gemm/core.py`:
- Around line 94-95: Create a module-level constant in
flashinfer/jit/gemm/core.py (e.g., GDC_CFLAGS =
["-DCUTLASS_ENABLE_GDC_FOR_SM100=1", "-DCUTLASS_ENABLE_GDC_FOR_SM90=1"]) and
replace every occurrence of those two literal strings in the six JIT generator
option lists with that constant (where the code currently has the two-flag pair
at the shown locations). Ensure each generator appends or extends using
GDC_CFLAGS instead of duplicating the literals so all CUTLASS JIT entries share
the single source of truth.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: c8ece308-23c9-4d97-a1d5-7b443823d60d
📒 Files selected for processing (1)
flashinfer/jit/gemm/core.py
|
Reported! Thanks |
|
this PR seem to affect AOT package for 0.6.6 release |
|
btw have we filed an issue to cutlass? |
@depaulmillz PR is the correct one, he is from cutlass team... I pinged him because the vLLM team pinged us and it was related with cutlass |
|
@voipmonitor - Looks like this PR causes flashinfer-jit-cache compilation to fail if the build doesn't include SM90 arch, e.g. if we build with FLASHINFER_CUDA_ARCH_LIST=12.1a: lots of errors like this. @johnnynunez, @yzh119 - FYI |
…g GDC flags cause PDL synchronization barriers to compile as no-ops" (#2737) Proposing to revert #2716 in order to unblock 0.6.6 release #2716 seems to have broken AOT packages https://github.com/flashinfer-ai/flashinfer/actions/runs/22870567870/job/66353637447?pr=2730 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Removed legacy GPU compilation flags related to GDC enablement for certain GPU tiers during JIT GEMM generation, reducing extra compile flags and build noise; GDC-related flags for the latest GPU tier remain enabled where still applicable. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: yzh119 <zihaoy@nvidia.com>
Re-applies the fix from flashinfer-ai#2716 (reverted in flashinfer-ai#2737) using only -DCUTLASS_ENABLE_GDC_FOR_SM100=1, without -DCUTLASS_ENABLE_GDC_FOR_SM90=1. The SM90 flag was the cause of the AOT build failure: it triggers a direct #ifdef in sm90_gemm_tma_warpspecialized_cooperative.hpp (line 794) that calls scheduler.is_last_tile() — but SM100+/SM120 schedulers (PersistentTileSchedulerSm100StreamK) don't have that method. The SM100 flag alone is sufficient because CUTLASS 4.2.1's grid_dependency_control.h defines CUTLASS_GDC_ENABLED for the entire SM100 family (SM100, SM101, SM103, SM120, SM121) when CUTLASS_ENABLE_GDC_FOR_SM100 is set. All affected GEMM kernels use enablePDL=true, so the device-side GDC barriers (griddepcontrol.wait / griddepcontrol.launch_dependents) must be compiled in — otherwise PDL enables host-side kernel overlap but device-side synchronization is compiled out as no-ops, causing a race condition (NaN/garbage in output tiles under concurrency). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
## Summary Re-applies #2716 (reverted in #2737) with the fix for the AOT build failure. **Only `-DCUTLASS_ENABLE_GDC_FOR_SM100=1`** is added. The `-DCUTLASS_ENABLE_GDC_FOR_SM90=1` flag that broke AOT builds is intentionally omitted. ## Why the original PR broke AOT `sm90_gemm_tma_warpspecialized_cooperative.hpp:794` has a direct `#ifdef CUTLASS_ENABLE_GDC_FOR_SM90` guard (not `CUTLASS_GDC_ENABLED`) that calls `scheduler.is_last_tile()`. When compiling SM120 kernels with that flag, the SM120 scheduler (`PersistentTileSchedulerSm100StreamK`) doesn't have `is_last_tile()` → compilation error. ## Why SM100 flag alone is sufficient CUTLASS 4.2.1 `grid_dependency_control.h` defines `CUTLASS_GDC_ENABLED` for the entire SM100 family (SM100/101/103/120/121) when `CUTLASS_ENABLE_GDC_FOR_SM100` is set. This enables `griddepcontrol.wait` and `griddepcontrol.launch_dependents` device-side barriers for all affected architectures. ## Why this is needed All affected GEMM kernels hardcode `enablePDL=true`, which enables host-side kernel overlap. Without the GDC compile flag, the device-side synchronization barriers compile as no-ops → race condition → NaN/garbage output tiles under concurrency. ## Affected modules - `fp4_gemm_cutlass` (SM100) - `fp4_gemm_cutlass_sm103` (SM103) - `fp4_gemm_cutlass_sm120` (SM120) - `fp8_gemm_cutlass` (SM100) - `mxfp8_gemm_cutlass` (SM100) - `gemm_sm120` (SM120 FP8 groupwise) (`tgv_gemm` already had the SM100 flag.) ## Test plan - [ ] AOT build with `FLASHINFER_CUDA_ARCH_LIST="12.1a"` (the exact config that broke before) - [ ] AOT build with full arch list `"7.5 8.0 8.9 9.0a 10.0a 12.0a"` - [ ] FP4 GEMM correctness under concurrent streams on SM120 🤖 Generated with [Claude Code](https://claude.com/claude-code) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Chores** * Updated CUDA compilation configurations for matrix multiplication kernels across multiple data format variants (FP4, FP8, MXFP8, BF16) supporting additional GPU architectures. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
…g GDC flags cause PDL synchronization barriers to compile as no-ops" (#2737) Proposing to revert flashinfer-ai/flashinfer#2716 in order to unblock 0.6.6 release flashinfer-ai/flashinfer#2716 seems to have broken AOT packages https://github.com/flashinfer-ai/flashinfer/actions/runs/22870567870/job/66353637447?pr=2730 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Removed legacy GPU compilation flags related to GDC enablement for certain GPU tiers during JIT GEMM generation, reducing extra compile flags and build noise; GDC-related flags for the latest GPU tier remain enabled where still applicable. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: yzh119 <zihaoy@nvidia.com>
…es on SM12x (#2913) ### Summary - Add missing `-DCUTLASS_ENABLE_GDC_FOR_SM100=1` compile flag to all CUTLASS fused MoE JIT modules (SM100/SM103/SM120) and `-DCUTLASS_ENABLE_GDC_FOR_SM90=1` to SM90 modules - Sync nv_internal `grid_dependency_control.h` with upstream CUTLASS to support SM100/SM103/SM110/SM120/SM121 GDC - Add `-DCUTLASS_ENABLE_GDC_FOR_SM90=1` to FP8 blockscale GEMM SM90 module ### Problem Random `cudaErrorIllegalInstruction` crashes on DGX Spark (SM121) and RTX 50-series (SM120) when running NVFP4 MoE models (e.g., Nemotron, Qwen3.5-122B) under load. The crashes are intermittent and worsen with longer context lengths and higher concurrency. **Root cause:** PR #2780 fixed the missing GDC compile flags for GEMM modules (`flashinfer/jit/gemm/core.py`), but the **CUTLASS fused MoE modules** in `flashinfer/jit/fused_moe.py` and the **FP8 blockscale GEMM module** were not fixed. This is the exact same class of bug as #2708. Without `-DCUTLASS_ENABLE_GDC_FOR_SM100=1`, CUTLASS's `grid_dependency_control.h` compiles `wait_on_dependent_grids()` and `launch_dependent_grids()` as **empty no-ops**: ```cpp CUTLASS_DEVICE void wait_on_dependent_grids() { #if (defined(CUTLASS_GDC_ENABLED)) // ← not defined without the flag asm volatile("griddepcontrol.wait;"); #endif } ``` Meanwhile, the host-side code still sets `programmaticStreamSerializationAllowed = true` (PDL enabled) via `device_support_pdl()` which returns `True` for all `major >= 9`, including SM12x. This means: 1. **Host enables PDL** → CUDA runtime overlaps consecutive kernels 2. **Device GDC barriers are no-ops** → No synchronization between overlapping kernels 3. **Race condition** → Dependent kernel reads stale global memory → corruption → `cudaErrorIllegalInstruction` The crash is random because it depends on exact kernel scheduling timing, which varies per request. ### Fix **`flashinfer/jit/fused_moe.py`** — Added GDC flags to all CUTLASS fused MoE modules: | Module | Flag | Architectures Covered | |---|---|---| | `gen_cutlass_fused_moe_sm120_module()` | `-DCUTLASS_ENABLE_GDC_FOR_SM100=1` | SM120, SM121 | | `gen_cutlass_fused_moe_sm103_module()` | `-DCUTLASS_ENABLE_GDC_FOR_SM100=1` | SM103, SM120, SM121 | | `gen_cutlass_fused_moe_sm100_module()` | `-DCUTLASS_ENABLE_GDC_FOR_SM100=1` | SM100, SM110, SM120, SM121 | | `gen_cutlass_fused_moe_sm90_module()` | `-DCUTLASS_ENABLE_GDC_FOR_SM90=1` | SM90 | | `gen_trtllm_gen_fused_moe_sm100_module()` | `-DCUTLASS_ENABLE_GDC_FOR_SM100=1` | SM100+, SM120, SM121 | **`flashinfer/jit/gemm/fp8_blockscale.py`** — Added `-DCUTLASS_ENABLE_GDC_FOR_SM90=1` to `gen_fp8_blockscale_gemm_sm90_module()`. **`csrc/nv_internal/.../grid_dependency_control.h`** — Synced with upstream CUTLASS (`3rdparty/cutlass/include/cutlass/arch/grid_dependency_control.h`) to add SM100+ GDC support. Previously only handled SM90, so any nv_internal TensorRT-LLM code compiled for SM12x would have GDC barriers silently compiled as no-ops. ### Why `-DCUTLASS_ENABLE_GDC_FOR_SM100=1` covers SM12x CUTLASS uses a single flag for the entire Blackwell family. From `grid_dependency_control.h`: ```cpp #if(CUDA_BARRIER_ENABLED && defined(CUTLASS_ENABLE_GDC_FOR_SM100) && defined(__CUDA_ARCH__) && \ ((__CUDA_ARCH__ == 1000 && ...) || // SM100 (__CUDA_ARCH__ == 1030 && ...) || // SM103 (__CUDA_ARCH__ == 1100 && ...) || // SM110 (__CUDA_ARCH__ == 1200 && ...) || // SM120 (RTX 50-series) (__CUDA_ARCH__ == 1210 && ...))) // SM121 (DGX Spark) #define CUTLASS_GDC_ENABLED ``` ### Why SM90 GDC flag was NOT added to SM100+ modules PR #2716 attempted to add both `-DCUTLASS_ENABLE_GDC_FOR_SM90=1` and `-DCUTLASS_ENABLE_GDC_FOR_SM100=1` to all modules. It broke AOT builds because `sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp` checks `CUTLASS_ENABLE_GDC_FOR_SM90` and calls `scheduler.is_last_tile()` — a method not present on the SM120 scheduler. PR #2780 corrected this by using only the SM100 flag for SM100+ modules. This PR follows the same approach. ### Related - #2708 — Original issue: missing GDC flags cause PDL race condition - #2716 — First fix attempt (reverted — broke AOT) - #2780 — Corrected fix for GEMM modules only - [vllm-project/vllm#38423](vllm-project/vllm#38423) — NVFP4 bugfix on DGX Spark - [NVIDIA/cutlass#3121](NVIDIA/cutlass#3121) — K=64 block-scaled GEMM tiles (separate issue) ### Test plan - [x] Clear JIT cache: `rm -rf ~/.cache/flashinfer/` - [x] Run NVFP4 MoE model on SM121 (DGX Spark) with 128K context under load — verify no `cudaErrorIllegalInstruction` - [x] Run NVFP4 MoE model on SM120 (RTX 50-series) with concurrent requests — verify no NaN/garbage output - [x] Verify `CUDA_LAUNCH_BLOCKING=1` workaround is no longer needed - [x] AOT build with `FLASHINFER_CUDA_ARCH_LIST="12.1a"` completes without errors - [x] SM90 (Hopper) fused MoE tests pass: `pytest tests/moe/` - [x] SM100 GEMM tests still pass (no regression from existing GDC flags) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Expanded GPU kernel compilation support: enabled additional optimizations for NVIDIA SM100 and SM90 GPUs, activating dependency-control optimizations where available. * Updated JIT/GEMM build configs to include these architecture-specific compile options, improving performance and compatibility on supported hardware. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
Summary
All CUTLASS GEMM templates use
enablePDL=true(Programmatic Dependent Launch), but the JIT compilation is missing-DCUTLASS_ENABLE_GDC_FOR_SM100=1and-DCUTLASS_ENABLE_GDC_FOR_SM90=1compile flags. Without these flags,wait_on_dependent_grids()andlaunch_dependent_grids()in CUTLASSgrid_dependency_control.hcompile as empty no-ops, eliminating the synchronization barriers needed for safe PDL execution.Root Cause
In
cutlass/include/cutlass/arch/grid_dependency_control.h:The
CUTLASS_GDC_ENABLEDmacro is only defined whenCUTLASS_ENABLE_GDC_FOR_SM100is passed as a compile flag. Without it, PDL launches kernels with overlap enabled at the host level (cudaLaunchAttributeProgrammaticStreamSerialization), but the device-side synchronization barriers are compiled out — creating a race condition.Symptoms
On SM120 (Blackwell RTX PRO 6000 / RTX 5090) with high concurrency (64+ simultaneous requests in SGLang with TP=8):
CUDA_LAUNCH_BLOCKING=1eliminates the bug (confirms race condition)Fix
Add
-DCUTLASS_ENABLE_GDC_FOR_SM100=1and-DCUTLASS_ENABLE_GDC_FOR_SM90=1to all affected GEMM JIT modules:fp4_gemm_cutlass(SM100)fp4_gemm_cutlass_sm103(SM103)fp4_gemm_cutlass_sm120(SM120)fp8_gemm_cutlass(SM100)mxfp8_gemm_cutlass(SM100)gemm_sm120(SM120 FP8 groupwise)The
tgv_gemmmodule already hadDCUTLASS_ENABLE_GDC_FOR_SM100.Note:
DCUTLASS_ENABLE_GDC_FOR_SM90is needed because the SM120 CUTLASS kernel (sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp) guardslaunch_dependent_grids()with#ifdef CUTLASS_ENABLE_GDC_FOR_SM90instead ofSM100(upstream CUTLASS bug).Verification
CUDA_LAUNCH_BLOCKING=1Environment
Related
Summary by CodeRabbit
Release Notes
Chores