feat: Add backend='auto' to mm_fp4 and enable autotune for backend='cudnn'#1979
feat: Add backend='auto' to mm_fp4 and enable autotune for backend='cudnn'#1979yzh119 merged 19 commits intoflashinfer-ai:mainfrom
Conversation
WalkthroughReplaces static mm_fp4 backend listings with runtime support checks and an "auto" backend selector; adds cuDNN/CUTLASS FP4 runner factories, tactic-aware graph execution, runtime backend validation/pruning in benchmarks, CLI Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant MM as mm_fp4(auto)
participant Heu as _heuristic_func_mm_fp4
participant RunnerC as cudnn_runner
participant RunnerU as cutlass_runner
participant Bench as Benchmark/Test
User->>MM: call mm_fp4(..., backend="auto")
MM->>Heu: evaluate shapes/CUDA/cuDNN to rank candidates
Heu-->>MM: ordered backend candidates
loop try candidates (lazy init)
MM->>RunnerC: init & capability trial
RunnerC-->>MM: success / fail
MM->>RunnerU: init & capability trial
RunnerU-->>MM: success / fail
end
alt some backends failed
MM->>MM: prune unsupported backends
end
MM->>MM: autotune & warmup across remaining backends
Bench->>MM: run cross-backend validation (cosine >= 0.97)
MM-->>User: return execution result from chosen backend
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Focus areas:
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
bc94c4c to
254827a
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
flashinfer/gemm.py (1)
2096-2134: Consider extracting auto-backend selection into a helper function.The auto-backend selection logic (lines 2096-2134) is complex and involves:
- CUDA/cuDNN version inspection
- Backend ordering heuristics
- Problem size validation
- Exception handling for unsupported configurations
This logic could benefit from extraction into a dedicated helper function (e.g.,
_select_mm_fp4_backends) to improve readability and testability.Additionally, the bare
except Exceptionat lines 2131-2132 might hide unexpected errors. Consider either:
- Catching more specific exceptions (e.g.,
ValueError,RuntimeError)- Adding logging to track which backends fail validation and why
Example refactoring:
def _select_mm_fp4_backends( cuda_major: int, cudnn_version: int, a: torch.Tensor, b: torch.Tensor, a_descale: torch.Tensor, b_descale: torch.Tensor, alpha: Optional[torch.Tensor], out_dtype: torch.dtype, out: torch.Tensor, block_size: int, use_8x4_sf_layout: bool, use_nvfp4: bool, ) -> List[str]: """Select supported backends for mm_fp4 based on device capabilities.""" # Backend ordering heuristics if cuda_major >= 13 and cudnn_version >= 91400: candidate_backends = ("cudnn", "cutlass") else: candidate_backends = ("cutlass", "cudnn") # Filter by problem size support backends = [] for candidate in candidate_backends: try: _check_mm_fp4_problem_size( a, b, a_descale, b_descale, alpha, out_dtype, out, block_size, use_8x4_sf_layout, cast(Literal["cudnn", "trtllm", "cutlass", "auto"], candidate), use_nvfp4, ) backends.append(candidate) except (ValueError, RuntimeError): pass # Backend not supported for this problem return backends
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between b9287c9 and 254827a5c54c10d69705743d5068e9acd7299776.
📒 Files selected for processing (4)
benchmarks/routines/flashinfer_benchmark_utils.py(1 hunks)benchmarks/routines/gemm.py(5 hunks)flashinfer/gemm.py(17 hunks)tests/gemm/test_mm_fp4.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/gemm.py (4)
flashinfer/jit/cpp_ext.py (1)
get_cuda_version(64-83)flashinfer/autotuner.py (9)
TunableRunner(194-247)get_valid_tactics(196-214)OptimizationProfile(168-183)forward(220-244)AutoTuner(335-784)get(362-365)TuningConfig(101-141)choose_one(400-529)get_opt_shapes(177-183)flashinfer/trtllm_low_latency_gemm.py (2)
get_valid_tactics(52-77)forward(79-109)flashinfer/utils.py (4)
supported_compute_capability(772-852)get_compute_capability(251-254)is_compute_capability_supported(966-972)backend_requirement(855-1028)
🪛 Ruff (0.14.2)
flashinfer/gemm.py
96-96: Unused function argument: device
(ARG001)
432-432: Unused method argument: inputs
(ARG002)
433-433: Unused method argument: profile
(ARG002)
441-441: Unused method argument: do_preparation
(ARG002)
442-442: Unused method argument: kwargs
(ARG002)
1722-1722: Unused method argument: profile
(ARG002)
1733-1733: Unpacked variable out is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1736-1736: Unpacked variable workspace_buffer is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1772-1772: Unused method argument: do_preparation
(ARG002)
1773-1773: Unused method argument: kwargs
(ARG002)
1855-1855: Avoid specifying long messages outside the exception class
(TRY003)
1876-1876: Unused function argument: backend
(ARG001)
1934-1934: Unused function argument: backend
(ARG001)
1956-1956: Unused function argument: backend
(ARG001)
1957-1957: Unused function argument: use_nvfp4
(ARG001)
1965-1965: Unused function argument: b
(ARG001)
1966-1966: Unused function argument: a_descale
(ARG001)
1967-1967: Unused function argument: b_descale
(ARG001)
1968-1968: Unused function argument: alpha
(ARG001)
1969-1969: Unused function argument: out_dtype
(ARG001)
1970-1970: Unused function argument: out
(ARG001)
1971-1971: Unused function argument: block_size
(ARG001)
1972-1972: Unused function argument: use_8x4_sf_layout
(ARG001)
1973-1973: Unused function argument: backend
(ARG001)
1974-1974: Unused function argument: use_nvfp4
(ARG001)
2099-2099: Unpacked variable cc_major is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2099-2099: Unpacked variable cc_minor is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2131-2132: try-except-pass detected, consider logging the exception
(S110)
2131-2131: Do not catch blind exception: Exception
(BLE001)
2163-2163: Avoid specifying long messages outside the exception class
(TRY003)
2509-2509: Unpacked variable a_descale is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2510-2510: Unpacked variable b_descale is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2511-2511: Unpacked variable alpha is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2513-2513: Unpacked variable out is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2516-2516: Unpacked variable workspace_buffer is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2530-2530: Unused method argument: do_preparation
(ARG002)
2531-2531: Unused method argument: kwargs
(ARG002)
🔇 Additional comments (12)
benchmarks/routines/flashinfer_benchmark_utils.py (1)
241-243: LGTM! Auto backend addition is correct.The addition of "auto" to the supported backends list for
mm_fp4at compute capabilities 10.0, 10.3, and 12.0 is consistent with the PR objectives and aligns with the auto-backend implementation inflashinfer/gemm.py.benchmarks/routines/gemm.py (2)
134-134: LGTM! Backend choices updated correctly.The addition of "auto" to the --backends argument choices is consistent with the auto-backend support introduced in this PR.
793-793: LGTM! Auto backend support properly integrated.The changes correctly:
- Add "auto" to the list of autotune-supported backends for
mm_fp4- Implement backend filtering logic for "auto" that respects the
use_128x4_sf_layoutconstraint- Include "auto" in the
run_backendexecution pathThe filtering logic at lines 836-842 appropriately mirrors the filtering done for other backends (cudnn, cutlass) and ensures "auto" is removed when layout constraints aren't met.
Also applies to: 836-842, 899-899
flashinfer/gemm.py (7)
425-465: LGTM! Runner refactoring improves consistency.The refactoring of CUTLASS FP4 GEMM into
cutlass_fp4_gemm_runnerwith the helper function_create_cutlass_fp4_gemm_moduleimproves naming consistency and aligns with the pattern used for other runners (e.g.,trtllm_fp4_gemm_runner).
1270-1294: LGTM! cuDNN tactic support enables fine-grained autotuning.The addition of
tacticparameter tobuild_plans_cudnn_fp4_gemm_graphandexecute_cudnn_gemm_fp4_graphenables plan-specific execution for autotuning. The logic correctly:
- Builds a specific plan when
tactic != -1- Builds all plans when
tactic == -1(fallback)- Executes the selected plan or uses default execution
This aligns with the autotuning framework's expectations and follows the pattern established by other tunable runners.
Also applies to: 1306-1331
1665-1802: LGTM! cuDNN FP4 runner properly implements TunableRunner interface.The new
_cudnn_gemm_fp4and_cudnn_gemm_fp4_runnerfunctions correctly:
- Encapsulate cuDNN FP4 GEMM execution with tactic support
- Implement the
TunableRunnerinterface withget_valid_tacticsandforwardmethods- Query available execution plans from the cuDNN graph
- Support tactic-specific execution for autotuning
The implementation follows the established pattern for tunable runners and integrates well with the autotuning framework.
1962-1997: LGTM! Auto backend requirement validation is well-implemented.The
_auto_gemm_fp4_requirementfunction correctly validates that the "auto" backend can be used by:
- Checking compute capability support for candidate backends (cudnn, cutlass)
- Explicitly excluding trtllm due to its different interface (as documented in the PR description)
- Returning True if at least one backend is supported
The implementation ensures that "auto" will only be accepted on devices where at least one compatible backend is available.
2136-2163: LGTM! Runner construction logic handles all backend cases correctly.The runner construction for each backend (cudnn, trtllm, cutlass) correctly:
- Creates appropriate runner instances based on backend type
- Handles dtype conversions for cutlass backend (uint8 ↔ float8_e4m3fn)
- Dispatches to the correct module based on device architecture (SM120 vs SM100/SM103)
- Falls through to a clear error for unsupported backends
The logic is well-structured and handles all supported backend configurations.
2165-2217: LGTM! Autotuning integration is well-structured.The autotuning setup correctly:
- Defines dynamic tensor specs for batch size variation (power-of-2 bucketing)
- Sets constraint specs to maintain shape relationships
- Prepares input tensors in the expected format
- Uses
AutoTuner.choose_oneto select the best (runner, tactic) combination- Executes the chosen runner with the selected tactic
The integration follows the established autotuning framework patterns and enables cross-backend tuning when
backend="auto".
2487-2563: LGTM! TRTLLM FP4 runner refactoring enables autotuning.The refactoring of
trtllm_fp4_gemm_runnerto:
- Accept
use_8x4_sf_layoutas a parameter- Implement the
TunableRunnerinterface with tactic support- Return a properly configured runner instance
This change aligns the TRTLLM backend with the autotuning framework and maintains consistency with other FP4 runners. The implementation correctly handles the
use_8x4_sf_layoutparameter throughout the runner lifecycle.tests/gemm/test_mm_fp4.py (2)
15-95: LGTM! Test refactoring improves maintainability.Extracting the test logic into
_test_mm_fp4is a good refactoring that:
- Eliminates code duplication between test functions
- Makes the test logic reusable and easier to maintain
- Consolidates backend support checks and skip conditions
The updated skip condition at lines 34-35 correctly limits mxfp4 support to cudnn and auto backends, which aligns with the implementation in
flashinfer/gemm.py.
97-127: LGTM! Test split provides good coverage of auto backend.The split between
test_mm_fp4(non-auto backends) andtest_mm_fp4_backend_auto(auto backend) is well-designed:
test_mm_fp4maintains full parameter coverage for individual backendstest_mm_fp4_backend_autotests the auto backend with a reduced but representative parameter space- The reduced parameter space (fewer m/n/k combinations, only
use_128x4_sf_layout=True) is appropriate for auto backend testing and helps keep test execution time reasonableThis approach provides comprehensive coverage while avoiding combinatorial explosion of test cases.
| @@ -823,7 +823,7 @@ def testMmFp4(args): | |||
| print( | |||
| "[INFO] cutlass backend does not support mxfp4 quantization (use_nvfp4=False)" | |||
| ) | |||
| backends.remove("cutlass") | |||
| remove_cutlass = True | |||
There was a problem hiding this comment.
Another way to avoid these remove_backend_x bools is to call the related backend check (which should be annoated with the decorator), or have the decorator return a filtered list as I proposed. #2000 (comment)
There was a problem hiding this comment.
Regardless whether you stuff it into the decorator, this will be a pattern that will happen for all APIs, so we should think about encapsulating the "if backend and checks_dont_pass: filter_it_out".
There was a problem hiding this comment.
This is a good idea. I have removed these hard coded checks entirely and have started using the checkers in the latest
There was a problem hiding this comment.
Now this is addressed with the latest decorator update #2029
| ) | ||
| # Auto-select the best backend | ||
| if backend == "auto": | ||
| cuda_major, _ = get_cuda_version(a.device) |
There was a problem hiding this comment.
These checks should be part of the _auto_gemm_fp4_requirement check.
I think a cleaner way would be to move the generation of the list of candidate_backends in the @backend_requirement decorator, where "auto" backend is treated specially. It lists the required checks for each backend already. An alternative is that we create a separate decorator that composes and uses the backend checks of the backend_requirement
There was a problem hiding this comment.
The danger here is that we may be repeating some checks, but not all of them.
There was a problem hiding this comment.
When writing the code path for this PR, I noted that the following questions had to be answered at different times by the auto backend logic:
- Is there at least one runnable backend for the given input params -- for early error raising
- What are the runnable backends for the given input params -- to consider which backends to choose from
- In the current GPU/CUDA/cuDNN environment, what is the preferred ordering of backends -- for heuristics
The current implementation in the PR answers 1 in @backend_requirement and 2 & 3 in the body of the mm_fp4 while you're suggesting putting 2 inside @backend_requirement. I agree that this helps us avoid repeating checks but this will involve--as you raised--a special treatment for the auto backend and a change to backend_requirement. We can discuss
There was a problem hiding this comment.
Now this is addressed with the latest decorator update #2029
| candidate_backends = ("cutlass", "cudnn") | ||
|
|
||
| # Filter to only supported backends for this compute capability | ||
| # Note: The requirement function already validated that at least one backend is supported |
There was a problem hiding this comment.
So this is the dangerous part: at this point, we know 1 backend replied that its check is ok. But we are considering all backends. Maybe cudnn supports it but not trtllm or cutlass.
There was a problem hiding this comment.
You are correct here. In the latest commit, I now check whether the backend is supported generally + for the inputs
There was a problem hiding this comment.
Now this is addressed with the latest decorator update #2029
| for candidate in candidate_backends: | ||
| # mypy requires explicit type casting for the backend literal | ||
| backend_literal = cast( | ||
| Literal["cudnn", "trtllm", "cutlass", "auto"], candidate |
There was a problem hiding this comment.
why is auto added back?
There was a problem hiding this comment.
Auto is actually not being added here since the cast() is telling pre-commit tests that backend_literal will be one of ["cudnn", "trtllm", "cutlass", "auto"] while candidate_backends will never contain auto.
However, there is no need for auto to be there and I can see it being confusing so I have removed in the latest commit
| ) | ||
| elif cur_backend == "cutlass": | ||
| if a.dtype == torch.uint8 and a_descale.dtype == torch.float8_e4m3fn: | ||
| a_descale = a_descale.view(torch.uint8) |
There was a problem hiding this comment.
This seems like an implementation detail, and maybe needs to be moved to the cutlass runner itself, just like we do with the cudnn_runner.
There was a problem hiding this comment.
Agree and this allows removal of the if-then-else structure above. Updated in latest commit
254827a to
c9f3d52
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between 254827a5c54c10d69705743d5068e9acd7299776 and c9f3d52788926f9ca228f415faa12eaffbf87400.
📒 Files selected for processing (4)
benchmarks/routines/flashinfer_benchmark_utils.py(1 hunks)benchmarks/routines/gemm.py(4 hunks)flashinfer/gemm.py(17 hunks)tests/gemm/test_mm_fp4.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
benchmarks/routines/gemm.py (2)
flashinfer/gemm.py (1)
_check_mm_fp4_problem_size(1812-1870)flashinfer/autotuner.py (1)
autotune(251-262)
flashinfer/gemm.py (2)
flashinfer/jit/cpp_ext.py (1)
get_cuda_version(64-83)flashinfer/utils.py (4)
supported_compute_capability(772-852)get_compute_capability(251-254)is_compute_capability_supported(966-972)backend_requirement(855-1028)
🪛 Ruff (0.14.2)
benchmarks/routines/gemm.py
900-900: Do not catch blind exception: Exception
(BLE001)
flashinfer/gemm.py
96-96: Unused function argument: device
(ARG001)
436-436: Unused method argument: inputs
(ARG002)
437-437: Unused method argument: profile
(ARG002)
445-445: Unused method argument: do_preparation
(ARG002)
446-446: Unused method argument: kwargs
(ARG002)
1730-1730: Unused method argument: profile
(ARG002)
1741-1741: Unpacked variable out is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1744-1744: Unpacked variable workspace_buffer is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1780-1780: Unused method argument: do_preparation
(ARG002)
1781-1781: Unused method argument: kwargs
(ARG002)
1863-1863: Avoid specifying long messages outside the exception class
(TRY003)
1884-1884: Unused function argument: backend
(ARG001)
1942-1942: Unused function argument: backend
(ARG001)
1964-1964: Unused function argument: backend
(ARG001)
1965-1965: Unused function argument: use_nvfp4
(ARG001)
1973-1973: Unused function argument: b
(ARG001)
1974-1974: Unused function argument: a_descale
(ARG001)
1975-1975: Unused function argument: b_descale
(ARG001)
1976-1976: Unused function argument: alpha
(ARG001)
1977-1977: Unused function argument: out_dtype
(ARG001)
1978-1978: Unused function argument: out
(ARG001)
1979-1979: Unused function argument: block_size
(ARG001)
1980-1980: Unused function argument: use_8x4_sf_layout
(ARG001)
1981-1981: Unused function argument: backend
(ARG001)
1982-1982: Unused function argument: use_nvfp4
(ARG001)
2109-2109: Unpacked variable cc_major is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2109-2109: Unpacked variable cc_minor is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2153-2154: try-except-pass detected, consider logging the exception
(S110)
2153-2153: Do not catch blind exception: Exception
(BLE001)
2517-2517: Unpacked variable a_descale is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2518-2518: Unpacked variable b_descale is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2519-2519: Unpacked variable alpha is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2521-2521: Unpacked variable out is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2524-2524: Unpacked variable workspace_buffer is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2538-2538: Unused method argument: do_preparation
(ARG002)
2539-2539: Unused method argument: kwargs
(ARG002)
fd4dfd6 to
1b9ffbf
Compare
|
/bot run |
|
/bot stop |
|
The GitLab CI pipeline #38303944 has been cancelled. |
c19ffc4 to
fe2070b
Compare
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tests/gemm/test_mm_fp4.py (1)
15-25:backend='auto'tests are always skipped due tois_backend_supportedmisuse.
mm_fp4.is_backend_supported("auto", cc)returnsFalsebecause"auto"is not a real backend key inbackend_checks. As a result,_test_mm_fp4skips alltest_mm_fp4_backend_autoparameterizations, so the new auto-backend tests never actually run.You’re getting “coverage” numbers without exercising the
backend="auto"path.A minimal fix is to avoid using
is_backend_supportedfor the synthetic"auto"backend and let the decorator logic drive error handling:- compute_capability = get_compute_capability(torch.device(device="cuda")) - compute_capability_number = compute_capability[0] * 10 + compute_capability[1] - if not mm_fp4.is_backend_supported(backend, compute_capability_number): - pytest.skip( - f"Skipping test for {backend} because it is not supported on compute capability {compute_capability_number}." - ) + compute_capability = get_compute_capability(torch.device(device="cuda")) + compute_capability_number = compute_capability[0] * 10 + compute_capability[1] + # For concrete backends, pre-skip unsupported CCs. For backend='auto', rely on + # the decorator to raise `BackendSupportedError` if no candidate backend exists. + if backend != "auto" and not mm_fp4.is_backend_supported( + backend, compute_capability_number + ): + pytest.skip( + f"Skipping test for {backend} because it is not supported on compute capability {compute_capability_number}." + )Optionally, if you want an early skip for auto as well, you can instead check
mm_fp4.is_compute_capability_supported(compute_capability_number).This will ensure
test_mm_fp4_backend_autoactually exercises both the heuristic and autotune behavior.Also applies to: 124-127
flashinfer/gemm/gemm_base.py (1)
1979-2015: Fix tensor index mismatch in TRTLLM FP4 autotuning.The issue is confirmed. In
TrtllmFp4GemmRunner.get_valid_tactics()(lines 2553–2620), usinga_tensor_index = 1andb_tensor_index = 2incorrectly maps tobanda_descalerespectively, since themm_fp4inputs list placesaat index 0 andbat index 1.Correct this by setting
a_tensor_index = 0andb_tensor_index = 1:def get_valid_tactics( self, inputs: List[torch.Tensor], profile: OptimizationProfile, ) -> List[int]: - a_tensor_index = 1 - b_tensor_index = 2 + a_tensor_index = 0 + b_tensor_index = 1 + opt_shapes = profile.get_opt_shapes() - a = profile.get_opt_shapes()[a_tensor_index] - b = profile.get_opt_shapes()[b_tensor_index] + a = opt_shapes[a_tensor_index] + b = opt_shapes[b_tensor_index] m = a[0] n = b[0] k = a[1] * 2This ensures
m, n, kpassed totrtllm_gemm_tacticsmatch the actual problem dimensions during autotuning.
🧹 Nitpick comments (3)
tests/utils/test_decorators.py (1)
347-359: Heuristic test helper is correct; consider minor cleanups to avoid shadowing.The
_heuristic_funcimplementation matches the newbackend_requirementcontract and exercises the auto-backend path correctly. One small readability nit is reusingbackendas the loop variable, which shadows the function argument and slightly hurts clarity; you can also simplifycandidate_backends/heuristic_backendsconstruction.For example:
- def _heuristic_func(suitable_backends, x, backend): - candidate_backends = None - if x.shape[0] > 5: - candidate_backends = ["cudnn", "cutlass"] - else: - candidate_backends = ["cutlass", "cudnn"] - - heuristic_backends = [] - for backend in candidate_backends: - if backend in suitable_backends: - heuristic_backends.append(backend) - return heuristic_backends + def _heuristic_func(suitable_backends, x, backend): + if x.shape[0] > 5: + candidate_backends = ["cudnn", "cutlass"] + else: + candidate_backends = ["cutlass", "cudnn"] + + return [b for b in candidate_backends if b in suitable_backends]This keeps behavior identical while making the helper a bit clearer.
Also applies to: 361-367
flashinfer/utils.py (1)
924-930:heuristic_funcsemantics look good; prefer explicit error overassertfor robustness.The extended docstring clearly defines
heuristic_func’s contract, and the change to always run it insuitable_auto_backendsenforces the intended rule that any API exposingbackend="auto"must supply a heuristic.One concern: using
assert heuristic_func is not None, "Heuristic function must be provided"means this safety net disappears under
python -O, and callers would then silently fall back to unorderedsuitable_backends. It’s safer to raise an explicit error whenbackend="auto"is used without a heuristic, e.g.:- assert heuristic_func is not None, "Heuristic function must be provided" - suitable_backends = heuristic_func(suitable_backends, *args, **kwargs) + if heuristic_func is None: + raise RuntimeError( + f"backend='auto' requires a heuristic_func for {func.__name__}" + ) + suitable_backends = heuristic_func(suitable_backends, *args, **kwargs)This keeps the behavior “loud” in all runtime modes while matching the intent discussed in earlier review threads.
Also applies to: 1086-1087
benchmarks/routines/gemm.py (1)
129-135: Runtime backend probing for mm_fp4 is sensible; narrow the catch to avoid hiding real bugs.Letting
testMmFp4discover supported backends by actually callingflashinfer.gemm.mm_fp4is a nice improvement over the old static CC map, and adding"auto"to--backendsplusautotune_supported_backends = ["cudnn", "cutlass", "trtllm", "auto"]lines up with the new auto-backend behavior.The main concern is this block:
for backend in backends: ... try: flashinfer.gemm.mm_fp4(...) except Exception as e: print( f"[INFO] {backend} backend does not support this configuration: {type(e).__name__}: {e}" ) backends_to_remove.append(backend)Catching bare
Exceptionmeans any failure in the runner (logic bug, memory issue, etc.) is treated as “backend unsupported” and the benchmark keeps going, which can silently mask regressions.Given the decorator and kernels raise well-typed errors on unsupported configs, you can be more precise here, e.g.:
- try: - flashinfer.gemm.mm_fp4(...) - except Exception as e: + try: + flashinfer.gemm.mm_fp4(...) + except (LibraryError, BackendSupportedError, ValueError) as e: print( f"[INFO] {backend} backend does not support this configuration: {type(e).__name__}: {e}" ) backends_to_remove.append(backend) + except Exception as e: + # Treat unexpected failures as real errors so they don't get hidden. + raiseThis keeps the probing-based filtering but makes genuine bugs in mm_fp4 visible during benchmarking instead of being silently filtered away.
Also applies to: 793-799, 843-884, 907-917
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between c19ffc4b32ad2dca37ab97bab0199f59342b2a16 and fe2070b.
📒 Files selected for processing (6)
benchmarks/routines/flashinfer_benchmark_utils.py(1 hunks)benchmarks/routines/gemm.py(5 hunks)flashinfer/gemm/gemm_base.py(18 hunks)flashinfer/utils.py(2 hunks)tests/gemm/test_mm_fp4.py(3 hunks)tests/utils/test_decorators.py(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
benchmarks/routines/gemm.pyflashinfer/gemm/gemm_base.py
🧬 Code graph analysis (3)
benchmarks/routines/gemm.py (2)
flashinfer/gemm/gemm_base.py (1)
mm_fp4(2027-2186)flashinfer/autotuner.py (1)
autotune(251-262)
tests/utils/test_decorators.py (1)
flashinfer/utils.py (1)
backend_requirement(897-1179)
flashinfer/gemm/gemm_base.py (2)
flashinfer/autotuner.py (7)
TunableRunner(194-247)OptimizationProfile(168-183)AutoTuner(335-786)TuningConfig(101-141)DynamicTensorSpec(41-82)ConstraintSpec(86-97)choose_one(400-529)flashinfer/utils.py (3)
backend_requirement(897-1179)suitable_auto_backends(1071-1091)get_compute_capability(253-256)
🪛 Ruff (0.14.5)
benchmarks/routines/gemm.py
871-871: Do not catch blind exception: Exception
(BLE001)
flashinfer/gemm/gemm_base.py
417-417: Unused method argument: inputs
(ARG002)
418-418: Unused method argument: profile
(ARG002)
426-426: Unused method argument: do_preparation
(ARG002)
427-427: Unused method argument: kwargs
(ARG002)
483-483: Avoid specifying long messages outside the exception class
(TRY003)
1671-1671: Unused function argument: out
(ARG001)
1748-1748: Unused method argument: profile
(ARG002)
1761-1761: Unpacked variable workspace_buffer is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1785-1785: Unused method argument: do_preparation
(ARG002)
1786-1786: Unused method argument: kwargs
(ARG002)
1824-1824: Unused function argument: out
(ARG001)
1826-1826: Unused function argument: use_8x4_sf_layout
(ARG001)
1827-1827: Unused function argument: backend
(ARG001)
1881-1881: Unused function argument: out
(ARG001)
1884-1884: Unused function argument: backend
(ARG001)
1888-1888: Avoid specifying long messages outside the exception class
(TRY003)
1936-1936: Unused function argument: a
(ARG001)
1937-1937: Unused function argument: b
(ARG001)
1938-1938: Unused function argument: a_descale
(ARG001)
1939-1939: Unused function argument: b_descale
(ARG001)
1940-1940: Unused function argument: alpha
(ARG001)
1942-1942: Unused function argument: out
(ARG001)
1943-1943: Unused function argument: block_size
(ARG001)
1944-1944: Unused function argument: use_8x4_sf_layout
(ARG001)
1945-1945: Unused function argument: backend
(ARG001)
1949-1949: Avoid specifying long messages outside the exception class
(TRY003)
1960-1960: Unused function argument: a
(ARG001)
1961-1961: Unused function argument: b
(ARG001)
1962-1962: Unused function argument: a_descale
(ARG001)
1963-1963: Unused function argument: b_descale
(ARG001)
1964-1964: Unused function argument: alpha
(ARG001)
1965-1965: Unused function argument: out_dtype
(ARG001)
1966-1966: Unused function argument: out
(ARG001)
1967-1967: Unused function argument: block_size
(ARG001)
1969-1969: Unused function argument: backend
(ARG001)
1973-1973: Avoid specifying long messages outside the exception class
(TRY003)
1975-1975: Avoid specifying long messages outside the exception class
(TRY003)
1990-1990: Unused function argument: backend
(ARG001)
1991-1991: Unused function argument: use_nvfp4
(ARG001)
2569-2569: Unpacked variable a_descale is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2570-2570: Unpacked variable b_descale is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2571-2571: Unpacked variable alpha is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2573-2573: Unpacked variable out is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2576-2576: Unpacked variable workspace_buffer is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2590-2590: Unused method argument: do_preparation
(ARG002)
2591-2591: Unused method argument: kwargs
(ARG002)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
benchmarks/routines/flashinfer_benchmark_utils.py (1)
238-239: Comment correctly reflects mm_fp4’s new dynamic support handling.Not listing
mm_fp4inroutine_cc_to_supported_backendsand documenting that it relies on runtime support checkers aligns with the new decorator-based validation inmm_fp4. This avoids stale hard-coded backend lists in the benchmark helper.flashinfer/gemm/gemm_base.py (1)
410-447: ****The original review comment incorrectly states that
a_descale.view(torch.uint8)will raise aTypeErrorat runtime. In current PyTorch,torch.Tensor.viewdoes support a dtype overload that accepts atorch.dtypeargument (e.g.,x.view(torch.uint8)), which reinterprets the underlying data with the given dtype without copying. This is the correct idiomatic approach for the FP8-to-uint8 reinterpretation shown inCutlassFp4GemmRunner.forward.The code is valid as written. No changes needed on this section.
Likely an incorrect or invalid review comment.
…udnn' (flashinfer-ai#1979) <!-- .github/pull_request_template.md --> Current PR: * Introduces an `auto` backend to `mm_fp4` that can be autotuned. **It replaces `cudnn` as the default.** * Implementation matches `bmm_fp8`'s auto backend support. * Allows `cudnn` backend to be autotuned. * Added unit test test cases for backend=auto Behavior of `auto` backend: * Examines CUDA version & cuDNN version and calls either `cutlass` or `cudnn` kernel backends. `trtllm` kernel is not considered due to a non-interchangeable interface with other backends. * `auto` backend therefore only supports inputs runnable by `cutlass` and/or `cudnn. * Non-autotuned behavior: * Constructs an ordered list of backends (cudnn, cutlass) or (cutlass, cudnn) where ordering is based on previous microbenchmark study results. * If CUDA 12 --> cutlass comes to front. * If CUDA 13 and cuDNN version < 9.15 --> cutlass comes front * If CUDA 13 and cuDNN version >= 9.15 --> cudnn comes front * If kernel is not available from a support check, it is removed from the list. * Autotune behavior: * If backend is explicitly provided --> Autotunes within the backend. Same as previous behavior, but now autotuning is supported for cudnn. * If `backend='auto'` --> Autotunes within and across backends (cudnn & cutlass) and chooses the best config of best backend. `trtllm` kernel is not considered * A lot of helper functions to `mm_fp4` were refactored to enable cross-backend autotuning. Refactoring was done to match cross-backend autotune-enabled `bmm_fp8` as a reference. `pytest tests/gemm/test_mm_fp4.py` * SM100 (B200) CUDA 13 & cuDNN 9.15: `900 passed, 2532 skipped in 125.19s (0:02:05)` * SM100 (B200) CUDA 12 & cuDNN 9.15: `900 passed, 2532 skipped in 125.67s (0:02:05)` * SM120 (RTX 5090) CUDA 13 & cuDNN 9.15: `720 passed, 2712 skipped in 76.50s (0:01:16)` On SM100 (B200) CUDA 13 & cuDNN 9.15 ``` flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --use_nvfp4 --refcheck [PERF] cudnn :: median time 0.018 ms; std 0.000 ms; achieved tflops 3797.932 TFLOPs/sec; achieved tb_per_sec 1.884 TB/sec [PERF] cutlass :: median time 0.020 ms; std 0.000 ms; achieved tflops 3440.640 TFLOPs/sec; achieved tb_per_sec 1.707 TB/sec [PERF] trtllm :: median time 0.031 ms; std 0.000 ms; achieved tflops 2187.427 TFLOPs/sec; achieved tb_per_sec 1.085 TB/sec [PERF] auto :: median time 0.018 ms; std 0.000 ms; achieved tflops 3840.714 TFLOPs/sec; achieved tb_per_sec 1.905 TB/sec /flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --refcheck [INFO] cutlass backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization. [INFO] trtllm backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization. [PERF] cudnn :: median time 0.021 ms; std 0.000 ms; achieved tflops 3238.249 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec [PERF] auto :: median time 0.021 ms; std 0.000 ms; achieved tflops 3237.753 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec median time 0.009 ms; std 0.000 ms; achieved tflops 938.356 TFLOPs/sec; achieved tb_per_sec 2.069 TB/sec /flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --use_nvfp4 --refcheck --autotune 2025-11-11 23:43:23,715 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:43:25,789 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends 2025-11-11 23:43:25,790 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:43:26,251 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends 2025-11-11 23:43:26,251 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:43:26,327 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends 2025-11-11 23:43:26,327 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:43:26,335 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends [PERF] cudnn_autotune :: median time 0.016 ms; std 0.000 ms; achieved tflops 4129.171 TFLOPs/sec; achieved tb_per_sec 2.048 TB/sec [PERF] cutlass_autotun:: median time 0.019 ms; std 0.000 ms; achieved tflops 3513.845 TFLOPs/sec; achieved tb_per_sec 1.743 TB/sec [PERF] trtllm_autotune:: median time 0.026 ms; std 0.000 ms; achieved tflops 2613.338 TFLOPs/sec; achieved tb_per_sec 1.296 TB/sec [PERF] auto_autotune :: median time 0.016 ms; std 0.000 ms; achieved tflops 4128.768 TFLOPs/sec; achieved tb_per_sec 2.048 TB/sec /flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --refcheck --autotune [INFO] cutlass backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization. [INFO] trtllm backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization. 2025-11-11 23:43:37,942 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:43:43,116 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends 2025-11-11 23:43:43,116 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:43:43,124 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends [PERF] cudnn_autotune :: median time 0.020 ms; std 0.000 ms; achieved tflops 3370.154 TFLOPs/sec; achieved tb_per_sec 1.672 TB/sec [PERF] auto_autotune :: median time 0.020 ms; std 0.000 ms; achieved tflops 3370.692 TFLOPs/sec; achieved tb_per_sec 1.672 TB/sec ``` On SM100 (B200) CUDA 12 & cuDNN 9.15 ``` flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --use_nvfp4 --refcheck [PERF] cudnn :: median time 0.023 ms; std 0.001 ms; achieved tflops 2975.898 TFLOPs/sec; achieved tb_per_sec 1.476 TB/sec [PERF] cutlass :: median time 0.020 ms; std 0.000 ms; achieved tflops 3370.423 TFLOPs/sec; achieved tb_per_sec 1.672 TB/sec [PERF] trtllm :: median time 0.031 ms; std 0.000 ms; achieved tflops 2187.427 TFLOPs/sec; achieved tb_per_sec 1.085 TB/sec [PERF] auto :: median time 0.020 ms; std 0.000 ms; achieved tflops 3371.229 TFLOPs/sec; achieved tb_per_sec 1.672 TB/sec (py312) root@84ef83abb1b5:/flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --refcheck [INFO] cutlass backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization. [INFO] trtllm backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization. [PERF] cudnn :: median time 0.021 ms; std 0.000 ms; achieved tflops 3238.249 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec [PERF] auto :: median time 0.021 ms; std 0.000 ms; achieved tflops 3238.249 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec /flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --use_nvfp4 --refcheck --autotune 2025-11-11 23:42:43,378 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:42:45,451 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends 2025-11-11 23:42:45,451 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:42:45,910 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends 2025-11-11 23:42:45,910 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:42:45,986 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends 2025-11-11 23:42:45,986 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:42:45,993 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends [PERF] cudnn_autotune :: median time 0.021 ms; std 0.000 ms; achieved tflops 3190.355 TFLOPs/sec; achieved tb_per_sec 1.583 TB/sec [PERF] cutlass_autotun:: median time 0.019 ms; std 0.000 ms; achieved tflops 3551.330 TFLOPs/sec; achieved tb_per_sec 1.762 TB/sec [PERF] trtllm_autotune:: median time 0.026 ms; std 0.000 ms; achieved tflops 2621.440 TFLOPs/sec; achieved tb_per_sec 1.300 TB/sec [PERF] auto_autotune :: median time 0.019 ms; std 0.000 ms; achieved tflops 3551.628 TFLOPs/sec; achieved tb_per_sec 1.762 TB/sec flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --refcheck --autotune [INFO] cutlass backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization. [INFO] trtllm backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization. 2025-11-11 23:42:55,176 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:42:58,600 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends 2025-11-11 23:42:58,601 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:42:58,608 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends [PERF] cudnn_autotune :: median time 0.021 ms; std 0.000 ms; achieved tflops 3238.249 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec [PERF] auto_autotune :: median time 0.021 ms; std 0.000 ms; achieved tflops 3238.249 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec ``` On SM120 (RTX 5090) CUDA 13 & cuDNN 9.15 ``` /flashinfer/benchmarks$ python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --use_nvfp4 --refcheck [INFO] trtllm backend does not support this configuration: BackendSupportedError: mm_fp4 does not support backend 'trtllm' with capability 120 [PERF] cudnn :: median time 0.058 ms; std 0.000 ms; achieved tflops 1167.143 TFLOPs/sec; achieved tb_per_sec 0.579 TB/sec [PERF] cutlass :: median time 0.060 ms; std 0.000 ms; achieved tflops 1135.056 TFLOPs/sec; achieved tb_per_sec 0.563 TB/sec [PERF] auto :: median time 0.058 ms; std 0.000 ms; achieved tflops 1158.952 TFLOPs/sec; achieved tb_per_sec 0.575 TB/sec /flashinfer/benchmarks$ python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --refcheck [INFO] cutlass backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization. [INFO] trtllm backend does not support this configuration: BackendSupportedError: mm_fp4 does not support backend 'trtllm' with capability 120 [PERF] cudnn :: median time 0.054 ms; std 0.000 ms; achieved tflops 1241.735 TFLOPs/sec; achieved tb_per_sec 0.616 TB/sec [PERF] auto :: median time 0.054 ms; std 0.000 ms; achieved tflops 1241.735 TFLOPs/sec; achieved tb_per_sec 0.616 TB/sec ``` <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> <!-- Link any related issues here --> Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> * **New Features** * "auto" backend selection for FP4 ops to choose backend at runtime * cuDNN, CUTLASS and TRTLLM selectable as FP4 GEMM backends * CUDA/cuDNN version awareness to guide auto-backend heuristics * **Improvements** * Runtime capability checks replace static backend lists; unsupported backends are removed dynamically * Heuristic-driven auto-backend selection required for automatic mode * Expanded autotuning/warmup across backends and relaxed FP4 validation tolerance * **Tests** * Tests updated and added to exercise auto-backend scenarios and relaxed constraints <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
Current PR:
autobackend tomm_fp4that can be autotuned. It replacescudnnas the default.bmm_fp8's auto backend support.cudnnbackend to be autotuned.Behavior of
autobackend:cutlassorcudnnkernel backends.trtllmkernel is not considered due to a non-interchangeable interface with other backends.autobackend therefore only supports inputs runnable bycutlassand/or `cudnn.backend='auto'--> Autotunes within and across backends (cudnn & cutlass) and chooses the best config of best backend.trtllmkernel is not consideredmm_fp4were refactored to enable cross-backend autotuning. Refactoring was done to match cross-backend autotune-enabledbmm_fp8as a reference.Pytest outputs
pytest tests/gemm/test_mm_fp4.py900 passed, 2532 skipped in 125.19s (0:02:05)900 passed, 2532 skipped in 125.67s (0:02:05)720 passed, 2712 skipped in 76.50s (0:01:16)Example microbenchmark outputs:
On SM100 (B200) CUDA 13 & cuDNN 9.15
On SM100 (B200) CUDA 12 & cuDNN 9.15
On SM120 (RTX 5090) CUDA 13 & cuDNN 9.15
🔍 Related Issues
#1722
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Improvements
Tests
✏️ Tip: You can customize this high-level summary in your review settings.