Skip to content

[NVIDIA] Bugfix NVFP4 DGX Spark and RTX50#38423

Merged
vllm-bot merged 18 commits intovllm-project:mainfrom
johnnynunez:main
Mar 30, 2026
Merged

[NVIDIA] Bugfix NVFP4 DGX Spark and RTX50#38423
vllm-bot merged 18 commits intovllm-project:mainfrom
johnnynunez:main

Conversation

@johnnynunez
Copy link
Copy Markdown
Contributor

@johnnynunez johnnynunez commented Mar 28, 2026

Summary

Fix cudaErrorIllegalInstruction when running NVFP4 models (e.g. nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4) on SM12x GPUs (RTX 50 series SM120, DGX Spark SM121).

Root causes

  1. CUTLASS v4.2.2 lacks SM12x NVFP4 tile constraints — The bundled CUTLASS was missing SM120f family-level compilation support for NVFP4/MX Grouped GEMM and SM121-specific tile configurations (DGX Spark). This caused IllegalInstruction during decode when small-M tile variants were selected. Related upstream: NVIDIA/cutlass#3038.

  2. FlashInfer 0.6.6 bundles CUTLASS 4.2.1 — The FlashInfer CUTLASS MoE backend failed on SM12x with Failed to initialize cutlass TMA WS grouped gemm due to the same missing tile constraints. Fixed upstream in flashinfer-ai/flashinfer#2798.

  3. cutlass_scaled_mm_supports_fp4() reported false availability — Only checked CUDA runtime version (>= 12080), not whether the SM-specific kernel was actually compiled. On a build with only ENABLE_NVFP4_SM100, it incorrectly reported CUTLASS as available for SM12x, then failed at dispatch.

  4. Quantization kernels had no SM runtime guard — The scaled_fp4_quant, silu_and_mul_nvfp4_quant, and expert quant entry points dispatched to _sm1xxa kernels if any SM1xx was compiled, with no runtime check. If only SM100 SASS existed, CUDA would JIT-compile SM100 PTX for SM120 (different major arch), producing illegal instructions asynchronously — surfacing later at synchronize() as an opaque CUDA error.

  5. FlashInfer CUTLASS backend bypassed quant kernel checksselect_nvfp4_linear_backend() selected FlashInfer CUTLASS solely on has_device_capability(100), without verifying the vLLM quantization kernels (used by all non-Marlin backends) were compiled for the current SM.

Changes

File Change
CMakeLists.txt Bump CUTLASS from v4.2.2 to v4.4.2 — enables SM120f (family) compilation for NVFP4/MX Grouped GEMM, covering RTX 50 (SM120) and DGX Spark (SM121)
docker/Dockerfile Bump FlashInfer from 0.6.6 to 0.6.7 (includes CUTLASS 4.4.2, fixes TMA grouped GEMM on SM12x)
docker/Dockerfile.nightly_torch Same FlashInfer bump (source build)
docker/versions.json FLASHINFER_VERSION: 0.6.60.6.7
nvfp4_scaled_mm_entry.cu cutlass_scaled_mm_supports_fp4() now checks compile-time ENABLE_NVFP4_SM100/ENABLE_NVFP4_SM120 guards per SM range instead of a blanket >= 100 check
nvfp4_quant_entry.cu Added nvfp4_quant_sm_supported() runtime guard to all four quant entry points (scaled_fp4_quant, scaled_fp4_experts_quant, silu_and_mul_nvfp4_quant, silu_and_mul_scaled_fp4_experts_quant)
nvfp4_utils.py select_nvfp4_linear_backend() gates FlashInfer CUTLASS on cutlass_fp4_supported() + adds validation assert for all FlashInfer backends

What is NOT changed

Marlin remains a valid fallback on SM12x. Marlin FP4 uses weight-only dequantization to BF16 — it does not use native FP4 tensor core instructions and works correctly on all Blackwell architectures including DGX Spark. Benchmarks confirm Marlin is stable on SM121 (~558 tok/s, on par with vLLM CUTLASS at ~562 tok/s). The Marlin path (apply_fp4_marlin_linear) bypasses the vLLM quant kernels entirely, so the SM guards in nvfp4_quant_entry.cu do not affect it.

Behavior on SM12x after this PR

Scenario Before After
Build includes ENABLE_NVFP4_SM120 + CUTLASS v4.4.2 IllegalInstruction Native CUTLASS backend selected, works correctly
Build lacks ENABLE_NVFP4_SM120 IllegalInstruction (SM100 PTX JIT to SM120) Native CUTLASS correctly reports unavailable; Marlin selected as fallback — works correctly
FlashInfer CUTLASS MoE on SM12x Failed to initialize cutlass TMA WS grouped gemm (CUTLASS 4.2.1 in FlashInfer 0.6.6) Works correctly with FlashInfer 0.6.7 (CUTLASS 4.4.2)

Follow-up: FlashInfer 0.6.8

flashinfer-ai/flashinfer#2738 (merged March 28, 2026) adds native NVFP4 and MXFP4 group GEMM support for SM120/SM121 (RTX 50 / DGX Spark) directly in FlashInfer. This will land in FlashInfer 0.6.8. Once released, FLASHINFER_VERSION should be bumped in docker/Dockerfile, docker/Dockerfile.nightly_torch, and docker/versions.json to unlock FlashInfer's own SM12x NVFP4/MXFP4 kernels (including GDC unguarding and PDL group GEMM fixes). TODO comments have been added to both Dockerfiles tracking this.

Test plan

  • Build with CUDA_ARCHS="12.0a;12.1a" on DGX Spark (SM121), verify NVFP4 model serves with vLLM CUTLASS backend (VLLM_NVFP4_GEMM_BACKEND=cutlass --moe-backend=cutlass)
  • Verify FlashInfer CUTLASS MoE on SM12x no longer hits TMA init error
  • Build with CUDA_ARCHS="12.0a;12.1a", verify Marlin fallback still works (VLLM_NVFP4_GEMM_BACKEND=marlin --moe-backend=marlin)
  • Build with CUDA_ARCHS="10.0a" only, verify Marlin fallback on SM12x (no IllegalInstruction)
  • Verify SM100 (B200) still works with native CUTLASS (no regression from CUTLASS bump)
  • Verify SM89/SM90 still works (pre-Blackwell unaffected)
  • Run tests/models/quantization/test_nvfp4.py on SM120+
  • Docker build completes with FlashInfer 0.6.7 for both Dockerfile and Dockerfile.nightly_torch

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@johnnynunez johnnynunez changed the title fix NVFP4 DGX Spark and RTX50 [NVIDIA] Bugfix NVFP4 DGX Spark and RTX50 Mar 28, 2026
@mergify mergify Bot added ci/build nvidia bug Something isn't working labels Mar 28, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the CUTLASS revision to v4.4.2 and upgrades FlashInfer to version 0.6.7 across the Dockerfiles and requirement files. It also introduces runtime checks to verify that NVFP4 quantization kernels are compiled for the current GPU's SM version (SM100 or SM120) before use, preventing invalid backend selection or runtime failures. I have no feedback to provide.

Signed-off-by: johnnynunez <johnnynuca14@gmail.com>
Signed-off-by: johnnynunez <johnnynuca14@gmail.com>
@johnnynunez
Copy link
Copy Markdown
Contributor Author

johnnynunez commented Mar 28, 2026

@johnnynunez
Copy link
Copy Markdown
Contributor Author

Could a maintainer please add the ready label so CI can run? I have 3 merged PRs but need 4 to bypass the label requirement. Thank you!

@github-project-automation github-project-automation Bot moved this to Ready in NVIDIA Mar 28, 2026
@mgoin mgoin added ready ONLY add when PR is ready to merge/full CI is needed ready-run-all-tests Trigger CI with all tests for wide-ranging PRs labels Mar 28, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 28, 2026

Hi @johnnynunez, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

CUTLASS v4.4.2 added ArchTag to DispatchPolicy in
sm90_gemm_tma_warpspecialized_cooperative.hpp to distinguish SM90 from
SM120 kernel paths. Machete's custom MacheteCollectiveMma defines its
own DispatchPolicy but was missing this field, causing all 18 Machete
template instantiations to fail with "has no member ArchTag".
Also reformats nvfp4_scaled_mm_entry.cu to satisfy pre-commit linter.

Signed-off-by: johnnynunez <johnnynuca14@gmail.com>
@eugr
Copy link
Copy Markdown

eugr commented Mar 29, 2026

Getting consistent Illegal Instruction crashes with this PR.

Building Flashinfer from main with FLASHINFER_CUDA_ARCH_LIST=12.1a
vLLM from main with this PR applied with TORCH_CUDA_ARCH_LIST=12.1a

Exception raised from currentStreamCaptureStatusMayInitCtx at /pytorch/c10/cuda/CUDAGraphsC10Utils.h:71 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0xc8 (0xf152462b6778 in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so)
frame #1: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, unsigned int, bool) + 0x224 (0xf152463e4714 in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10_cuda.so)
frame #2: <unknown function> + 0xf1d388 (0xf15246f8d388 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: <unknown function> + 0x477e40 (0xf15246297e40 in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so)
frame #4: c10::TensorImpl::~TensorImpl() + 0x14 (0xf15246256d84 in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so)
frame #5: <unknown function> + 0x5fa548 (0xf1526c7ea548 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #6: <unknown function> + 0xb46d1c (0xf1526cd36d1c in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #7: VLLM::EngineCore() [0x524b64]
frame #8: _PyObject_ClearManagedDict + 0x10c (0x4fd240 in VLLM::EngineCore)
frame #9: VLLM::EngineCore() [0x527adc]
frame #10: VLLM::EngineCore() [0x5b3cac]
frame #11: VLLM::EngineCore() [0x5b2fec]
frame #12: VLLM::EngineCore() [0x58cf5c]
frame #13: _PyEval_EvalFrameDefault + 0x8fdc (0x56cf40 in VLLM::EngineCore)
frame #14: VLLM::EngineCore() [0x4c4d74]
frame #15: PyObject_CallMethodObjArgs + 0xa8 (0x4c6958 in VLLM::EngineCore)
frame #16: PyImport_ImportModuleLevelObject + 0x36c (0x58eaa0 in VLLM::EngineCore)
frame #17: _PyEval_EvalFrameDefault + 0x4cc0 (0x568c24 in VLLM::EngineCore)
frame #18: PyEval_EvalCode + 0x130 (0x562b54 in VLLM::EngineCore)
frame #19: VLLM::EngineCore() [0x55fd48]
frame #20: VLLM::EngineCore() [0x5045cc]
frame #21: _PyEval_EvalFrameDefault + 0x3e54 (0x567db8 in VLLM::EngineCore)
frame #22: VLLM::EngineCore() [0x4c4d74]
frame #23: PyObject_CallMethodObjArgs + 0xa8 (0x4c6958 in VLLM::EngineCore)
frame #24: PyImport_ImportModuleLevelObject + 0x36c (0x58eaa0 in VLLM::EngineCore)
frame #25: _PyEval_EvalFrameDefault + 0x4cc0 (0x568c24 in VLLM::EngineCore)
frame #26: PyEval_EvalCode + 0x130 (0x562b54 in VLLM::EngineCore)
frame #27: VLLM::EngineCore() [0x55fd48]
frame #28: VLLM::EngineCore() [0x5045cc]
frame #29: _PyEval_EvalFrameDefault + 0x3e54 (0x567db8 in VLLM::EngineCore)
frame #30: VLLM::EngineCore() [0x4c4d74]
frame #31: PyObject_CallMethodObjArgs + 0xa8 (0x4c6958 in VLLM::EngineCore)
frame #32: PyImport_ImportModuleLevelObject + 0x36c (0x58eaa0 in VLLM::EngineCore)
frame #33: _PyEval_EvalFrameDefault + 0x4cc0 (0x568c24 in VLLM::EngineCore)
frame #34: PyEval_EvalCode + 0x130 (0x562b54 in VLLM::EngineCore)
frame #35: VLLM::EngineCore() [0x55fd48]
frame #36: VLLM::EngineCore() [0x5045cc]
frame #37: _PyEval_EvalFrameDefault + 0x3e54 (0x567db8 in VLLM::EngineCore)
frame #38: VLLM::EngineCore() [0x4c4d74]
frame #39: PyObject_CallMethodObjArgs + 0xa8 (0x4c6958 in VLLM::EngineCore)
frame #40: PyImport_ImportModuleLevelObject + 0x36c (0x58eaa0 in VLLM::EngineCore)
frame #41: VLLM::EngineCore() [0x560238]
frame #42: VLLM::EngineCore() [0x5045cc]
frame #43: _PyEval_EvalFrameDefault + 0x3e54 (0x567db8 in VLLM::EngineCore)
frame #44: VLLM::EngineCore() [0x4c4d74]
frame #45: PyObject_CallMethodObjArgs + 0xa8 (0x4c6958 in VLLM::EngineCore)
frame #46: PyImport_ImportModuleLevelObject + 0x36c (0x58eaa0 in VLLM::EngineCore)
frame #47: _PyEval_EvalFrameDefault + 0x4cc0 (0x568c24 in VLLM::EngineCore)
frame #48: VLLM::EngineCore() [0x6c14a8]
frame #49: Py_FinalizeEx + 0x58 (0x67b088 in VLLM::EngineCore)
frame #50: Py_Exit + 0x18 (0x67c518 in VLLM::EngineCore)
frame #51: VLLM::EngineCore() [0x6811d0]
frame #52: VLLM::EngineCore() [0x680f04]
frame #53: PyRun_SimpleStringFlags + 0x7c (0x67ef1c in VLLM::EngineCore)
frame #54: Py_RunMain + 0x390 (0x68b690 in VLLM::EngineCore)
frame #55: Py_BytesMain + 0x28 (0x68b198 in VLLM::EngineCore)
frame #56: <unknown function> + 0x284c4 (0xf152f23d84c4 in /usr/lib/aarch64-linux-gnu/libc.so.6)
frame #57: __libc_start_main + 0x98 (0xf152f23d8598 in /usr/lib/aarch64-linux-gnu/libc.so.6)
frame #58: _start + 0x30 (0x5f66f0 in VLLM::EngineCore)

@johnnynunez
Copy link
Copy Markdown
Contributor Author

johnnynunez commented Mar 30, 2026

ready to merge! @mgoin

Now it is working perfectly and B200 accuracy tests passed for NVFP4

Nemotron Super NVFP4 - DGX Spark

export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
vllm serve nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 \
--kv-cache-dtype fp8 \
--trust-remote-code \
--gpu-memory-utilization 0.7 \
--max-model-len 262144 \
--max-num-seqs 10 \
--enable-prefix-caching \
--host 0.0.0.0 \
--port 8000 \
--enable-auto-tool-choice \
--load-format fastsafetensors \
--tool-call-parser qwen3_coder \
--reasoning-parser nemotron_v3 \
--mamba_ssm_cache_dtype float32

Results (Benchmark & Stress Test) --> uvx llama-benchy --base-url http://spark:8000/v1 --depth 0 4096 8192 16384 32768 65535 100000 20000:

Auto-detected HF model: nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 (served as: nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4)
llama-benchy (0.3.5)
Date: 2026-03-30 01:35:34
Benchmarking model: nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 at http://localhost:8000/v1
Concurrency levels: [1]
Loading text from cache: /home/johnny/.cache/llama-benchy/cc6a0b5782734ee3b9069aa3b64cc62c.txt
Total tokens available in text corpus: 143827
Warming up...
Warmup (User only) complete. Delta: 16 tokens (Server: 38, Local: 22)
Warmup (System+Empty) complete. Delta: 16 tokens (Server: 38, Local: 22)

Running coherence test...
Coherence test PASSED.
Measuring latency using mode: api...
Average latency (api): 1.63 ms
Running test: pp=2048, tg=32, depth=0, concurrency=1
  Run 1/3 (batch size 1)...
  Run 2/3 (batch size 1)...
  Run 3/3 (batch size 1)...
Running test: pp=2048, tg=32, depth=4096, concurrency=1
  Run 1/3 (batch size 1)...
  Run 2/3 (batch size 1)...
  Run 3/3 (batch size 1)...
Running test: pp=2048, tg=32, depth=8192, concurrency=1
  Run 1/3 (batch size 1)...
  Run 2/3 (batch size 1)...
  Run 3/3 (batch size 1)...
Running test: pp=2048, tg=32, depth=16384, concurrency=1
  Run 1/3 (batch size 1)...
  Run 2/3 (batch size 1)...
  Run 3/3 (batch size 1)...
Running test: pp=2048, tg=32, depth=32768, concurrency=1
  Run 1/3 (batch size 1)...
  Run 2/3 (batch size 1)...
  Run 3/3 (batch size 1)...
Running test: pp=2048, tg=32, depth=65535, concurrency=1
  Run 1/3 (batch size 1)...
  Run 2/3 (batch size 1)...
  Run 3/3 (batch size 1)...
Running test: pp=2048, tg=32, depth=100000, concurrency=1
  Run 1/3 (batch size 1)...
  Run 2/3 (batch size 1)...
  Run 3/3 (batch size 1)...
Running test: pp=2048, tg=32, depth=200000, concurrency=1
  Run 1/3 (batch size 1)...
  Run 2/3 (batch size 1)...
  Run 3/3 (batch size 1)...
Printing results in MD format:



| model                                          |             test |              t/s |     peak t/s |          ttfr (ms) |       est_ppt (ms) |      e2e_ttft (ms) |
|:-----------------------------------------------|-----------------:|-----------------:|-------------:|-------------------:|-------------------:|-------------------:|
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |           pp2048 | 1722.48 ± 394.11 |              |   1269.76 ± 345.98 |   1268.14 ± 345.98 |   1269.84 ± 345.98 |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |             tg32 |     12.76 ± 0.01 | 13.00 ± 0.00 |                    |                    |                    |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |   pp2048 @ d4096 |  1948.06 ± 80.28 |              |   3161.05 ± 134.07 |   3159.43 ± 134.07 |   3161.13 ± 134.05 |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |     tg32 @ d4096 |     12.75 ± 0.01 | 13.00 ± 0.00 |                    |                    |                    |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |   pp2048 @ d8192 |   1964.84 ± 4.14 |              |    5213.28 ± 10.99 |    5211.65 ± 10.99 |    5213.35 ± 10.97 |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |     tg32 @ d8192 |     12.71 ± 0.01 | 13.00 ± 0.00 |                    |                    |                    |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |  pp2048 @ d16384 |   1934.31 ± 5.53 |              |    9530.67 ± 27.20 |    9529.04 ± 27.20 |    9530.74 ± 27.22 |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |    tg32 @ d16384 |     12.64 ± 0.01 | 13.00 ± 0.00 |                    |                    |                    |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |  pp2048 @ d32768 |  1857.07 ± 14.17 |              |  18750.32 ± 143.56 |  18748.69 ± 143.56 |  18750.39 ± 143.57 |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |    tg32 @ d32768 |     12.64 ± 0.02 | 13.00 ± 0.00 |                    |                    |                    |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |  pp2048 @ d65535 |   1759.29 ± 5.89 |              |  38416.91 ± 128.78 |  38415.28 ± 128.78 |  38416.98 ± 128.78 |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |    tg32 @ d65535 |     12.64 ± 0.04 | 13.00 ± 0.00 |                    |                    |                    |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 | pp2048 @ d100000 |   1656.44 ± 4.33 |              |  61608.98 ± 160.90 |  61607.35 ± 160.90 |  61609.06 ± 160.91 |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |   tg32 @ d100000 |     12.69 ± 0.08 | 13.67 ± 0.47 |                    |                    |                    |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 | pp2048 @ d200000 |   1397.08 ± 7.47 |              | 144626.89 ± 771.10 | 144625.26 ± 771.10 | 144626.94 ± 771.11 |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |   tg32 @ d200000 |     12.59 ± 0.12 | 14.00 ± 0.00 |                    |                    |                    |

llama-benchy (0.3.5)
date: 2026-03-30 01:35:34 | latency mode: api
(APIServer pid=33932) INFO 03-30 01:50:49 [loggers.py:259] Engine 000: Avg prompt throughput: 20205.7 tokens/s, Avg generation throughput: 3.2 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
(APIServer pid=33932) INFO 03-30 01:50:59 [loggers.py:259] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%

Comment on lines +402 to +405
# Currently FI requires bfloat16 routing bias.
# https://github.com/flashinfer-ai/flashinfer/issues/2909
if e_score_correction_bias is not None:
e_score_correction_bias = e_score_correction_bias.to(torch.bfloat16)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pavanimajety do you know if this is right? I thought we fixed this issue for trtllm MoE across the board

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the fix is done by @wzhao18

Copy link
Copy Markdown
Contributor

@wzhao18 wzhao18 Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WIth new FI version, there are various CI failures with accuracy collapse. I rooted down the cause to these.

For reproducing the issue, can run gsm8k on nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8 and nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 on current main and flashinfer 0.6.7

Comment on lines +315 to +319
# Currently FI requires bfloat16 routing bias.
# https://github.com/flashinfer-ai/flashinfer/issues/2909
if e_score_correction_bias is not None:
e_score_correction_bias = e_score_correction_bias.to(torch.bfloat16)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same @wzhao18

Signed-off-by: Johnny <johnnynuca14@gmail.com>
@johnnynunez johnnynunez requested a review from noooop as a code owner March 30, 2026 02:20
@mgoin mgoin mentioned this pull request Mar 30, 2026
5 tasks
Copy link
Copy Markdown
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#38188
There is a flashiner version update PR here, not sure if we want to land it separately

@wzhao18
Copy link
Copy Markdown
Contributor

wzhao18 commented Mar 30, 2026

I see some eval tests still failing. Do we have clues on those?

@mgoin
Copy link
Copy Markdown
Member

mgoin commented Mar 30, 2026

  • Distributed DP Tests (2 GPUs) should be fine, this is a flaky test and for FlashAttn v1/distributed/test_eagle_dp.py::test_run_eagle_dp[FLASH_ATTN]. This is also the same as Model Runner V2 Distributed (2 GPUs)

  • LM Eval Large Models (H200) is fine, resolved on main

  • GPQA Eval (GPT-OSS) (B200) I'm not sure about this timeout, seems to be one request stuck running, this kernel might have been changed in the flashinfer update gpt-oss-20b-sm100-fi-mxfp4-mxfp8-trtllm

  • Distributed Tests (2 GPUs)(H100) I'm not sure about, haven't seen this failure in CI before FAILED tests/v1/distributed/test_dbo.py::test_dbo_dp_ep_gsm8k[deepep_low_latency] - AssertionError: DBO+DP+EP accuracy too low (deepep_low_latency): 0.000 < 0.620

I retried both of the last two

@wzhao18
Copy link
Copy Markdown
Contributor

wzhao18 commented Mar 30, 2026

@mgoin Is this expected?

FAILED evals/gsm8k/test_gsm8k_correctness.py::test_gsm8k_correctness[Nemotron-3-Super-120B-A12B-BF16] - AssertionError: GSM8K metric too low: 0.7854 < 0.9300 - 0.0800 = 0.8500
--
assert np.float64(0.7854435178165277) >= (0.93 - 0.08)

@mgoin
Copy link
Copy Markdown
Member

mgoin commented Mar 30, 2026

@wzhao18 Yes it is resolved by #38556

@vllm-bot vllm-bot merged commit b4a2f3a into vllm-project:main Mar 30, 2026
176 of 180 checks passed
@github-project-automation github-project-automation Bot moved this from Ready to Done in NVIDIA Mar 30, 2026
@wzhao18
Copy link
Copy Markdown
Contributor

wzhao18 commented Mar 30, 2026

@mgoin Thanks!

neweyes pushed a commit to neweyes/vllm that referenced this pull request Mar 31, 2026
Signed-off-by: johnnynunez <johnnynuca14@gmail.com>
Signed-off-by: Johnny <johnnynuca14@gmail.com>
Signed-off-by: neweyes <328719365@qq.com>
aleozlx pushed a commit to flashinfer-ai/flashinfer that referenced this pull request Apr 1, 2026
…es on SM12x (#2913)

### Summary

- Add missing `-DCUTLASS_ENABLE_GDC_FOR_SM100=1` compile flag to all
CUTLASS fused MoE JIT modules (SM100/SM103/SM120) and
`-DCUTLASS_ENABLE_GDC_FOR_SM90=1` to SM90 modules
- Sync nv_internal `grid_dependency_control.h` with upstream CUTLASS to
support SM100/SM103/SM110/SM120/SM121 GDC
- Add `-DCUTLASS_ENABLE_GDC_FOR_SM90=1` to FP8 blockscale GEMM SM90
module

### Problem

Random `cudaErrorIllegalInstruction` crashes on DGX Spark (SM121) and
RTX 50-series (SM120) when running NVFP4 MoE models (e.g., Nemotron,
Qwen3.5-122B) under load. The crashes are intermittent and worsen with
longer context lengths and higher concurrency.

**Root cause:** PR #2780 fixed the missing GDC compile flags for GEMM
modules (`flashinfer/jit/gemm/core.py`), but the **CUTLASS fused MoE
modules** in `flashinfer/jit/fused_moe.py` and the **FP8 blockscale GEMM
module** were not fixed. This is the exact same class of bug as #2708.

Without `-DCUTLASS_ENABLE_GDC_FOR_SM100=1`, CUTLASS's
`grid_dependency_control.h` compiles `wait_on_dependent_grids()` and
`launch_dependent_grids()` as **empty no-ops**:

```cpp
CUTLASS_DEVICE void wait_on_dependent_grids() {
#if (defined(CUTLASS_GDC_ENABLED))   // ← not defined without the flag
  asm volatile("griddepcontrol.wait;");
#endif
}
```

Meanwhile, the host-side code still sets
`programmaticStreamSerializationAllowed = true` (PDL enabled) via
`device_support_pdl()` which returns `True` for all `major >= 9`,
including SM12x. This means:

1. **Host enables PDL** → CUDA runtime overlaps consecutive kernels
2. **Device GDC barriers are no-ops** → No synchronization between
overlapping kernels
3. **Race condition** → Dependent kernel reads stale global memory →
corruption → `cudaErrorIllegalInstruction`

The crash is random because it depends on exact kernel scheduling
timing, which varies per request.

### Fix

**`flashinfer/jit/fused_moe.py`** — Added GDC flags to all CUTLASS fused
MoE modules:

| Module | Flag | Architectures Covered |
|---|---|---|
| `gen_cutlass_fused_moe_sm120_module()` |
`-DCUTLASS_ENABLE_GDC_FOR_SM100=1` | SM120, SM121 |
| `gen_cutlass_fused_moe_sm103_module()` |
`-DCUTLASS_ENABLE_GDC_FOR_SM100=1` | SM103, SM120, SM121 |
| `gen_cutlass_fused_moe_sm100_module()` |
`-DCUTLASS_ENABLE_GDC_FOR_SM100=1` | SM100, SM110, SM120, SM121 |
| `gen_cutlass_fused_moe_sm90_module()` |
`-DCUTLASS_ENABLE_GDC_FOR_SM90=1` | SM90 |
| `gen_trtllm_gen_fused_moe_sm100_module()` |
`-DCUTLASS_ENABLE_GDC_FOR_SM100=1` | SM100+, SM120, SM121 |

**`flashinfer/jit/gemm/fp8_blockscale.py`** — Added
`-DCUTLASS_ENABLE_GDC_FOR_SM90=1` to
`gen_fp8_blockscale_gemm_sm90_module()`.

**`csrc/nv_internal/.../grid_dependency_control.h`** — Synced with
upstream CUTLASS
(`3rdparty/cutlass/include/cutlass/arch/grid_dependency_control.h`) to
add SM100+ GDC support. Previously only handled SM90, so any nv_internal
TensorRT-LLM code compiled for SM12x would have GDC barriers silently
compiled as no-ops.

### Why `-DCUTLASS_ENABLE_GDC_FOR_SM100=1` covers SM12x

CUTLASS uses a single flag for the entire Blackwell family. From
`grid_dependency_control.h`:

```cpp
#if(CUDA_BARRIER_ENABLED && defined(CUTLASS_ENABLE_GDC_FOR_SM100) && defined(__CUDA_ARCH__) && \
    ((__CUDA_ARCH__ == 1000 && ...) ||   // SM100
     (__CUDA_ARCH__ == 1030 && ...) ||   // SM103
     (__CUDA_ARCH__ == 1100 && ...) ||   // SM110
     (__CUDA_ARCH__ == 1200 && ...) ||   // SM120 (RTX 50-series)
     (__CUDA_ARCH__ == 1210 && ...)))    // SM121 (DGX Spark)
#define CUTLASS_GDC_ENABLED
```

### Why SM90 GDC flag was NOT added to SM100+ modules

PR #2716 attempted to add both `-DCUTLASS_ENABLE_GDC_FOR_SM90=1` and
`-DCUTLASS_ENABLE_GDC_FOR_SM100=1` to all modules. It broke AOT builds
because `sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp`
checks `CUTLASS_ENABLE_GDC_FOR_SM90` and calls
`scheduler.is_last_tile()` — a method not present on the SM120
scheduler. PR #2780 corrected this by using only the SM100 flag for
SM100+ modules. This PR follows the same approach.

### Related

- #2708 — Original issue: missing GDC flags cause PDL race condition
- #2716 — First fix attempt (reverted — broke AOT)
- #2780 — Corrected fix for GEMM modules only
-
[vllm-project/vllm#38423](vllm-project/vllm#38423)
— NVFP4 bugfix on DGX Spark
- [NVIDIA/cutlass#3121](NVIDIA/cutlass#3121) —
K=64 block-scaled GEMM tiles (separate issue)

### Test plan

- [x] Clear JIT cache: `rm -rf ~/.cache/flashinfer/`
- [x] Run NVFP4 MoE model on SM121 (DGX Spark) with 128K context under
load — verify no `cudaErrorIllegalInstruction`
- [x] Run NVFP4 MoE model on SM120 (RTX 50-series) with concurrent
requests — verify no NaN/garbage output
- [x] Verify `CUDA_LAUNCH_BLOCKING=1` workaround is no longer needed
- [x] AOT build with `FLASHINFER_CUDA_ARCH_LIST="12.1a"` completes
without errors
- [x] SM90 (Hopper) fused MoE tests pass: `pytest tests/moe/`
- [x] SM100 GEMM tests still pass (no regression from existing GDC
flags)


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Expanded GPU kernel compilation support: enabled additional
optimizations for NVIDIA SM100 and SM90 GPUs, activating
dependency-control optimizations where available.
* Updated JIT/GEMM build configs to include these architecture-specific
compile options, improving performance and compatibility on supported
hardware.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
puririshi98 pushed a commit to puririshi98/vllm that referenced this pull request Apr 7, 2026
Signed-off-by: johnnynunez <johnnynuca14@gmail.com>
Signed-off-by: Johnny <johnnynuca14@gmail.com>
Signed-off-by: Rishi Puri <riship@nvidia.com>
mtparet pushed a commit to blackfuel-ai/vllm that referenced this pull request Apr 9, 2026
Signed-off-by: johnnynunez <johnnynuca14@gmail.com>
Signed-off-by: Johnny <johnnynuca14@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ci/build nvidia ready ONLY add when PR is ready to merge/full CI is needed ready-run-all-tests Trigger CI with all tests for wide-ranging PRs

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

7 participants