Support checks PoC#1809
Conversation
|
The checks currently live very far away from the implementation and updating them to be consistent with each other can eventually become a maintenance problem. The conditional checks are also quite tricky to get correct. For example, it's not easy to tell if the mxfp4 checks are correct. if not use_nvfp4 and block_size != 32:
raise ValueError("mxfp4 supports block_size = 32.")
if backend != "cudnn" and not use_nvfp4:
raise ValueError("Only cudnn FP4 GEMM supports mxfp4 quantization.")Shouldn't the checks be reordered to avoid confusing error messages?
Instead of having one top level
For example: def cudnn_gemm_fp4_requirement(
# ...
):
if (
not use_nvfp4
and _match_sm_version(a.device, ["120"])
and cudnn.backend_version() < 91400
):
raise LibraryError(
"cudnn FP4 GEMM with mxfp4 quantization is not supported on SM120 with cuDNN backend version < 9.14.0."
)
_check_cudnn_fp4_availability()
# ...
@requirement(cudnn_gemm_fp4_requirement, capability=["100", "101", "102"])
def execute_cudnn_gemm_fp4_graph(
# ...
@backend_requirement({
"cudnn": execute_cudnn_gemm_fp4_graph.requirement,
"trtllm": #...
})
def mm_fp4(
# ...This also means that all requirements are enforced to be local to the backend and won't affect each other. |
Thank you @nvjullin for the excellent suggestion.
While both are valid points, we prioritize separating the checks for now. Not all APIs are as cleanly to separate (2) atm and there is a plan for a more OO Backend class @Anerudhan. That does overlap somewhat with the support checks, as eventually we would be able to do something like cudnn_backend->check_mmfp4_support(). So I think as an intermediary step, and to get tighter checks in, we could do something like this: @supported_compute_capability(["100", "101", "102"])
def cudnn_gemm_fp4_requirement(
# ...
):
if (
not use_nvfp4
and _match_sm_version(a.device, ["120"])
and cudnn.backend_version() < 91400
):
raise LibraryError(
"cudnn FP4 GEMM with mxfp4 quantization is not supported on SM120 with cuDNN backend version < 9.14.0."
)
_check_cudnn_fp4_availability()
# ...
@backend_requirement({
"cudnn": execute_cudnn_gemm_fp4_graph.requirement,
"trtllm": #...
},
common_check=common_fp4_checks # To be called by all backend checks
})
def mm_fp4(
)
if backend == "cudnn":
# cudnn path
elif backend == "trtllm":
# trtllm path |
I wasn't aware, thanks for the info. LGTM. |
Thank you for the excellent suggestions, @nvjullin |
3c9f687 to
151fc7e
Compare
|
/bot run |
aleozlx
left a comment
There was a problem hiding this comment.
looks like a good step forward
|
[SUCCESS] Pipeline #36524696: 13/17 passed |
sricketts
left a comment
There was a problem hiding this comment.
Overall LGTM. Added one suggestion.
<!-- .github/pull_request_template.md --> ## 📌 Description In #1809 we previously added a compute-capability-based support check for `mm_fp4`. However, we missed enabling SM121 for backend = `cudnn` and `cutlass`. Additionally, we marked `trtllm` as supported on SM120 when it is not. Current PR fixes it. Example benchmark and pytest command on SM121 after the fix ``` (py312) root@f414f262f02a:/flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 8192 --n 7168 --k 512 --out_dtype bfloat16 --backends cudnn cutlass --use_128x4_sf_layout --use_nvfp4 --refcheck --use_cupti /opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285: UserWarning: Found GPU0 NVIDIA GB10 which is of cuda capability 12.1. Minimum and Maximum cuda capability supported by this version of PyTorch is (8.0) - (12.0) warnings.warn( [PERF] cudnn :: median time 0.656 ms; std 0.025 ms; achieved tflops 91.701 TFLOPs/sec; achieved tb_per_sec 0.185 TB/sec [PERF] cutlass :: median time 0.669 ms; std 0.022 ms; achieved tflops 89.859 TFLOPs/sec; achieved tb_per_sec 0.181 TB/sec (py312) root@f414f262f02a:/flashinfer# pytest tests/gemm/test_mm_fp4.py ====================================================================================================================== test session starts ====================================================================================================================== platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0 rootdir: /flashinfer configfile: pytest.ini collected 3240 items ... ======================================================================================================================= warnings summary ======================================================================================================================== ../opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285 /opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285: UserWarning: Found GPU0 NVIDIA GB10 which is of cuda capability 12.1. Minimum and Maximum cuda capability supported by this version of PyTorch is (8.0) - (12.0) warnings.warn( -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ========================================================================================================= 450 passed, 2790 skipped, 1 warning in 8.24s ========================================================================================================== ``` <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Expanded hardware compatibility by adding support for newer NVIDIA GPU architectures. * FP4 quantized operations now available across multiple backends on supported devices. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
<!-- .github/pull_request_template.md --> ## 📌 Description In flashinfer-ai#1809 we previously added a compute-capability-based support check for `mm_fp4`. However, we missed enabling SM121 for backend = `cudnn` and `cutlass`. Additionally, we marked `trtllm` as supported on SM120 when it is not. Current PR fixes it. Example benchmark and pytest command on SM121 after the fix ``` (py312) root@f414f262f02a:/flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 8192 --n 7168 --k 512 --out_dtype bfloat16 --backends cudnn cutlass --use_128x4_sf_layout --use_nvfp4 --refcheck --use_cupti /opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285: UserWarning: Found GPU0 NVIDIA GB10 which is of cuda capability 12.1. Minimum and Maximum cuda capability supported by this version of PyTorch is (8.0) - (12.0) warnings.warn( [PERF] cudnn :: median time 0.656 ms; std 0.025 ms; achieved tflops 91.701 TFLOPs/sec; achieved tb_per_sec 0.185 TB/sec [PERF] cutlass :: median time 0.669 ms; std 0.022 ms; achieved tflops 89.859 TFLOPs/sec; achieved tb_per_sec 0.181 TB/sec (py312) root@f414f262f02a:/flashinfer# pytest tests/gemm/test_mm_fp4.py ====================================================================================================================== test session starts ====================================================================================================================== platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0 rootdir: /flashinfer configfile: pytest.ini collected 3240 items ... ======================================================================================================================= warnings summary ======================================================================================================================== ../opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285 /opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285: UserWarning: Found GPU0 NVIDIA GB10 which is of cuda capability 12.1. Minimum and Maximum cuda capability supported by this version of PyTorch is (8.0) - (12.0) warnings.warn( -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ========================================================================================================= 450 passed, 2790 skipped, 1 warning in 8.24s ========================================================================================================== ``` <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Expanded hardware compatibility by adding support for newer NVIDIA GPU architectures. * FP4 quantized operations now available across multiple backends on supported devices. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
This PR adds is_*supported checks for backend and compute capability, through decorators.
Example:

🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes