feat: suitable_auto_backends to prune auto backends, bmm_fp8 refactor, heuristic_func intake#2029
Conversation
WalkthroughPer-backend FP8 BMM requirement checks were added and the Changes
Sequence Diagram(s)sequenceDiagram
actor User as User Code
participant Decorator as backend_requirement wrapper
participant AutoSel as suitable_auto_backends discovery
participant Backend as Backend checks (cuDNN / cuBLAS / CUTLASS)
participant Impl as decorated function
User->>Decorator: Call with backend="auto"
Decorator->>AutoSel: Enumerate candidate backends
loop per backend
AutoSel->>Backend: Run backend-specific requirement
alt passes
Backend-->>AutoSel: Suitable
else fails
Backend-->>AutoSel: Not suitable
end
end
alt any suitable
AutoSel-->>Decorator: list of suitable backends
Decorator->>Impl: Invoke using chosen backend (heuristic/order)
Impl-->>Decorator: Result
Decorator-->>User: Return result
else none suitable
Decorator-->>User: Raise BackendSupportedError
end
sequenceDiagram
actor User as User Code
participant Decorator as backend_requirement wrapper
participant Requirement as Backend-specific check
participant Impl as decorated function
User->>Decorator: Call with backend="cublas"
Decorator->>Requirement: `_cublas_bmm_fp8_requirement` + `_check_bmm_fp8_problem_size`
alt capability & problem size valid
Requirement-->>Decorator: OK
Decorator->>Impl: Execute on cublas
Impl-->>Decorator: Result
Decorator-->>User: Return result
else invalid
Requirement-->>Decorator: Error
Decorator-->>User: Raise error
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @jimmyzho, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the backend selection process for floating-point 8 (FP8) matrix multiplication operations, specifically Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a significant and well-structured refactoring for backend selection and validation, particularly for the auto backend setting. The new @backend_requirement decorator is a good improvement. However, I've identified a critical issue in the new suitable_auto_backends function where it fails to check for GPU compute capability and is not robust against exceptions during requirement checks. I have provided detailed comments and suggestions to address this. Additionally, I've pointed out the removal of a crucial safety check in fp8_gemm_sm100 which could lead to runtime errors. The new tests are a great addition and help validate the new functionality.
There was a problem hiding this comment.
Actionable comments posted: 2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between da01b1b and 6c1079eacf0ff8b8c1a221c05ea4bb6eb497af8f.
📒 Files selected for processing (3)
flashinfer/gemm.py(2 hunks)flashinfer/utils.py(2 hunks)tests/utils/test_decorators.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/utils/test_decorators.py (1)
flashinfer/utils.py (4)
supported_compute_capability(772-852)backend_requirement(855-1044)suitable_auto_backends(986-995)BackendSupportedError(63-66)
flashinfer/gemm.py (2)
flashinfer/utils.py (3)
supported_compute_capability(772-852)backend_requirement(855-1044)suitable_auto_backends(986-995)csrc/bmm_fp8.cu (2)
bmm_fp8(23-63)bmm_fp8(23-24)
🪛 Ruff (0.14.3)
tests/utils/test_decorators.py
204-204: Unused function argument: backend
(ARG001)
208-208: Unused function argument: backend
(ARG001)
212-212: Unused function argument: backend
(ARG001)
flashinfer/gemm.py
2014-2014: Unused function argument: A
(ARG001)
2015-2015: Unused function argument: B
(ARG001)
2016-2016: Unused function argument: A_scale
(ARG001)
2017-2017: Unused function argument: B_scale
(ARG001)
2018-2018: Unused function argument: dtype
(ARG001)
2019-2019: Unused function argument: out
(ARG001)
2020-2020: Unused function argument: backend
(ARG001)
2028-2028: Unused function argument: A
(ARG001)
2029-2029: Unused function argument: B
(ARG001)
2030-2030: Unused function argument: A_scale
(ARG001)
2031-2031: Unused function argument: B_scale
(ARG001)
2032-2032: Unused function argument: dtype
(ARG001)
2033-2033: Unused function argument: out
(ARG001)
2034-2034: Unused function argument: backend
(ARG001)
2043-2043: Unused function argument: A_scale
(ARG001)
2044-2044: Unused function argument: B_scale
(ARG001)
2045-2045: Unused function argument: dtype
(ARG001)
2046-2046: Unused function argument: out
(ARG001)
2047-2047: Unused function argument: backend
(ARG001)
2050-2050: Avoid specifying long messages outside the exception class
(TRY003)
2055-2055: Unused function argument: A
(ARG001)
2056-2056: Unused function argument: B
(ARG001)
2057-2057: Unused function argument: A_scale
(ARG001)
2058-2058: Unused function argument: B_scale
(ARG001)
2060-2060: Unused function argument: out
(ARG001)
2061-2061: Unused function argument: backend
(ARG001)
flashinfer/utils.py
1025-1027: Avoid specifying long messages outside the exception class
(TRY003)
1031-1033: Avoid specifying long messages outside the exception class
(TRY003)
1035-1037: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (6)
tests/utils/test_decorators.py (1)
194-229: LGTM! Well-structured test for auto-backend selection.The test correctly validates the new
suitable_auto_backendsmechanism by:
- Using shape-based constraints to filter backends
- Verifying the
suitable_auto_backendsattribute is populated correctly- Testing the error path when no backends are suitable
Note: The static analysis warnings about unused
backendparameters are false positives—these requirement check functions must match the decorated function's signature.flashinfer/gemm.py (4)
2012-2023: LGTM! Proper backend requirement gate for cuDNN.The function correctly validates cuDNN availability and is decorated with appropriate compute capabilities (SM89, 90, 100, 103, 120).
Note: The unused parameter warnings from static analysis are false positives—requirement check functions must match the decorated function signature.
2026-2036: LGTM! Compute capability gate for cuBLAS backend.The function serves as a pure compute capability gate for the cuBLAS backend (SM89, 90, 100, 103, 120). No additional validation is needed since cuBLAS is the most permissive backend.
2039-2051: LGTM! Proper validation for CUTLASS backend constraints.The function correctly:
- Gates on newer compute capabilities (SM100, 103, 110, 120, 121)
- Validates that e5m2 is not used (CUTLASS doesn't support it for FP8 BMM)
Note: Ruff flags the inline error message as a style issue (TRY003), but this is acceptable for this use case.
2054-2064: LGTM! Common validation shared across all backends.The function serves as the
common_checkparameter in the@backend_requirementdecorator, validating that the output dtype is either bf16 or fp16—a constraint common to all FP8 BMM backends.flashinfer/utils.py (1)
986-995: Race condition withwrapper.suitable_auto_backendsmutation is real; attribute access pattern requires careful review.The mutation at line 994 sets
wrapper.suitable_auto_backendsas a side effect, creating actual concurrency hazards:
Concurrent calls: If the decorated function is called from multiple threads, the attribute is overwritten between validation and usage. In
flashinfer/gemm.py:2152,backends = bmm_fp8.suitable_auto_backendsaccesses this attribute after the wrapper returns, where it could reflect a different thread's result.Skip-check stale data: When
skip_check=Trueis passed,suitable_auto_backends()is never called, sowrapper.suitable_auto_backendsis not set. The decorated function accessing this attribute would receive stale data from a prior call or hit anAttributeError.Consider storing the suitable backends in a local variable within the wrapper and passing it to the decorated function (e.g., as a kwarg or via a thread-local context) rather than relying on attribute mutation.
|
Hi @jimmyzho would you mind merging with (or rebase onto) main branch? |
Head branch was pushed to by a user without write access
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (1)
flashinfer/utils.py (1)
1046-1061: Fix: Prevent AttributeError whenskip_check=Trueandbackend="auto"are combined.When
backend="auto", the code at line 1047 callssuitable_auto_backends()which setswrapper.suitable_auto_backends. However, ifskip_check=True, the entire validation block (lines 1020-1060) is bypassed, sowrapper.suitable_auto_backendsis never set.Later, decorated functions like
bmm_fp8(line 2173 in gemm.py) accessbmm_fp8.suitable_auto_backends, which will raiseAttributeErrorwhenskip_check=Trueandbackend="auto"are used together.Apply this fix to reject the unsupported combination:
@functools.wraps(func) def wrapper(*args, **kwargs): # skip_check is an optional argument that the decorator adds to any API function. # It prevents the performance overhead of checking. skip_check = kwargs.pop("skip_check", False) + # backend="auto" requires validation to populate suitable_auto_backends + backend = kwargs.get("backend") + if backend == "auto" and skip_check: + raise ValueError( + f"{func.__name__}: backend='auto' requires validation and cannot be used with skip_check=True. " + "Either set skip_check=False or specify an explicit backend." + ) + if not skip_check:
🧹 Nitpick comments (1)
flashinfer/utils.py (1)
856-860: Document the newheuristic_funcparameter in the docstring.The signature correctly adds an optional
heuristic_funcparameter, but the docstring (lines 870-882) doesn't document it. Add documentation for this parameter describing:
- Its purpose (to order/filter suitable backends)
- Its signature (should accept suitable_backends list and function arguments)
- Expected return value (ordered list of backend names)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between f85c63d744d4adfedf62d6cfda4ee31215242a18 and 392cb79c2af6f7e76240d0ba855df6aba1ed1899.
📒 Files selected for processing (3)
flashinfer/gemm.py(3 hunks)flashinfer/utils.py(3 hunks)tests/utils/test_decorators.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/utils/test_decorators.py (1)
flashinfer/utils.py (4)
supported_compute_capability(773-853)backend_requirement(856-1067)suitable_auto_backends(994-1010)BackendSupportedError(64-67)
flashinfer/gemm.py (3)
flashinfer/utils.py (3)
supported_compute_capability(773-853)backend_requirement(856-1067)suitable_auto_backends(994-1010)csrc/bmm_fp8.cu (2)
bmm_fp8(23-63)bmm_fp8(23-24)csrc/flashinfer_gemm_binding.cu (1)
bmm_fp8(19-20)
🪛 Ruff (0.14.3)
flashinfer/utils.py
1048-1050: Avoid specifying long messages outside the exception class
(TRY003)
1054-1056: Avoid specifying long messages outside the exception class
(TRY003)
1058-1060: Avoid specifying long messages outside the exception class
(TRY003)
tests/utils/test_decorators.py
204-204: Unused function argument: backend
(ARG001)
208-208: Unused function argument: backend
(ARG001)
212-212: Unused function argument: backend
(ARG001)
241-241: Unused function argument: backend
(ARG001)
245-245: Unused function argument: backend
(ARG001)
248-248: Unused function argument: suitable_backends
(ARG001)
248-248: Unused function argument: x
(ARG001)
248-248: Unused function argument: backend
(ARG001)
255-255: Unused function argument: backend
(ARG001)
flashinfer/gemm.py
2002-2002: Unused function argument: A
(ARG001)
2003-2003: Unused function argument: B
(ARG001)
2004-2004: Unused function argument: A_scale
(ARG001)
2005-2005: Unused function argument: B_scale
(ARG001)
2006-2006: Unused function argument: dtype
(ARG001)
2007-2007: Unused function argument: out
(ARG001)
2008-2008: Unused function argument: backend
(ARG001)
2016-2016: Unused function argument: A
(ARG001)
2017-2017: Unused function argument: B
(ARG001)
2018-2018: Unused function argument: A_scale
(ARG001)
2019-2019: Unused function argument: B_scale
(ARG001)
2020-2020: Unused function argument: dtype
(ARG001)
2021-2021: Unused function argument: out
(ARG001)
2022-2022: Unused function argument: backend
(ARG001)
2031-2031: Unused function argument: A_scale
(ARG001)
2032-2032: Unused function argument: B_scale
(ARG001)
2033-2033: Unused function argument: dtype
(ARG001)
2034-2034: Unused function argument: out
(ARG001)
2035-2035: Unused function argument: backend
(ARG001)
2038-2038: Avoid specifying long messages outside the exception class
(TRY003)
2043-2043: Unused function argument: A
(ARG001)
2044-2044: Unused function argument: B
(ARG001)
2045-2045: Unused function argument: A_scale
(ARG001)
2046-2046: Unused function argument: B_scale
(ARG001)
2048-2048: Unused function argument: out
(ARG001)
2049-2049: Unused function argument: backend
(ARG001)
2059-2059: Unused function argument: A_scale
(ARG001)
2060-2060: Unused function argument: B_scale
(ARG001)
2061-2061: Unused function argument: dtype
(ARG001)
2062-2062: Unused function argument: out
(ARG001)
2063-2063: Unused function argument: backend
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (9)
tests/utils/test_decorators.py (2)
194-229: LGTM! Test correctly validates auto-backend selection.The test properly validates:
- Backend filtering based on problem size constraints
- Correct population of
suitable_auto_backendsattribute- Error handling when no suitable backends are found
Note: The static analysis warnings about unused
backendparameters (lines 204, 208, 212) are false positives - these parameters are required by thebackend_requirementdecorator API to match the decorated function's signature.
231-263: LGTM! Test correctly validates heuristic function behavior.The test properly verifies:
- Heuristic function is called during auto-backend selection
- Backend ordering is determined by the heuristic function
- Output shape validation after backend selection
Note: The static analysis warnings about unused parameters in
_heuristic_func(lines 248) are false positives for this test fixture - the function signature must match the decorator's expected interface.flashinfer/utils.py (1)
994-1010: LGTM! Common check validation is correctly implemented.The function properly validates both
common_check(lines 995-996) and individual backend checks, addressing the concern from previous reviews. The heuristic function integration (lines 1007-1008) is also correct.However, note that this function sets
wrapper.suitable_auto_backends(line 1009), which is later accessed whenbackend="auto". This creates a potential AttributeError ifskip_check=True- see separate comment on lines 1046-1060.flashinfer/gemm.py (6)
2000-2039: LGTM! Backend requirement functions correctly specify compute capability support.The three requirement functions properly gate backend support:
- cudnn: SM89, 90, 100, 103, 120
- cublas: SM89, 90, 100, 103, 120
- cutlass: SM100, 103, 110, 120, 121 (includes SM110 and SM121 that were missing in earlier iterations)
The e5m2 rejection for cutlass (lines 2037-2038) is appropriate.
Note: Static analysis warnings about unused arguments are false positives - these functions must match the decorated function signature for the
backend_requirementdecorator to work correctly.
2042-2052: LGTM! Common check validates output dtype.The function performs the essential validation of output dtype that applies to all backends. The unused arguments are required for signature matching with the decorator API.
2055-2084: LGTM! Heuristic function correctly orders backends and handles SM-specific constraints.The function properly:
- Filters out cutlass for e5m2 inputs (line 2072)
- Selects appropriate cutlass variant based on SM architecture (lines 2073-2078)
- Applies k_dim >= 128 constraint for SM120/121 (lines 2076-2078), which matches the blockwise scaling requirement at line 244
- Preserves the expected backend ordering: cutlass variants → cublas → cudnn
2087-2095: LGTM! Decorator correctly applied with all backend requirements.The
@backend_requirementdecorator properly integrates:
- Backend-specific requirement functions for cudnn, cublas, and cutlass
- Common problem size validation
- Heuristic function for backend ordering during auto-selection
367-375: LGTM! SM-specific CUTLASS runner support correctly added.The changes properly handle:
cutlass_sm10xrunner for SM100/103/110 architecturescutlass_sm12xrunner for SM120/121 architecturesThis integrates correctly with the heuristic function that selects these runners based on device capability and problem size.
2172-2175: Verify auto-backend selection with skip_check handling.The logic correctly uses
bmm_fp8.suitable_auto_backendsfor automatic backend selection. However, this attribute is only set during validation. If users callbmm_fp8(... backend="auto", skip_check=True), this will raiseAttributeError.This issue should be addressed by the fix proposed in
flashinfer/utils.py(lines 1046-1061), which will prevent the invalid combination.Once the fix in
utils.pyis applied, verify that:
backend="auto"withskip_check=Trueraises a clear errorbackend="auto"withskip_check=False(default) works correctly- Explicit backends work with both
skip_checkvalues
ad73ce6 to
c2ff918
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between 392cb79c2af6f7e76240d0ba855df6aba1ed1899 and 2cc760441a8e320f010c346183d52d3e55f28736.
📒 Files selected for processing (3)
flashinfer/gemm.py(3 hunks)flashinfer/utils.py(3 hunks)tests/utils/test_decorators.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/utils/test_decorators.py (1)
flashinfer/utils.py (4)
supported_compute_capability(773-853)backend_requirement(856-1074)suitable_auto_backends(994-1010)BackendSupportedError(64-67)
flashinfer/gemm.py (3)
flashinfer/utils.py (3)
supported_compute_capability(773-853)backend_requirement(856-1074)suitable_auto_backends(994-1010)csrc/bmm_fp8.cu (2)
bmm_fp8(23-63)bmm_fp8(23-24)csrc/flashinfer_gemm_binding.cu (1)
bmm_fp8(19-20)
🪛 Ruff (0.14.3)
flashinfer/utils.py
1048-1050: Avoid specifying long messages outside the exception class
(TRY003)
1054-1056: Avoid specifying long messages outside the exception class
(TRY003)
1058-1060: Avoid specifying long messages outside the exception class
(TRY003)
tests/utils/test_decorators.py
204-204: Unused function argument: backend
(ARG001)
208-208: Unused function argument: backend
(ARG001)
212-212: Unused function argument: backend
(ARG001)
241-241: Unused function argument: backend
(ARG001)
245-245: Unused function argument: backend
(ARG001)
248-248: Unused function argument: suitable_backends
(ARG001)
248-248: Unused function argument: x
(ARG001)
248-248: Unused function argument: backend
(ARG001)
255-255: Unused function argument: backend
(ARG001)
flashinfer/gemm.py
2002-2002: Unused function argument: A
(ARG001)
2003-2003: Unused function argument: B
(ARG001)
2004-2004: Unused function argument: A_scale
(ARG001)
2005-2005: Unused function argument: B_scale
(ARG001)
2006-2006: Unused function argument: dtype
(ARG001)
2007-2007: Unused function argument: out
(ARG001)
2008-2008: Unused function argument: backend
(ARG001)
2016-2016: Unused function argument: A
(ARG001)
2017-2017: Unused function argument: B
(ARG001)
2018-2018: Unused function argument: A_scale
(ARG001)
2019-2019: Unused function argument: B_scale
(ARG001)
2020-2020: Unused function argument: dtype
(ARG001)
2021-2021: Unused function argument: out
(ARG001)
2022-2022: Unused function argument: backend
(ARG001)
2031-2031: Unused function argument: A_scale
(ARG001)
2032-2032: Unused function argument: B_scale
(ARG001)
2033-2033: Unused function argument: dtype
(ARG001)
2034-2034: Unused function argument: out
(ARG001)
2035-2035: Unused function argument: backend
(ARG001)
2038-2038: Avoid specifying long messages outside the exception class
(TRY003)
2043-2043: Unused function argument: A
(ARG001)
2044-2044: Unused function argument: B
(ARG001)
2045-2045: Unused function argument: A_scale
(ARG001)
2046-2046: Unused function argument: B_scale
(ARG001)
2048-2048: Unused function argument: out
(ARG001)
2049-2049: Unused function argument: backend
(ARG001)
2059-2059: Unused function argument: A_scale
(ARG001)
2060-2060: Unused function argument: B_scale
(ARG001)
2061-2061: Unused function argument: dtype
(ARG001)
2062-2062: Unused function argument: out
(ARG001)
2063-2063: Unused function argument: backend
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
|
/bot run |
|
[FAILED] Pipeline #37905418: 5/17 passed |
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (2)
flashinfer/utils.py (2)
994-1011: Critical: Missing compute capability filtering in auto backend selection.The
suitable_auto_backendsfunction does not filter candidates by compute capability before including them. This means that on devices like A100 (SM80), backends that require SM90+ (e.g., certain CUTLASS variants) could be included insuitable_auto_backends, leading to runtime failures when the kernel is invoked.The explicit backend path validates compute capability via
is_backend_supported(backend, capability)at line 1053, but the auto path bypasses this check entirely.Apply this fix to add capability filtering:
def suitable_auto_backends(*args, **kwargs): if common_check is not None and not common_check(*args, **kwargs): return False + + # Extract compute capability from tensor arguments + capability = None + for value in kwargs.values(): + if isinstance(value, torch.Tensor): + major, minor = get_compute_capability(value.device) + capability = major * 10 + minor + break + suitable_backends = [] # Check for each backend support for backend in backend_checks: try: - if backend_checks[backend](*args, **kwargs): + if (capability is None or is_backend_supported(backend, capability)) and backend_checks[backend](*args, **kwargs): suitable_backends.append(backend) except ValueError: continue # If a heuristic function is provided, filter the suitable backends based on the heuristic function if heuristic_func is not None: suitable_backends = heuristic_func(suitable_backends, *args, **kwargs) if not suitable_backends: return False wrapper.suitable_auto_backends = suitable_backends return TrueBased on learnings
1047-1067: Critical: AttributeError whenbackend="auto"withskip_check=Trueand no heuristic function.When a function is called with
backend="auto"andskip_check=True, thesuitable_auto_backendsattribute is only populated ifheuristic_funcis provided (lines 1062-1067). Ifheuristic_funcisNone, the attribute is never set, causing anAttributeErrorwhen the calling code tries to accesswrapper.suitable_auto_backends.Add an explicit guard to reject this unsupported combination:
if tensor_arg is not None: # Get compute capability from the first tensor # Assume all tensors are on the same device/capability major, minor = get_compute_capability(tensor_arg.device) capability = major * 10 + minor if backend == "auto": + if skip_check: + raise ValueError( + f"{func.__name__} does not support backend='auto' with skip_check=True. " + f"Please either set skip_check=False or specify an explicit backend." + ) if not suitable_auto_backends(**kwargs_with_defaults): raise BackendSupportedError( f"No suitable auto backends found for {func.__name__}" ) else: if not is_backend_supported(backend, capability): extra = f" with capability {capability}" if capability else "" raise BackendSupportedError( f"{func.__name__} does not support backend '{backend}'{extra}" ) if not _is_problem_size_supported(**kwargs_with_defaults): raise ValueError( f"Problem size is not supported for {func.__name__}" ) - elif skip_check and heuristic_func is not None: - bound_args = sig.bind(*args, **kwargs) - bound_args.apply_defaults() - kwargs_with_defaults = dict(bound_args.arguments) - # This needs to be called for heuristic function - suitable_auto_backends(**kwargs_with_defaults) + else: + # skip_check is True, still populate suitable_auto_backends if heuristic_func exists + if heuristic_func is not None: + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + kwargs_with_defaults = dict(bound_args.arguments) + suitable_auto_backends(**kwargs_with_defaults) return func(*args, **kwargs)
🧹 Nitpick comments (1)
flashinfer/utils.py (1)
856-860: Document the newheuristic_funcparameter.The
heuristic_funcparameter was added to the decorator signature but is not documented in the docstring. This makes it difficult for users and future maintainers to understand its purpose and usage.Add documentation for the new parameter:
backend_checks : dict A dictionary mapping backend names (str) to requirement checker functions. Each checker function should accept the same arguments as the decorated function and return True if the problem size is supported, False otherwise. Checkers can be decorated with @supported_compute_capability to specify which compute capabilities they support. common_check : callable, optional An optional function that performs additional validation checks common to all backends. Should accept the same arguments as the decorated function and return True if requirements are met, False otherwise. + heuristic_func : callable, optional + An optional function that reorders and filters suitable backends based on runtime + heuristics. Should accept a list of backend names (that passed backend_checks) as + the first argument, followed by the same arguments as the decorated function, and + return a reordered/filtered list of backend names. This is useful for dynamically + selecting preferred backends based on problem characteristics (e.g., problem size, + data types, device architecture).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between 2cc760441a8e320f010c346183d52d3e55f28736 and f76b98106a6050de30229b1e36522b99d128d20a.
📒 Files selected for processing (2)
flashinfer/gemm.py(3 hunks)flashinfer/utils.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/gemm.py (4)
flashinfer/utils.py (3)
supported_compute_capability(773-853)backend_requirement(856-1075)suitable_auto_backends(994-1011)include/flashinfer/trtllm/common.h (1)
device(83-90)csrc/bmm_fp8.cu (2)
bmm_fp8(23-63)bmm_fp8(23-24)csrc/flashinfer_gemm_binding.cu (1)
bmm_fp8(19-20)
🪛 Ruff (0.14.3)
flashinfer/utils.py
1049-1051: Avoid specifying long messages outside the exception class
(TRY003)
1055-1057: Avoid specifying long messages outside the exception class
(TRY003)
1059-1061: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer/gemm.py
2002-2002: Unused function argument: A
(ARG001)
2003-2003: Unused function argument: B
(ARG001)
2004-2004: Unused function argument: A_scale
(ARG001)
2005-2005: Unused function argument: B_scale
(ARG001)
2006-2006: Unused function argument: dtype
(ARG001)
2007-2007: Unused function argument: out
(ARG001)
2008-2008: Unused function argument: backend
(ARG001)
2016-2016: Unused function argument: A
(ARG001)
2017-2017: Unused function argument: B
(ARG001)
2018-2018: Unused function argument: A_scale
(ARG001)
2019-2019: Unused function argument: B_scale
(ARG001)
2020-2020: Unused function argument: dtype
(ARG001)
2021-2021: Unused function argument: out
(ARG001)
2022-2022: Unused function argument: backend
(ARG001)
2031-2031: Unused function argument: A_scale
(ARG001)
2032-2032: Unused function argument: B_scale
(ARG001)
2033-2033: Unused function argument: dtype
(ARG001)
2034-2034: Unused function argument: out
(ARG001)
2035-2035: Unused function argument: backend
(ARG001)
2038-2038: Avoid specifying long messages outside the exception class
(TRY003)
2043-2043: Unused function argument: A
(ARG001)
2044-2044: Unused function argument: B
(ARG001)
2045-2045: Unused function argument: A_scale
(ARG001)
2046-2046: Unused function argument: B_scale
(ARG001)
2048-2048: Unused function argument: out
(ARG001)
2049-2049: Unused function argument: backend
(ARG001)
2059-2059: Unused function argument: A_scale
(ARG001)
2060-2060: Unused function argument: B_scale
(ARG001)
2061-2061: Unused function argument: dtype
(ARG001)
2062-2062: Unused function argument: out
(ARG001)
2063-2063: Unused function argument: backend
(ARG001)
🔇 Additional comments (5)
flashinfer/gemm.py (5)
2000-2039: LGTM: Backend requirement functions properly structured.The three requirement functions (
_cudnn_bmm_fp8_requirement,_cublas_bmm_fp8_requirement,_cutlass_bmm_fp8_requirement) are correctly decorated with@supported_compute_capabilityto specify their SM support. The unused parameters flagged by static analysis are intentional—thebackend_requirementdecorator requires all checker functions to have matching signatures.The CUTLASS requirement correctly validates that e5m2 format is not supported (line 2037-2038).
Note: The static analysis warnings about unused parameters can be safely ignored in this context.
2055-2083: Good implementation of dynamic backend heuristic.The heuristic function correctly addresses the requirement from past review comments to dynamically order backends based on device architecture and problem characteristics. Key features:
- Filters out CUTLASS for e5m2 format (not supported)
- Maps to concrete runner names (
cutlass_sm10x,cutlass_sm12x) for different SM architectures- Applies the k_dim >= 128 constraint for SM12x CUTLASS variant
- Preserves preferred ordering: CUTLASS (when suitable), cuBLAS, cuDNN
Based on learnings
2171-2183: LGTM: Backend selection correctly handles auto, explicit cutlass, and other backends.The three-way branching properly addresses the different backend selection scenarios:
- Auto mode (line 2171-2172): Uses the decorator-computed
suitable_auto_backendslist- Explicit CUTLASS (line 2173-2176): Applies the heuristic function to derive concrete runner names (
cutlass_sm10xorcutlass_sm12x), fixing the issue from past comments where explicit CUTLASS would produce zero runnable tactics- Other explicit backends (line 2177-2180): Passes through as singleton lists
Based on learnings
367-375: LGTM: Runner selection updated for concrete backend names.The function correctly handles the concrete runner names (
cutlass_sm10x,cutlass_sm12x) produced by the heuristic function, replacing the previous generic backend names. The assertion at line 375 ensures at least one runner is available, providing a clear failure point if backend selection goes wrong.
2086-2094: SM110 exclusion in cuDNN/cuBLAS is intentional and justified.The compute capability annotations exclude SM110 from cuDNN and cuBLAS FP8 BMM support deliberately. Based on available evidence:
cuDNN has a known FP8 bug: cuDNN has a known FP8 scaled‑dot‑product attention deadlock/bug on sm10.0 (Blackwell), which explains the exclusion.
cuBLAS theoretically supports SM110, but is likely excluded for consistency and to avoid related compatibility issues with the broader FP8 stack.
Pattern in codebase confirms intentionality: Other backends explicitly block SM110 (e.g., TRTLLM backend raises
ValueError("TRTLLM FP8 GEMM is not supported on SM110.")), indicating deliberate SM110 FP8 exclusions are standard practice in this codebase due to known issues.No changes are needed.
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
flashinfer/utils.py (1)
1023-1040: Critical: Missing compute capability filtering in auto backend selection.The
suitable_auto_backendsfunction does not extract compute capability or verifyis_backend_supported(backend, capability)before including backends in the candidate list. This allows unsupported backends (e.g.,cublasFP8 on pre-Hopper GPUs) to be selected, causing runtime failures.The explicit backend path (lines 1088-1094) correctly validates capability, but the auto path bypasses this check entirely.
Apply this fix to add capability filtering:
def suitable_auto_backends(*args, **kwargs): if common_check is not None and not common_check(*args, **kwargs): return False + + # Extract compute capability from tensor arguments + capability = None + tensor_arg = None + for value in kwargs.values(): + if isinstance(value, torch.Tensor): + tensor_arg = value + break + if tensor_arg is not None: + major, minor = get_compute_capability(tensor_arg.device) + capability = major * 10 + minor + suitable_backends = [] # Check for each backend support for backend in backend_checks: + # Skip backends unsupported on current compute capability + if capability is not None and not is_backend_supported(backend, capability): + continue try: if backend_checks[backend](*args, **kwargs): suitable_backends.append(backend) except ValueError: continue # If a heuristic function is provided, filter the suitable backends based on the heuristic function if heuristic_func is not None: suitable_backends = heuristic_func(suitable_backends, *args, **kwargs) if not suitable_backends: return False wrapper.suitable_auto_backends = suitable_backends return TrueBased on learnings from past reviews.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between f76b98106a6050de30229b1e36522b99d128d20a and 259a3d3a3b89c31ac66f66891da398b577435992.
📒 Files selected for processing (3)
flashinfer/gemm.py(3 hunks)flashinfer/utils.py(3 hunks)tests/utils/test_decorators.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/utils/test_decorators.py (1)
flashinfer/utils.py (4)
supported_compute_capability(773-853)backend_requirement(856-1123)suitable_auto_backends(1023-1040)BackendSupportedError(64-67)
flashinfer/gemm.py (2)
flashinfer/utils.py (3)
supported_compute_capability(773-853)backend_requirement(856-1123)suitable_auto_backends(1023-1040)csrc/bmm_fp8.cu (2)
bmm_fp8(23-63)bmm_fp8(23-24)
🪛 Ruff (0.14.3)
tests/utils/test_decorators.py
340-340: Unused function argument: backend
(ARG001)
344-344: Unused function argument: backend
(ARG001)
348-348: Unused function argument: backend
(ARG001)
377-377: Unused function argument: backend
(ARG001)
381-381: Unused function argument: backend
(ARG001)
384-384: Unused function argument: suitable_backends
(ARG001)
384-384: Unused function argument: x
(ARG001)
384-384: Unused function argument: backend
(ARG001)
391-391: Unused function argument: backend
(ARG001)
flashinfer/utils.py
1084-1086: Avoid specifying long messages outside the exception class
(TRY003)
1092-1094: Avoid specifying long messages outside the exception class
(TRY003)
1096-1098: Avoid specifying long messages outside the exception class
(TRY003)
1101-1103: Avoid specifying long messages outside the exception class
(TRY003)
1105-1107: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer/gemm.py
2002-2002: Unused function argument: A
(ARG001)
2003-2003: Unused function argument: B
(ARG001)
2004-2004: Unused function argument: A_scale
(ARG001)
2005-2005: Unused function argument: B_scale
(ARG001)
2006-2006: Unused function argument: dtype
(ARG001)
2007-2007: Unused function argument: out
(ARG001)
2008-2008: Unused function argument: backend
(ARG001)
2016-2016: Unused function argument: A
(ARG001)
2017-2017: Unused function argument: B
(ARG001)
2018-2018: Unused function argument: A_scale
(ARG001)
2019-2019: Unused function argument: B_scale
(ARG001)
2020-2020: Unused function argument: dtype
(ARG001)
2021-2021: Unused function argument: out
(ARG001)
2022-2022: Unused function argument: backend
(ARG001)
2031-2031: Unused function argument: A_scale
(ARG001)
2032-2032: Unused function argument: B_scale
(ARG001)
2033-2033: Unused function argument: dtype
(ARG001)
2034-2034: Unused function argument: out
(ARG001)
2035-2035: Unused function argument: backend
(ARG001)
2038-2038: Avoid specifying long messages outside the exception class
(TRY003)
2043-2043: Unused function argument: A
(ARG001)
2044-2044: Unused function argument: B
(ARG001)
2045-2045: Unused function argument: A_scale
(ARG001)
2046-2046: Unused function argument: B_scale
(ARG001)
2048-2048: Unused function argument: out
(ARG001)
2049-2049: Unused function argument: backend
(ARG001)
2059-2059: Unused function argument: A_scale
(ARG001)
2060-2060: Unused function argument: B_scale
(ARG001)
2061-2061: Unused function argument: dtype
(ARG001)
2062-2062: Unused function argument: out
(ARG001)
2063-2063: Unused function argument: backend
(ARG001)
🔇 Additional comments (7)
tests/utils/test_decorators.py (2)
330-365: LGTM! Test structure validates auto backend selection.The test properly exercises the new
suitable_auto_backendsfeature with both success and error paths.Note: Static analysis warnings about unused
backendarguments in the requirement checker functions (lines 340, 344, 348) are false positives. These functions must match the signature of the decorated function even when not all parameters are used.
367-399: LGTM! Test validates heuristic function integration.The test correctly verifies that the heuristic function can reorder the suitable backends list.
Note: Static analysis warnings about unused arguments (lines 377, 381, 384, 391) are false positives. The requirement checker and heuristic functions must match expected signatures.
flashinfer/gemm.py (5)
2000-2011: Verify SM110 support for cuDNN FP8.The compute capability list includes
[89, 90, 100, 103, 120, 121]but excludes110. Based on past review comments about ensuring complete SM coverage, please confirm whether SM110 should be included in the supported list for cuDNN FP8 BMM.Note: Static analysis warnings about unused parameters are false positives—requirement checker functions must match the decorated function signature.
Based on learnings from past reviews.
2014-2024: LGTM! cuBLAS requirement checker is correct.The function correctly defines cuBLAS support for the specified compute capabilities. Same note about SM110 applies here as mentioned in the cuDNN requirement checker.
Note: Unused parameter warnings are false positives.
2027-2039: LGTM! CUTLASS requirement checker correctly validates e5m2 restriction.The function appropriately restricts CUTLASS to SM100+ architectures and rejects
float8_e5m2inputs, which aligns with CUTLASS FP8 kernel limitations.Note: Static analysis warning about long error message (TRY003) is acceptable here for clarity. Unused parameter warnings are false positives.
2042-2083: LGTM! Heuristic function correctly maps backends to SM-specific runners.The heuristic properly transforms the logical
"cutlass"backend into architecture-specific runner names ("cutlass_sm10x","cutlass_sm12x") based on device capability and problem size (k dimension check for SM120/121). This aligns with the updatedfp8_gemm_sm100function that expects these specific runner names.Note: Unused parameter warnings are false positives—these functions must match the decorated function signature.
2086-2094: LGTM! Decorator and implementation correctly handle auto and explicit backends.The decorator integrates the new per-backend requirements and heuristic function. The implementation correctly:
- Uses
bmm_fp8.suitable_auto_backendsforbackend="auto"(line 2172)- Applies the heuristic to explicit
"cutlass"to resolve runner names (lines 2174-2176)- Handles
"cudnn"with availability check (lines 2177-2178)- Falls through to other explicit backends (line 2180)
This addresses past review comments about explicit cutlass needing heuristic transformation.
Also applies to: 2171-2183
| elif skip_check and heuristic_func is not None: | ||
| bound_args = sig.bind(*args, **kwargs) | ||
| bound_args.apply_defaults() | ||
| kwargs_with_defaults = dict(bound_args.arguments) | ||
| # This needs to be called for heuristic function | ||
| suitable_auto_backends(*args, **kwargs) |
There was a problem hiding this comment.
Fix: Pass kwargs_with_defaults to suitable_auto_backends.
When skip_check=True and heuristic_func is present, the code applies defaults to create kwargs_with_defaults but then passes the original *args, **kwargs to suitable_auto_backends. This is inconsistent with the main validation path (line 1083) and may cause issues if default parameter values are needed by the heuristic.
Apply this fix:
elif skip_check and heuristic_func is not None:
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
kwargs_with_defaults = dict(bound_args.arguments)
# This needs to be called for heuristic function
- suitable_auto_backends(*args, **kwargs)
+ suitable_auto_backends(**kwargs_with_defaults)🤖 Prompt for AI Agents
In flashinfer/utils.py around lines 1108 to 1113, the code computes
kwargs_with_defaults but still calls suitable_auto_backends with the original
*args, **kwargs; replace that call so the heuristic receives the
applied-defaults arguments by calling suitable_auto_backends(*args,
**kwargs_with_defaults) (i.e., pass kwargs_with_defaults as the keyword args) to
match the main validation path.
|
[FAILED] Pipeline #37971846: 12/17 passed |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
flashinfer/utils.py (1)
1109-1112: Consider applying defaults in skip_check path when calling suitable_auto_backends.When
skip_check=Trueandbackend=="auto", the code callssuitable_auto_backends(*args, **kwargs)without applying defaults from the function signature. This is inconsistent with the main validation path (line 1083) which useskwargs_with_defaults.The heuristic function and backend requirement checkers may depend on default parameter values. If a parameter has a default value and the caller doesn't pass it explicitly, the heuristic won't see it.
Based on past review comments, there's a trade-off between correctness and keeping the
skip_checkpath lightweight. A middle-ground approach:elif skip_check and heuristic_func is not None: if kwargs.get("backend") == "auto": + # Apply defaults only when needed for auto-backend selection + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + kwargs_with_defaults = dict(bound_args.arguments) # This needs to be called for heuristic function - suitable_auto_backends(*args, **kwargs) + suitable_auto_backends(**kwargs_with_defaults)This adds minimal overhead—only when both
skip_check=TrueANDbackend="auto"ANDheuristic_funcis present.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between 259a3d3a3b89c31ac66f66891da398b577435992 and 330f89b.
📒 Files selected for processing (3)
flashinfer/gemm.py(3 hunks)flashinfer/utils.py(3 hunks)tests/utils/test_decorators.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/utils/test_decorators.py (1)
flashinfer/utils.py (4)
supported_compute_capability(773-853)backend_requirement(856-1122)suitable_auto_backends(1023-1040)BackendSupportedError(64-67)
flashinfer/gemm.py (2)
flashinfer/utils.py (3)
supported_compute_capability(773-853)backend_requirement(856-1122)suitable_auto_backends(1023-1040)csrc/bmm_fp8.cu (2)
bmm_fp8(23-63)bmm_fp8(23-24)
🪛 Ruff (0.14.3)
flashinfer/utils.py
1084-1086: Avoid specifying long messages outside the exception class
(TRY003)
1092-1094: Avoid specifying long messages outside the exception class
(TRY003)
1096-1098: Avoid specifying long messages outside the exception class
(TRY003)
1102-1104: Avoid specifying long messages outside the exception class
(TRY003)
1106-1108: Avoid specifying long messages outside the exception class
(TRY003)
tests/utils/test_decorators.py
340-340: Unused function argument: backend
(ARG001)
344-344: Unused function argument: backend
(ARG001)
348-348: Unused function argument: backend
(ARG001)
377-377: Unused function argument: backend
(ARG001)
381-381: Unused function argument: backend
(ARG001)
385-385: Unused function argument: backend
(ARG001)
388-388: Unused function argument: backend
(ARG001)
403-403: Unused function argument: backend
(ARG001)
flashinfer/gemm.py
2002-2002: Unused function argument: A
(ARG001)
2003-2003: Unused function argument: B
(ARG001)
2004-2004: Unused function argument: A_scale
(ARG001)
2005-2005: Unused function argument: B_scale
(ARG001)
2006-2006: Unused function argument: dtype
(ARG001)
2007-2007: Unused function argument: out
(ARG001)
2008-2008: Unused function argument: backend
(ARG001)
2016-2016: Unused function argument: A
(ARG001)
2017-2017: Unused function argument: B
(ARG001)
2018-2018: Unused function argument: A_scale
(ARG001)
2019-2019: Unused function argument: B_scale
(ARG001)
2020-2020: Unused function argument: dtype
(ARG001)
2021-2021: Unused function argument: out
(ARG001)
2022-2022: Unused function argument: backend
(ARG001)
2031-2031: Unused function argument: A_scale
(ARG001)
2032-2032: Unused function argument: B_scale
(ARG001)
2033-2033: Unused function argument: dtype
(ARG001)
2034-2034: Unused function argument: out
(ARG001)
2035-2035: Unused function argument: backend
(ARG001)
2038-2038: Avoid specifying long messages outside the exception class
(TRY003)
2043-2043: Unused function argument: A
(ARG001)
2044-2044: Unused function argument: B
(ARG001)
2045-2045: Unused function argument: A_scale
(ARG001)
2046-2046: Unused function argument: B_scale
(ARG001)
2048-2048: Unused function argument: out
(ARG001)
2049-2049: Unused function argument: backend
(ARG001)
2059-2059: Unused function argument: A_scale
(ARG001)
2060-2060: Unused function argument: B_scale
(ARG001)
2061-2061: Unused function argument: dtype
(ARG001)
2062-2062: Unused function argument: out
(ARG001)
2063-2063: Unused function argument: backend
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (8)
tests/utils/test_decorators.py (2)
330-365: LGTM - comprehensive test coverage for auto-backend selection.The test correctly validates that:
suitable_auto_backendsis populated with backends passing their requirement checks- The error path raises
BackendSupportedErrorwhen no suitable backends exist- The
suitable_auto_backendsattribute is accessible on the decorated functionNote: Static analysis warnings about unused
backendparameters (lines 340, 344, 348) are false positives—these parameters must match the decorated function's signature even if unused in the requirement check.
367-419: LGTM - thorough validation of heuristic function behavior.This test correctly validates that:
- The heuristic function receives the filtered list of suitable backends (after capability/requirement checks)
- The heuristic can reorder backends dynamically based on runtime parameters
- The
suitable_auto_backendsattribute reflects the heuristic's ordering- Different input shapes produce different backend orderings as expected
The assertion on line 390 that "cutlass" is not in suitable_backends is particularly good—it confirms that unsuitable backends are filtered before the heuristic runs.
Note: Static analysis warnings about unused parameters (lines 377, 381, 385, 388, 403) are false positives—these are required by the decorator pattern.
flashinfer/gemm.py (6)
367-375: LGTM - cleaner backend selection logic.The refactoring from SM-specific conditionals to a generic
runner_nameslist is cleaner and more maintainable. The new structure supports distinct "cutlass_sm10x" and "cutlass_sm12x" runners, aligning well with the heuristic function's output.The assertion on line 375 provides a clear error if no runners are configured, which is better than silently failing later.
2000-2039: Verify SM110 exclusion from cuDNN and cuBLAS backends is intentional.Comparing the
@supported_compute_capabilitydecorators across the three backend requirement functions:
- cuDNN (line 2000):
[89, 90, 100, 103, 120, 121]— missing SM110- cuBLAS (line 2014):
[89, 90, 100, 103, 120, 121]— missing SM110- CUTLASS (line 2027):
[100, 103, 110, 120, 121]— has SM110, but missing SM89/90This asymmetry may be intentional (reflecting actual backend/architecture support), but given the past review comment asking to verify SM110 and SM121 aren't missing, please confirm:
- Is SM110 truly unsupported by cuDNN and cuBLAS for FP8 BMM?
- Is the absence of SM89/90 from CUTLASS intentional?
If these exclusions are correct, consider adding a brief comment in the code or commit message explaining the architecture support matrix to prevent future confusion.
2042-2052: LGTM - appropriate common validation.The common check correctly validates the output dtype constraint that applies to all backends. Backend-specific constraints (like e5m2 support) are appropriately handled in the individual requirement functions.
Note: Static analysis warnings about unused parameters are false positives—the signature must match the decorated function even if not all parameters are used in the validation.
2055-2083: Good heuristic logic with SM-specific CUTLASS variant selection.The heuristic correctly handles:
- e5m2 exclusion for CUTLASS (line 2066, 2072) — aligns with the requirement check on line 2037-2038
- SM-specific CUTLASS variants (lines 2073-2078) — maps to the concrete runner names expected by
fp8_gemm_sm100- K-dimension constraint for SM12x (line 2077) — the
k >= 128check aligns with the SM120 groupwise scaling kernel requirement (see line 242-245 earlier in this file)The backend ordering (CUTLASS → cuBLAS → cuDNN) implies a performance preference, which is reasonable for auto-selection.
Note: Static analysis warnings about unused parameters (lines 2059-2063) are false positives—the heuristic signature must match the decorated function.
2086-2094: LGTM - decorator correctly configured.The decorator integrates all the new components:
- Backend-specific requirement functions with compute capability decorators
- Common validation for output dtype
- Heuristic function for dynamic backend ordering
This provides a complete auto-backend selection framework for
bmm_fp8.
2171-2180: Backend selection logic correctly handles all cases.The implementation properly handles:
backend="auto"(line 2171-2172): Uses the decorator-populatedsuitable_auto_backendslistbackend="cutlass"(lines 2173-2176): Explicitly calls the heuristic to map to SM-specific variants ("cutlass_sm10x" or "cutlass_sm12x"), resolving the past issue where explicit cutlass produced zero runnable tacticsbackend="cudnn"(lines 2177-2178): Guards withCUDNN_AVAILABLEcheck- Other backends (lines 2179-2180): Passes through as-is (e.g., "cublas")
The explicit cutlass path (lines 2173-2176) is particularly important—it ensures that users who explicitly request "cutlass" get the correct SM-specific variant, not an empty list.
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (2)
flashinfer/utils.py (2)
1023-1043: Ensurewrapper.suitable_auto_backendsis always set.The function returns
Falseon lines 1025 and 1041 without settingwrapper.suitable_auto_backends. In the skip_check path (line 1121), the return value isn't checked, which could leave callers expecting the attribute to exist when it doesn't.Consider always setting
wrapper.suitable_auto_backendsto provide consistent behavior:def suitable_auto_backends(cc, *args, **kwargs): if common_check is not None and not common_check(*args, **kwargs): + wrapper.suitable_auto_backends = [] return False suitable_backends = [] # Check for each backend support for backend in backend_checks: req_checker = backend_checks[backend] try: if req_checker( *args, **kwargs ) and req_checker.is_compute_capability_supported(cc): suitable_backends.append(backend) except ValueError: continue # If a heuristic function is provided, filter the suitable backends based on the heuristic function if heuristic_func is not None: suitable_backends = heuristic_func(suitable_backends, *args, **kwargs) + wrapper.suitable_auto_backends = suitable_backends if not suitable_backends: return False - wrapper.suitable_auto_backends = suitable_backends return TrueNote: The compute capability validation on line 1033 correctly addresses the concern raised in past reviews.
1083-1116: Optional: Consider extracting error messages to exception classes.Static analysis flags multiple long error messages (lines 1083-1085, 1092-1094, 1100-1102, 1104-1106, 1110-1112, 1114-1116) per PEP 8 style guidelines. While the current approach is clear and allows for dynamic context, you could optionally extract these to custom exception classes for better organization:
class NoSuitableBackendsError(BackendSupportedError): def __init__(self, func_name): super().__init__(f"No suitable auto backends found for {func_name}") class UnsupportedBackendError(BackendSupportedError): def __init__(self, func_name, backend, capability=None): extra = f" with capability {capability}" if capability else "" super().__init__(f"{func_name} does not support backend '{backend}'{extra}")This is a low-priority style improvement given the "Chill" review mode.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/utils.py(3 hunks)
🧰 Additional context used
🪛 Ruff (0.14.3)
flashinfer/utils.py
1083-1085: Avoid specifying long messages outside the exception class
(TRY003)
1092-1094: Avoid specifying long messages outside the exception class
(TRY003)
1100-1102: Avoid specifying long messages outside the exception class
(TRY003)
1104-1106: Avoid specifying long messages outside the exception class
(TRY003)
1110-1112: Avoid specifying long messages outside the exception class
(TRY003)
1114-1116: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (6)
flashinfer/utils.py (6)
857-859: LGTM: Function signature enhanced with heuristic support.The addition of
heuristic_funcparameter is properly typed and aligns with the PR's goal of enabling dynamic backend ordering for auto-selection.
1045-1062: LGTM: Capability extraction is correctly implemented.The helper function appropriately extracts compute capability from the first tensor argument. The assumption that all tensors share the same device (line 1049) is reasonable for performance.
1081-1081: LGTM: Capability extraction refactored to helper.Good refactoring that extracts capability once for reuse across validation paths.
1088-1106: LGTM: Auto and explicit backend validation paths are correctly implemented.The logic appropriately handles both auto-selection (lines 1088-1094) and explicit backend validation (lines 1095-1106), with proper error messages for unsupported backends.
1108-1116: LGTM: Backend-agnostic validation correctly implemented.The comment on line 1108 improves clarity, and the validation logic properly handles functions with implicit single backends.
Based on past review feedback.
1117-1121: Consider applying defaults before callingsuitable_auto_backendsin skip_check path.On line 1121,
suitable_auto_backendsis called with*args, **kwargswithout applying defaults, unlike the main validation path (line 1090) which useskwargs_with_defaults. If the heuristic function or backend requirement checkers rely on default parameter values, they could fail or behave incorrectly.Trade-off: Per past review feedback, keeping the skip_check path lightweight is important for performance. However, this creates an inconsistency that could cause subtle bugs.
Recommendation: Document that when using
skip_check=True, callers must explicitly provide all parameters that the heuristic or backend checks depend on, or apply defaults here:elif skip_check and heuristic_func is not None: if kwargs.get("backend") == "auto": + # Apply defaults for heuristic function + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + kwargs_with_defaults = dict(bound_args.arguments) # This needs to be called for heuristic function capability = _get_capability(*args, **kwargs) - suitable_auto_backends(capability, *args, **kwargs) + suitable_auto_backends(capability, **kwargs_with_defaults)Based on past review feedback.
…, heuristic_func intake (flashinfer-ai#2029) <!-- .github/pull_request_template.md --> <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> <!-- Link any related issues here --> Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> * **Improvements** * Expanded FP8 BMM backend support with explicit Cutlass SM10x/SM12x handling, safer fallbacks (no unconditional hard failures), and richer auto-selection that exposes viable backends and respects device capabilities. * Added heuristic-driven backend preference for auto and cutlass paths. * **Refactor** * Backend gating reorganized into per-backend capability checks, a shared problem-size pre-check, and heuristic selection; decorator now exposes suitable_auto_backends and capability extraction. * **Tests** * Added tests validating auto backend discovery and heuristic ordering. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Improvements
Refactor
Tests