fix: Enable SM121 for mm_fp4#2012
Conversation
WalkthroughVersion 12.1 support is added across FP4 backends in benchmarks and the GEMM library. Benchmark utility mappings are extended to recognize "12.1" alongside existing backend options. Three FP4 backend implementations—cudnn, trtllm, and cutlass—expand their SM compute capability support to include SM121. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (2)
🧰 Additional context used🧬 Code graph analysis (1)flashinfer/gemm.py (1)
🔇 Additional comments (4)
Comment |
|
/bot run |
|
|
||
|
|
||
| @supported_compute_capability([100, 103, 120]) | ||
| @supported_compute_capability([100, 103, 120, 121]) |
There was a problem hiding this comment.
110 is also supported if I remember correctly, cc: @ttyio
There was a problem hiding this comment.
It was explicitly disabled on trtllm in the original checks. The other backends support it
|
[CANCELING] Pipeline #37609405: canceled |
|
/bot run |
<!-- .github/pull_request_template.md --> ## 📌 Description In flashinfer-ai#1809 we previously added a compute-capability-based support check for `mm_fp4`. However, we missed enabling SM121 for backend = `cudnn` and `cutlass`. Additionally, we marked `trtllm` as supported on SM120 when it is not. Current PR fixes it. Example benchmark and pytest command on SM121 after the fix ``` (py312) root@f414f262f02a:/flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 8192 --n 7168 --k 512 --out_dtype bfloat16 --backends cudnn cutlass --use_128x4_sf_layout --use_nvfp4 --refcheck --use_cupti /opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285: UserWarning: Found GPU0 NVIDIA GB10 which is of cuda capability 12.1. Minimum and Maximum cuda capability supported by this version of PyTorch is (8.0) - (12.0) warnings.warn( [PERF] cudnn :: median time 0.656 ms; std 0.025 ms; achieved tflops 91.701 TFLOPs/sec; achieved tb_per_sec 0.185 TB/sec [PERF] cutlass :: median time 0.669 ms; std 0.022 ms; achieved tflops 89.859 TFLOPs/sec; achieved tb_per_sec 0.181 TB/sec (py312) root@f414f262f02a:/flashinfer# pytest tests/gemm/test_mm_fp4.py ====================================================================================================================== test session starts ====================================================================================================================== platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0 rootdir: /flashinfer configfile: pytest.ini collected 3240 items ... ======================================================================================================================= warnings summary ======================================================================================================================== ../opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285 /opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285: UserWarning: Found GPU0 NVIDIA GB10 which is of cuda capability 12.1. Minimum and Maximum cuda capability supported by this version of PyTorch is (8.0) - (12.0) warnings.warn( -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ========================================================================================================= 450 passed, 2790 skipped, 1 warning in 8.24s ========================================================================================================== ``` <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Expanded hardware compatibility by adding support for newer NVIDIA GPU architectures. * FP4 quantized operations now available across multiple backends on supported devices. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
In #1809 we previously added a compute-capability-based support check for
mm_fp4.However, we missed enabling SM121 for backend =
cudnnandcutlass.Additionally, we marked
trtllmas supported on SM120 when it is not.Current PR fixes it. Example benchmark and pytest command on SM121 after the fix
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit