Skip to content

[Bugfix] Fix SM121 (DGX Spark) exclusion from Marlin/CUTLASS FP8 paths#35568

Merged
vllm-bot merged 5 commits into
vllm-project:mainfrom
blake-snc:fix/marlin-sm12x-capability-check
May 15, 2026
Merged

[Bugfix] Fix SM121 (DGX Spark) exclusion from Marlin/CUTLASS FP8 paths#35568
vllm-bot merged 5 commits into
vllm-project:mainfrom
blake-snc:fix/marlin-sm12x-capability-check

Conversation

@blake-snc

@blake-snc blake-snc commented Feb 28, 2026

Copy link
Copy Markdown
Contributor

Summary

SM121 (DGX Spark GB10) shares the same FP8 MMA capabilities as SM120 (RTX 5090) — both support native mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32. However, SM121 is excluded from all Marlin and CUTLASS FP8 codepaths by exact-match arch guards (== 120, in [89, 120], enable_sm120_only).

This fixes 8 locations across codegen, runtime, dispatch, and tests using bounded SM12x family checks (arch // 10 == 12, major_capability == 12, enable_sm120_family, is_device_capability_family(120)):

Codegen (FP8 kernel template generation):

  • csrc/quantization/marlin/generate_kernels.py: arch in [89, 120]arch == 89 or arch // 10 == 12
  • csrc/moe/marlin_moe_wna16/generate_kernels.py: same fix

Runtime (FP8 activation gate):

  • csrc/moe/marlin_moe_wna16/ops.cu: == 120major_capability == 12

CUTLASS FP8 dispatch (kernel wrapper):

  • csrc/quantization/w8a8/cutlass/c3x/scaled_mm.cuh: enable_sm120_onlyenable_sm120_family
  • csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh: same fix

Tests (FP8 test case generation):

  • tests/kernels/moe/test_moe.py: get_device_capability() not in [89, 120] → proper is_device_capability(89) / is_device_capability_family(120) API calls
  • tests/kernels/quantization/test_marlin_gemm.py: same fix

Python-side FP8 input validation:

  • vllm/model_executor/layers/quantization/utils/marlin_utils.py: is_device_capability(120)is_device_capability_family(120)

All checks use bounded SM12x family matching (covers SM120/SM121 but won't accidentally match future SM13x).

The enable_sm120_onlyenable_sm120_family change in the CUTLASS dispatch headers also resolves the CUTLASS FP4 GEMM failure on SM121 reported in #30163 ("Failed to run cutlass FP4 gemm on sm120. Error: Error Internal"), since enable_sm120_only uses __CUDA_ARCH__ == 1200 which excludes SM121 (__CUDA_ARCH__ == 1210), while enable_sm120_family uses >= 1200 && < 1300.

Validation

Tested on DGX Spark (NVIDIA GB10, SM121a / capability 12.1):

Marlin FP4 GEMM (all 5 configs including N=100544): PASS
CUTLASS FP4 dispatch: cutlass_scaled_mm_supports_fp4(121) = True
Capability check logic:

SM89 (Ada):   allowed via exact match ✓
SM90 (Hopper): blocked ✓
SM120 (RTX 5090): allowed ✓
SM121 (DGX Spark): allowed ✓
SM130 (future): not matched ✓

Subsumes #35803. Fixes #35432. Fixes #30163. Relates to #30135.

Contributed by Second Nature Computing (https://joinsecondnature.com)

Test plan

  • Validated on SM121a hardware (DGX Spark)
  • Marlin FP4 GEMM passes all 5 test configs
  • enable_sm120_family verified in common.hpp with correct >= 1200 && < 1300 range guard
  • is_device_capability_family(120) verified: uses to_int() // 10 == 120 // 10
  • Pre-commit hooks pass

🤖 Generated with Claude Code

@github-actions

Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify Bot added the bug Something isn't working label Feb 28, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the device capability check for Marlin W4A8-FP8 support to include newer GPU architectures. The check is changed from an exact match for compute capability 12.0 (is_device_capability(120)) to a check for 12.0 or higher (has_device_capability(120)). This is intended to enable support on devices such as Blackwell variants that report compute capabilities like 12.1. The error message is also updated to reflect this change, now indicating support for SM120+ devices.

@blake-snc blake-snc force-pushed the fix/marlin-sm12x-capability-check branch from 30d8763 to f4b19a7 Compare February 28, 2026 02:33
scottgl9 added a commit to scottgl9/vllm that referenced this pull request Mar 2, 2026
Change is_device_capability(120) to has_device_capability(120) so
SM121 (GB10) passes the >= comparison for Marlin W4A8-FP8 support.
is_device_capability checks for exact match only.

Ref: vllm-project#35568
blake-snc added a commit to blake-snc/vllm that referenced this pull request Mar 2, 2026
SM121 (DGX Spark GB10) shares the same FP8 MMA capabilities as SM120
(RTX 5090) but is excluded by exact-match arch guards throughout the
Marlin and CUTLASS FP8 codepaths. This fixes 8 locations:

- generate_kernels.py (Marlin + MoE): `arch in [89, 120]` → `arch == 89
  or arch >= 120` so SM121 FP8 kernel templates are generated
- ops.cu (MoE Marlin): `== 120` → `>= 120` in runtime FP8 activation
  gate
- scaled_mm_sm120_fp8_dispatch.cuh + scaled_mm.cuh: `enable_sm120_only`
  → `enable_sm120_family` so CUTLASS FP8 GEMM kernels run on SM121
- test_moe.py + test_marlin_gemm.py: fix FP8 test skip using proper
  `is_device_capability(89)` / `is_device_capability_family(120)` APIs
  instead of broken `get_device_capability() not in [89, 120]`
  (NamedTuple vs int comparison)
- marlin_utils.py: `is_device_capability(120)` →
  `is_device_capability_family(120)` for Python-side FP8 input check

Companion to vllm-project#35568 which fixes the runtime Marlin FP8 gate in
marlin.cu.

Contributed by Second Nature Computing (https://joinsecondnature.com)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@blake-snc blake-snc requested a review from WoosukKwon as a code owner March 3, 2026 05:58
@mergify mergify Bot added the nvidia label Mar 3, 2026
@blake-snc blake-snc changed the title [Bugfix] Fix Marlin W4A8-FP8 check for SM121+ Blackwell variants [Bugfix] Fix SM121 (DGX Spark) exclusion from Marlin/CUTLASS FP8 paths Mar 3, 2026
scottgl9 added a commit to scottgl9/vllm that referenced this pull request Mar 4, 2026
Change is_device_capability(120) to has_device_capability(120) so
SM121 (GB10) passes the >= comparison for Marlin W4A8-FP8 support.
is_device_capability checks for exact match only.

Ref: vllm-project#35568
scottgl9 added a commit to scottgl9/vllm that referenced this pull request Mar 4, 2026
Cherry-pick upstream fixes for GB10 Spark (SM121):

- PR vllm-project#35568: Recognize SM121 as SM120 family for Marlin/CUTLASS FP8
  kernels (generate_kernels.py, ops.cu, scaled_mm*.cuh, marlin_utils.py)
- PR vllm-project#35675: Fix Qwen3.5 MTP fc layer weight shape mismatch with NVFP4
  by using ReplicatedLinear with quant_config=None
- PR vllm-project#35833: FP8 KV cache for Triton MLA decode on Blackwell — adds
  on-the-fly FP8 dequantization in Triton kernels
- PR vllm-project#35936: tool_choice="required" falls back to tool_parser for
  non-JSON (XML) tool calls from Qwen3 models

Local patches:
- Patch FlashInfer TRTLLM JIT to compile for SM12x
  (supported_major_versions=[10] → [10, 12])
- Skip VLLM_TEST_FORCE_FP8_MARLIN for NVFP4 MoE (not SM121-ready)
scottgl9 added a commit to scottgl9/vllm that referenced this pull request Mar 4, 2026
- Remove VLLM_TEST_FORCE_FP8_MARLIN=1 (CUTLASS FP8 now works on SM121
  via enable_sm120_family from PR vllm-project#35568)
- Make VLLM_USE_FLASHINFER_MOE_FP4 overridable (default still 0) so
  users can test FlashInfer TRTLLM MoE on SM121 after JIT patch
- Add auto-kill of existing vLLM server before launch (prevents GPU OOM
  on GB10 unified memory)
- Skip VLLM_TEST_FORCE_FP8_MARLIN in NVFP4 MoE oracle (not SM121-ready
  for that path)
scottgl9 added a commit to scottgl9/vllm that referenced this pull request Mar 5, 2026
Change is_device_capability(120) to has_device_capability(120) so
SM121 (GB10) passes the >= comparison for Marlin W4A8-FP8 support.
is_device_capability checks for exact match only.

Ref: vllm-project#35568
scottgl9 added a commit to scottgl9/vllm that referenced this pull request Mar 5, 2026
Cherry-pick upstream fixes for GB10 Spark (SM121):

- PR vllm-project#35568: Recognize SM121 as SM120 family for Marlin/CUTLASS FP8
  kernels (generate_kernels.py, ops.cu, scaled_mm*.cuh, marlin_utils.py)
- PR vllm-project#35675: Fix Qwen3.5 MTP fc layer weight shape mismatch with NVFP4
  by using ReplicatedLinear with quant_config=None
- PR vllm-project#35833: FP8 KV cache for Triton MLA decode on Blackwell — adds
  on-the-fly FP8 dequantization in Triton kernels
- PR vllm-project#35936: tool_choice="required" falls back to tool_parser for
  non-JSON (XML) tool calls from Qwen3 models

Local patches:
- Patch FlashInfer TRTLLM JIT to compile for SM12x
  (supported_major_versions=[10] → [10, 12])
- Skip VLLM_TEST_FORCE_FP8_MARLIN for NVFP4 MoE (not SM121-ready)
blake-snc added a commit to blake-snc/vllm that referenced this pull request Mar 5, 2026
SM121 (DGX Spark GB10) shares the same FP8 MMA capabilities as SM120
(RTX 5090) but is excluded by exact-match arch guards throughout the
Marlin and CUTLASS FP8 codepaths. This fixes 8 locations:

- generate_kernels.py (Marlin + MoE): `arch in [89, 120]` → `arch == 89
  or arch >= 120` so SM121 FP8 kernel templates are generated
- ops.cu (MoE Marlin): `== 120` → `>= 120` in runtime FP8 activation
  gate
- scaled_mm_sm120_fp8_dispatch.cuh + scaled_mm.cuh: `enable_sm120_only`
  → `enable_sm120_family` so CUTLASS FP8 GEMM kernels run on SM121
- test_moe.py + test_marlin_gemm.py: fix FP8 test skip using proper
  `is_device_capability(89)` / `is_device_capability_family(120)` APIs
  instead of broken `get_device_capability() not in [89, 120]`
  (NamedTuple vs int comparison)
- marlin_utils.py: `is_device_capability(120)` →
  `is_device_capability_family(120)` for Python-side FP8 input check

Companion to vllm-project#35568 which fixes the runtime Marlin FP8 gate in
marlin.cu.

Contributed by Second Nature Computing (https://joinsecondnature.com)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
@blake-snc blake-snc force-pushed the fix/marlin-sm12x-capability-check branch from 2cb48d7 to 8092825 Compare March 12, 2026 21:36
@blake-snc

Copy link
Copy Markdown
Contributor Author

Updated — DCO sign-off has been added to all commits. Ready for review.

@blake-snc

Copy link
Copy Markdown
Contributor Author

@scottgl9 I see you have cherry-picked a good bit of this PR - is there anything left in this PR worth keeping it open for from your end?

scottgl9 added a commit to scottgl9/vllm that referenced this pull request Mar 18, 2026
Change is_device_capability(120) to has_device_capability(120) so
SM121 (GB10) passes the >= comparison for Marlin W4A8-FP8 support.
is_device_capability checks for exact match only.

Ref: vllm-project#35568
blake-snc and others added 4 commits April 9, 2026 10:58
…ariants)

`get_marlin_input_dtype()` uses `is_device_capability(120)` which is an
exact match — SM121 devices (DGX Spark GB10, RTX 5090) return capability
(12, 1) and fail the check, blocking Marlin W4A8-FP8 with a misleading
"only support SM89 or SM120" error.

Changed to `has_device_capability(120)` which uses >= comparison,
allowing SM120 and all Blackwell variants (SM121, SM121a, etc.) while
still correctly blocking SM90 (Hopper) where Marlin FP8 is slower than
W4A16.

The SM89 (Ada) check remains as `is_device_capability(89)` since there
are no Ada variants.

Validated on DGX Spark (NVIDIA GB10, SM121a / capability 12.1):
- Before: `is_device_capability(120)` → False → ValueError raised
- After:  `has_device_capability(120)` → True  → FP8 dtype returned
- SM90 still correctly blocked (has_device_capability(120) → False)
- SM89 still correctly allowed (is_device_capability(89) → True)

Fixes vllm-project#35432
Relates to vllm-project#30135

Contributed by Second Nature Computing (https://joinsecondnature.com)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
SM121 (DGX Spark GB10) shares the same FP8 MMA capabilities as SM120
(RTX 5090) but is excluded by exact-match arch guards throughout the
Marlin and CUTLASS FP8 codepaths. This fixes 8 locations:

- generate_kernels.py (Marlin + MoE): `arch in [89, 120]` → `arch == 89
  or arch >= 120` so SM121 FP8 kernel templates are generated
- ops.cu (MoE Marlin): `== 120` → `>= 120` in runtime FP8 activation
  gate
- scaled_mm_sm120_fp8_dispatch.cuh + scaled_mm.cuh: `enable_sm120_only`
  → `enable_sm120_family` so CUTLASS FP8 GEMM kernels run on SM121
- test_moe.py + test_marlin_gemm.py: fix FP8 test skip using proper
  `is_device_capability(89)` / `is_device_capability_family(120)` APIs
  instead of broken `get_device_capability() not in [89, 120]`
  (NamedTuple vs int comparison)
- marlin_utils.py: `is_device_capability(120)` →
  `is_device_capability_family(120)` for Python-side FP8 input check

Companion to vllm-project#35568 which fixes the runtime Marlin FP8 gate in
marlin.cu.

Contributed by Second Nature Computing (https://joinsecondnature.com)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
Address review feedback: arch >= 120 would incorrectly match future
arch families (SM130+). Use arch // 10 == 12 for codegen and
major_capability == 12 for runtime to scope checks to the SM12x family.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
fused_marlin_moe() now requires vllm_config to be set (via a
CustomOp in its call chain). Add the default_vllm_config pytest
fixture to test_fused_marlin_moe, test_fused_marlin_moe_with_bias,
and test_fused_marlin_moe_non_gated — matching the pattern already
used by test_batched_fused_marlin_moe.

Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
@blake-snc blake-snc force-pushed the fix/marlin-sm12x-capability-check branch from d25fee4 to 974e660 Compare April 9, 2026 18:02
SeraphimSerapis added a commit to SeraphimSerapis/spark-vllm-docker that referenced this pull request Apr 10, 2026
The PR #35568 (SM121/DGX Spark FP8 fix) was applied via curl+git-apply,
which breaks when upstream context lines shift — even though the target
code itself hasn't changed.

Replace with a sed-based shell script (patches/sm121-fp8-fix.sh) that
does exact string substitution, making it immune to upstream context
drift. The script applies the same 8 fixes from PR #35568:

- generate_kernels.py: arch in [89,120] -> arch == 89 or arch // 10 == 12
- ops.cu: == 120 -> == 12 (major capability family check)
- scaled_mm.cuh: enable_sm120_only -> enable_sm120_family
- marlin_utils.py: is_device_capability(120) -> is_device_capability_family(120)

Ref: vllm-project/vllm#35568
@blake-snc

Copy link
Copy Markdown
Contributor Author

The default_vllm_config fixture resolved the config error (7,282 of 7,304 tests now pass, up from 0). The remaining 22 failures are a pre-existing borderline tolerance issue at m=666, K=2048 on L4 — max error 0.04297 vs atol=0.04 on 1-4 elements out of 1.4M. Investigated the kernel: FP32 accumulation is used, partial block handling is correct, and the fixture doesn't affect Marlin dispatch. The error is inherent 4-bit quantization noise that varies by L4 instance (PR #39024 passed all MOE tests on a different build). Filed as #39549.

@mgointest_fused_marlin_moe_with_bias already uses @flaky(reruns=2) for the same kind of borderline behavior. Would it make sense to add that to the base test as well, or would you prefer a different approach?

@AshtonVaughan

Copy link
Copy Markdown

Validated the family-check approach on RTX 5090 (SM 12.0). Author targets SM 12.1 (GB10), this confirms the broader SM 12.x family logic also catches the 12.0 RTX 50-series side correctly.

5090 reports SM 12.0
is_device_capability(120):         False  (exact-match would still exclude 12.0!)
is_device_capability_family(120):  True   (correctly admits 12.0 and 12.1)
arch // 10 == 12:                  True
major_capability == 12:            True

Subtle but important: the old guards used is_device_capability(120) which is exact match for SM 12.0, so they would still incorrectly exclude SM 12.1 even after a partial fix. Switching to is_device_capability_family(120) / major_capability == 12 is the right call - admits both 12.0 and 12.1 with one consistent check.

Note: I am only validating the Python-side gate logic and family-check pattern. The C++ kernel codegen changes (Marlin / CUTLASS) would require a full build to confirm runtime behaviour, which I have not done here.

Would be useful to know if anyone has run the Marlin FP8 kernel tests on a 5090 (SM 12.0) post-build to confirm the codegen path actually emits compatible kernels for both SM 12.0 and 12.1.

@eugr

eugr commented Apr 30, 2026

Copy link
Copy Markdown

@mgoin - can we merge this? It's been a part of my builds for a few weeks, but would be nice to have it merged into main otherwise some FP8 models fail with "unsupported arch" error...

@pavanimajety pavanimajety enabled auto-merge (squash) May 8, 2026 22:12
@DavRodSwede

Copy link
Copy Markdown

Quick deployment-evidence note from a 3-node DGX Spark GB10 (SM121) cluster:

This patch has been shipping in eugr/spark-vllm-docker's Dockerfile (applied inline at build time) since their commit 44808f7 on 2026-04-02 — ~38 days of broader DGX Spark community deployment. On our cluster specifically, the patched image has been running Intel/Qwen3.5-397B-A17B-int4-AutoRound at PP=3 (kv_cache_dtype=fp8) since 2026-05-06.

Runtime indicators that the patch is doing its job on real SM121 silicon:

INFO ... [gptq_marlin.py:387] Using MarlinLinearKernel for GPTQMarlinLinearMethod
INFO ... [int_wna16.py:136] Using 'MARLIN' WNA16 MoE backend.
[repeated 2x across cluster]

Zero Failed to run cutlass-style errors over the container's lifetime. We also re-built fresh from vllm-project/vllm@main + this PR on 2026-05-10 — all 6 source patches apply cleanly, build succeeds.

Hopefully useful as a third-party deployment data point. Happy to provide more detail (logs, full env, hardware specifics) if it'd help unblock the merge. Thanks for the fix, @blake-snc.

@eugr

eugr commented May 15, 2026

Copy link
Copy Markdown

@mgoin - can we merge please?

@mgoin mgoin left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the ping

@vllm-bot vllm-bot merged commit 06d020b into vllm-project:main May 15, 2026
153 of 160 checks passed
@github-project-automation github-project-automation Bot moved this from Ready to Done in NVIDIA May 15, 2026
omerpaz95 pushed a commit to omerpaz95/vllm that referenced this pull request May 18, 2026
vllm-project#35568)

Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Pavani Majety <pmajety@nvidia.com>
omerpaz95 pushed a commit to omerpaz95/vllm that referenced this pull request May 18, 2026
vllm-project#35568)

Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Pavani Majety <pmajety@nvidia.com>
mfylcek pushed a commit to mfylcek/vllm that referenced this pull request May 19, 2026
vllm-project#35568)

Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Pavani Majety <pmajety@nvidia.com>
h1t35h pushed a commit to h1t35h/vllm that referenced this pull request May 21, 2026
vllm-project#35568)

Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Pavani Majety <pmajety@nvidia.com>
Liuweixiong0118 pushed a commit to Liuweixiong0118/vllm that referenced this pull request Jun 1, 2026
vllm-project#35568)

Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Pavani Majety <pmajety@nvidia.com>
Signed-off-by: Liuweixiong0118 <lwx34158427@gmail.com>
mvanhorn pushed a commit to mvanhorn/vllm that referenced this pull request Jun 4, 2026
vllm-project#35568)

Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Pavani Majety <pmajety@nvidia.com>
Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
andakai pushed a commit to andakai/vllm that referenced this pull request Jun 4, 2026
vllm-project#35568)

Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Pavani Majety <pmajety@nvidia.com>
knight0528 pushed a commit to knight0528/vllm that referenced this pull request Jun 8, 2026
vllm-project#35568)

Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Pavani Majety <pmajety@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working nvidia ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

9 participants