Skip to content

fix: Enable SM121 for mm_fp4#2012

Merged
yzh119 merged 1 commit intoflashinfer-ai:mainfrom
bkryu:enable_spark_mm_fp4
Oct 30, 2025
Merged

fix: Enable SM121 for mm_fp4#2012
yzh119 merged 1 commit intoflashinfer-ai:mainfrom
bkryu:enable_spark_mm_fp4

Conversation

@bkryu
Copy link
Copy Markdown
Collaborator

@bkryu bkryu commented Oct 30, 2025

📌 Description

In #1809 we previously added a compute-capability-based support check for mm_fp4.

However, we missed enabling SM121 for backend = cudnn and cutlass.
Additionally, we marked trtllm as supported on SM120 when it is not.

Current PR fixes it. Example benchmark and pytest command on SM121 after the fix

(py312) root@f414f262f02a:/flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 8192 --n 7168 --k 512 --out_dtype bfloat16 --backends cudnn cutlass --use_128x4_sf_layout --use_nvfp4 --refcheck --use_cupti
/opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285: UserWarning: 
    Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.
    Minimum and Maximum cuda capability supported by this version of PyTorch is
    (8.0) - (12.0)
    
  warnings.warn(
[PERF] cudnn          :: median time 0.656 ms; std 0.025 ms; achieved tflops 91.701 TFLOPs/sec; achieved tb_per_sec 0.185 TB/sec
[PERF] cutlass        :: median time 0.669 ms; std 0.022 ms; achieved tflops 89.859 TFLOPs/sec; achieved tb_per_sec 0.181 TB/sec

(py312) root@f414f262f02a:/flashinfer# pytest tests/gemm/test_mm_fp4.py 
====================================================================================================================== test session starts ======================================================================================================================
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /flashinfer
configfile: pytest.ini
collected 3240 items     
...
======================================================================================================================= warnings summary ========================================================================================================================
../opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285
  /opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285: UserWarning: 
      Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.
      Minimum and Maximum cuda capability supported by this version of PyTorch is
      (8.0) - (12.0)
      
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================================================================================================= 450 passed, 2790 skipped, 1 warning in 8.24s ==========================================================================================================


🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features
    • Expanded hardware compatibility by adding support for newer NVIDIA GPU architectures.
    • FP4 quantized operations now available across multiple backends on supported devices.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Oct 30, 2025

Walkthrough

Version 12.1 support is added across FP4 backends in benchmarks and the GEMM library. Benchmark utility mappings are extended to recognize "12.1" alongside existing backend options. Three FP4 backend implementations—cudnn, trtllm, and cutlass—expand their SM compute capability support to include SM121.

Changes

Cohort / File(s) Summary
Benchmark Backend Version Mappings
benchmarks/routines/flashinfer_benchmark_utils.py
Adds "12.1" version key to dtype-to-backends and mm_fp4 backend mappings, both mapping to ["cudnn", "cutlass"] backends.
GEMM FP4 Backend SM Capabilities
flashinfer/gemm.py
Expands SM compute capability support for three FP4 backend implementations by adding SM121: cudnn FP4 now supports [100, 103, 110, 120, 121]; trtllm FP4 and cutlass FP4 now support [100, 103, 120, 121].

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

  • Changes follow repetitive patterns (adding SM121 capability to multiple decorator annotations and version mappings)
  • No logic modifications, control flow changes, or functional behavioral updates
  • Straightforward configuration/compatibility extensions

Possibly related PRs

Suggested reviewers

  • nvmbreughe
  • Anerudhan
  • yzh119
  • cyx-6

Poem

🐰 A new SM arrives on the scene,
With 12.1 compute, crisp and clean!
FP4 backends now stand tall and wide,
SM121's support, our rabbity pride! ✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title Check ✅ Passed The PR title "fix: Enable SM121 for mm_fp4" is clear, concise, and directly captures the main objective of the changeset. The raw summary shows the changes focus on expanding SM121 (compute capability 12.1) support across multiple FP4 backends by updating decorator annotations and version mappings. The title accurately summarizes this primary change without unnecessary details or vague language, making it easy for teammates to understand the purpose when scanning commit history.
Description Check ✅ Passed The PR description addresses all major template sections adequately. The Description section provides specific details about what was fixed (enabling SM121 for cudnn and cutlass backends, correcting trtllm SM120 support), references the related issue #1809, and includes concrete validation evidence with benchmark results and pytest output showing 450 tests passing. Pre-commit checks are marked complete. While the test checkboxes are not checked, this is reasonable given that the changeset only extends existing compute capability support without adding new test files, and the provided pytest results demonstrate validation. The description contains sufficient detail to understand the changes and their necessity.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b9287c9 and 9d6afe9.

📒 Files selected for processing (2)
  • benchmarks/routines/flashinfer_benchmark_utils.py (1 hunks)
  • flashinfer/gemm.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/gemm.py (1)
flashinfer/utils.py (1)
  • supported_compute_capability (772-852)
🔇 Additional comments (4)
flashinfer/gemm.py (3)

1753-1753: LGTM! SM121 support correctly enabled for cuDNN backend.

The addition of 121 to the supported compute capabilities list enables FP4 GEMM operations on SM12.1 hardware for the cuDNN backend.


1833-1833: LGTM! SM121 support correctly enabled for CUTLASS backend.

The addition of 121 to the supported compute capabilities list enables FP4 GEMM operations on SM12.1 hardware for the CUTLASS backend.


1811-1811: No issues found; code change is correct.

The decorator correctly removes SM120 support from trtllm. The test file at tests/gemm/test_mm_fp4.py:40 already documents that "trtllm gemm does not support SM110/SM120/SM121 GPUs" and has a proper skip condition. No other code in the codebase assumes trtllm supports SM120.

benchmarks/routines/flashinfer_benchmark_utils.py (1)

244-244: LGTM! Benchmark mapping correctly reflects SM12.1 backend support.

The addition of the "12.1" compute capability mapping with ["cudnn", "cutlass"] backends is consistent with the capability changes in flashinfer/gemm.py. The trtllm backend is correctly excluded, matching its lack of SM120/SM121 support.


Comment @coderabbitai help to get the list of available commands and usage tips.

@bkryu bkryu marked this pull request as ready for review October 30, 2025 18:25
@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Oct 30, 2025

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !101 has been created, and the CI pipeline #37609405 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@nvmbreughe nvmbreughe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks, Brian!

Comment thread flashinfer/gemm.py


@supported_compute_capability([100, 103, 120])
@supported_compute_capability([100, 103, 120, 121])
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

110 is also supported if I remember correctly, cc: @ttyio

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was explicitly disabled on trtllm in the original checks. The other backends support it

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[CANCELING] Pipeline #37609405: canceled

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Oct 30, 2025

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !101 has been created, and the CI pipeline #37612579 is currently running. I'll report back once the pipeline job completes.

@yzh119 yzh119 merged commit a5ff033 into flashinfer-ai:main Oct 30, 2025
4 checks passed
@bkryu bkryu deleted the enable_spark_mm_fp4 branch October 31, 2025 00:06
BingooYang pushed a commit to BingooYang/flashinfer that referenced this pull request Mar 13, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

In flashinfer-ai#1809 we previously added a compute-capability-based support check
for `mm_fp4`.

However, we missed enabling SM121 for backend = `cudnn` and  `cutlass`. 
Additionally, we marked `trtllm` as supported on SM120 when it is not.

Current PR fixes it. Example benchmark and pytest command on SM121 after
the fix
```
(py312) root@f414f262f02a:/flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 8192 --n 7168 --k 512 --out_dtype bfloat16 --backends cudnn cutlass --use_128x4_sf_layout --use_nvfp4 --refcheck --use_cupti
/opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285: UserWarning: 
    Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.
    Minimum and Maximum cuda capability supported by this version of PyTorch is
    (8.0) - (12.0)
    
  warnings.warn(
[PERF] cudnn          :: median time 0.656 ms; std 0.025 ms; achieved tflops 91.701 TFLOPs/sec; achieved tb_per_sec 0.185 TB/sec
[PERF] cutlass        :: median time 0.669 ms; std 0.022 ms; achieved tflops 89.859 TFLOPs/sec; achieved tb_per_sec 0.181 TB/sec

(py312) root@f414f262f02a:/flashinfer# pytest tests/gemm/test_mm_fp4.py 
====================================================================================================================== test session starts ======================================================================================================================
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /flashinfer
configfile: pytest.ini
collected 3240 items     
...
======================================================================================================================= warnings summary ========================================================================================================================
../opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285
  /opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285: UserWarning: 
      Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.
      Minimum and Maximum cuda capability supported by this version of PyTorch is
      (8.0) - (12.0)
      
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================================================================================================= 450 passed, 2790 skipped, 1 warning in 8.24s ==========================================================================================================


```

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **New Features**
* Expanded hardware compatibility by adding support for newer NVIDIA GPU
architectures.
* FP4 quantized operations now available across multiple backends on
supported devices.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants