Added workspace check and reflected this in test#1991
Added workspace check and reflected this in test#1991nvmbreughe merged 5 commits intoflashinfer-ai:mainfrom
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdded optional workspace metadata creation in the IPC workspace API and metadata-based validation in the all-reduce fusion call. Signatures changed to support Changes
Sequence Diagram(s)sequenceDiagram
participant Test
participant CreateWS as trtllm_create_ipc_workspace_for_all_reduce_fusion
participant AllReduce as trtllm_allreduce_fusion
Test->>CreateWS: create_metadata=True, (tp_rank,tp_size,max_token_num,hidden_dim,...)
CreateWS-->>Test: (ipc_handles, workspace_tensor, metadata)
note right of Test: metadata → {tp_rank,tp_size,max_token_num,hidden_dim,use_fp32_lamport,buffer_size,flag_size,lamport_comm_size,lamport_buffer_size}
Test->>AllReduce: allreduce_in, world_size, token_num, hidden_dim, workspace_ptrs, ..., metadata
rect rgb(245,250,255)
note over AllReduce: Validation when metadata provided
AllReduce->>AllReduce: assert token_num ≤ metadata.max_token_num
AllReduce->>AllReduce: assert world_size == metadata.tp_size
AllReduce->>AllReduce: assert hidden_dim == metadata.hidden_dim
end
alt Validation passes
AllReduce->>AllReduce: launch fused kernels / run logic
AllReduce-->>Test: completion
else Validation fails
AllReduce-->>Test: raise ValueError (detailed message)
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
Summary of ChangesHello @nvmbreughe, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a crucial safety mechanism to the all-reduce fusion operations by implementing workspace validation checks. It addresses a potential issue where incorrect token counts or dimension mismatches could lead to memory access errors or system hangs. By optionally providing and consuming workspace metadata, the system can now proactively detect and prevent these inconsistencies, enhancing the robustness and stability of the communication primitives. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a valuable validation check to prevent potential illegal memory access errors by ensuring num_tokens does not exceed the workspace's max_token_num. The changes are implemented in a non-breaking way by using optional parameters, which is a good practice for maintaining API compatibility. The tests have been updated accordingly, and a bug in the test cleanup logic was also fixed.
I have a couple of suggestions to improve the maintainability and type safety of the new API, related to using typing.overload for clearer function signatures and a TypedDict or dataclass for the metadata dictionary. Overall, this is a solid contribution that improves the robustness of the library.
| ) -> Union[ | ||
| Tuple[List[List[int]], torch.Tensor], Tuple[List[List[int]], torch.Tensor, dict] | ||
| ]: |
There was a problem hiding this comment.
While using Union of tuples with different lengths maintains API compatibility, it can be challenging for static analysis tools and IDEs to handle correctly. For better type safety and developer experience, you could use typing.overload to define the function signatures explicitly. This would make it clear to users of the function what to expect for different inputs, improving usability and reducing potential for errors.
| metadata = { | ||
| "tp_rank": tp_rank, | ||
| "tp_size": tp_size, | ||
| "max_token_num": max_token_num, | ||
| "hidden_dim": hidden_dim, | ||
| "use_fp32_lamport": use_fp32_lamport, | ||
| "buffer_size": buffer_size, | ||
| "flag_size": flag_size, | ||
| "lamport_comm_size": lamport_comm_size, | ||
| "lamport_buffer_size": lamport_buffer_size, | ||
| } |
There was a problem hiding this comment.
Using a raw dict for metadata is prone to errors from typos in keys, both when creating it here and when consuming it in trtllm_allreduce_fusion. To make this more robust and self-documenting, consider defining a TypedDict or a dataclass for the metadata. This provides static type checking and autocompletion support in IDEs, reducing the chance of runtime errors.
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (3)
flashinfer/comm/trtllm_ar.py (2)
503-506: Return typing is fine; consider a stronger metadata type for clarity.Union return is OK. To improve API clarity, define a TypedDict for metadata and use it in the signature and docstrings. This avoids untyped dict usage downstream.
-from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, TypedDict + +class AllReduceFusionWorkspaceMeta(TypedDict): + tp_rank: int + tp_size: int + max_token_num: int + hidden_dim: int + use_fp32_lamport: bool + buffer_size: int + flag_size: int + lamport_comm_size: int + lamport_buffer_size: int @@ -) -> Union[ - Tuple[List[List[int]], torch.Tensor], Tuple[List[List[int]], torch.Tensor, dict] -]: +) -> Union[ + Tuple[List[List[int]], torch.Tensor], + Tuple[List[List[int]], torch.Tensor, AllReduceFusionWorkspaceMeta], +]:
621-636: Good: metadata emission. Consider adding device and validating key presence.Including creation params is useful. Add a 'device' (CUDA index) for extra safety and validate required keys on consumption to avoid KeyError if a caller passes a partial dict.
metadata = { "tp_rank": tp_rank, "tp_size": tp_size, "max_token_num": max_token_num, "hidden_dim": hidden_dim, "use_fp32_lamport": use_fp32_lamport, + "device": torch.cuda.current_device(), "buffer_size": buffer_size, "flag_size": flag_size, "lamport_comm_size": lamport_comm_size, "lamport_buffer_size": lamport_buffer_size, }tests/comm/test_trtllm_allreduce_fusion.py (1)
60-71: Good: create workspace with metadata; add quick negative-path checks here.Leverage the new validation with a small trio of negative assertions right after creating metadata to guard against regressions (no heavy kernels executed).
ipc_handles, workspace_tensor, workspace_metadata = ( comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( rank, world_size, MAX_TOKEN_NUM, hidden_dim, group=group, use_fp32_lamport=lamport_use_fp32, create_metadata=True, # Get metadata for validation ) ) + + # Negative-path validation: token_num overflow + bad_token_num = MAX_TOKEN_NUM + 1 + with pytest.raises(ValueError, match="token_num .* exceeds"): + comm.trtllm_allreduce_fusion( + allreduce_in=torch.empty(bad_token_num * hidden_dim, dtype=dtype, device=device), + world_size=world_size, + world_rank=rank, + token_num=bad_token_num, + hidden_dim=hidden_dim, + workspace_ptrs=workspace_tensor, + launch_with_pdl=True, + trigger_completion_at_end=True, + fp32_acc=False, + pattern_code=comm.AllReduceFusionPattern.kAllReduce, + use_oneshot=None, + allreduce_out=None, + residual_in=None, + residual_out=None, + norm_out=None, + quant_out=None, + scale_out=None, + rms_gamma=None, + rms_eps=None, + scale_factor=None, + layout_code=None, + metadata=workspace_metadata, + ) + + # Negative-path validation: world_size mismatch + meta_ws_mismatch = dict(workspace_metadata) + meta_ws_mismatch["tp_size"] = world_size + 1 + with pytest.raises(ValueError, match="world_size .* does not match"): + comm.trtllm_allreduce_fusion( + allreduce_in=torch.empty(1 * hidden_dim, dtype=dtype, device=device), + world_size=world_size, + world_rank=rank, + token_num=1, + hidden_dim=hidden_dim, + workspace_ptrs=workspace_tensor, + launch_with_pdl=True, + trigger_completion_at_end=True, + fp32_acc=False, + pattern_code=comm.AllReduceFusionPattern.kAllReduce, + use_oneshot=None, + allreduce_out=None, + residual_in=None, + residual_out=None, + norm_out=None, + quant_out=None, + scale_out=None, + rms_gamma=None, + rms_eps=None, + scale_factor=None, + layout_code=None, + metadata=meta_ws_mismatch, + ) + + # Negative-path validation: hidden_dim mismatch + meta_hd_mismatch = dict(workspace_metadata) + meta_hd_mismatch["hidden_dim"] = hidden_dim + 64 + with pytest.raises(ValueError, match="hidden_dim .* does not match"): + comm.trtllm_allreduce_fusion( + allreduce_in=torch.empty(1 * hidden_dim, dtype=dtype, device=device), + world_size=world_size, + world_rank=rank, + token_num=1, + hidden_dim=hidden_dim, + workspace_ptrs=workspace_tensor, + launch_with_pdl=True, + trigger_completion_at_end=True, + fp32_acc=False, + pattern_code=comm.AllReduceFusionPattern.kAllReduce, + use_oneshot=None, + allreduce_out=None, + residual_in=None, + residual_out=None, + norm_out=None, + quant_out=None, + scale_out=None, + rms_gamma=None, + rms_eps=None, + scale_factor=None, + layout_code=None, + metadata=meta_hd_mismatch, + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/comm/trtllm_ar.py(6 hunks)tests/comm/test_trtllm_allreduce_fusion.py(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/comm/test_trtllm_allreduce_fusion.py (1)
flashinfer/comm/trtllm_ar.py (2)
trtllm_create_ipc_workspace_for_all_reduce_fusion(496-635)trtllm_destroy_ipc_workspace_for_all_reduce_fusion(638-654)
flashinfer/comm/trtllm_ar.py (1)
flashinfer/comm/mapping.py (1)
tp_rank(325-326)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (5)
flashinfer/comm/trtllm_ar.py (2)
893-898: LGTM: oneshot fallback guard.Warning and disabling oneshot when exceeding MAX_COMM_SIZE is appropriate.
818-819: No breaking changes detected; new param added safely.The new
metadataparameter is added as a trailing optional parameter with a default value (None), which is the safe way to extend a function signature. The call at line 249 passes 21 positional arguments that correctly match the first 21 function parameters in order. The new trailing parameter does not affect this or any other call site—all existing calls continue to work without modification. Test calls (lines 166, 196) and internal calls (line 906) use keyword arguments, which are unaffected.tests/comm/test_trtllm_allreduce_fusion.py (3)
188-189: LGTM: pass metadata during warmup.This exercises the validation path without affecting CUDA graph capture.
218-219: LGTM: pass metadata during capture.Safe since validation occurs on Python side before kernel launch; no graph-safety issues.
310-312: LGTM: updated destroy to fusion variant.Matches the new creation path and keeps resource lifecycle consistent.
| dist.barrier(group=group) | ||
|
|
||
| comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group=group) | ||
| comm.trtllm_destroy_ipc_workspace_for_all_reduce_fusion( |
There was a problem hiding this comment.
@yzh119 note this subtle issue of calling the wrong destroy function. It is not a problem here, as they have the same implementation.
As an alternative solution for this, as well as for the root cause of #1986 (workspace size insufficient) we could implement this as an AR-class. e.g., we would have trtllm_create_ipc_workspace_for_all_reduce_fusion return an AR object. The destructor would clean up (using the correct destructor :)), and we could just add members like "max_token_size", etc, for workspace validation.
This PR addresses this by having trtllm_create_ipc_workspace_for_all_reduce_fusion return a metadata dictionary that allows to do checks. It's not ideal, but it is non-API breaking. It's not high priority, but we could consider the proposed alternative solution, and deprecate both custom_all_reduce and all_reduce_fusion variants.
There was a problem hiding this comment.
agreed making it a class should be a more reasonable solution
| """ | ||
|
|
||
| # Validate against workspace metadata if provided | ||
| if metadata is not None: |
There was a problem hiding this comment.
To avoid similar cases where num_tokens > MAX_TOKEN_NUM, the check is necessary anyway — so why not just return meta regardless?
There was a problem hiding this comment.
because it would break the API. We can only do that for major bumps if we want to respect semantic versioning.
There was a problem hiding this comment.
I didn't address this, but we should keep it in mind for an upcoming major bump. I added a TODO comment so we don't forget.
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (3)
flashinfer/comm/trtllm_ar.py (3)
510-513: Consider@overloadfor clearer type signatures.While the
Unionreturn type maintains compatibility, static analysis tools and IDEs struggle with variable-length tuple unions. Usingtyping.overloadwould provide explicit signatures for each case, improving type safety and developer experience.Example:
@overload def trtllm_create_ipc_workspace_for_all_reduce_fusion( tp_rank: int, tp_size: int, max_token_num: int, hidden_dim: int, use_fp32_lamport: bool = False, group: Optional[ProcessGroup] = None, create_metadata: Literal[False] = False, ) -> Tuple[List[List[int]], torch.Tensor]: ... @overload def trtllm_create_ipc_workspace_for_all_reduce_fusion( tp_rank: int, tp_size: int, max_token_num: int, hidden_dim: int, use_fp32_lamport: bool = False, group: Optional[ProcessGroup] = None, create_metadata: Literal[True] = ..., ) -> Tuple[List[List[int]], torch.Tensor, dict]: ...
628-642: ConsiderTypedDictordataclassfor type-safe metadata.Raw dicts are vulnerable to typos in both creation and consumption sites. Defining a structured type would enable static type checking, IDE autocompletion, and early error detection.
Example with
TypedDict:from typing import TypedDict class WorkspaceMetadata(TypedDict): tp_rank: int tp_size: int max_token_num: int hidden_dim: int use_fp32_lamport: bool buffer_size: int flag_size: int lamport_comm_size: int lamport_buffer_size: int
913-917: Major: Size calculation should use lamport dtype from metadata.When metadata is provided, the required lamport size should be calculated using
metadata["use_fp32_lamport"]rather thanallreduce_in.dtype. The workspace's lamport buffer dtype determines safety, not the input tensor dtype. Using the input dtype can underestimate required bytes and incorrectly enable oneshot mode, leading to buffer overruns.Apply this diff:
+ # Determine element size: use metadata's lamport dtype when available + if metadata is not None: + elem_bytes = 4 if metadata["use_fp32_lamport"] else 2 + else: + elem_bytes = 4 if allreduce_in.dtype == torch.float32 else 2 + required_lamport_comm_size = ( - token_num * hidden_dim * 2 * world_size - if allreduce_in.dtype != torch.float32 - else token_num * hidden_dim * 4 * world_size + token_num * hidden_dim * elem_bytes * world_size )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/comm/trtllm_ar.py(10 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (5)
flashinfer/comm/trtllm_ar.py (5)
22-22: LGTM: Import added for deprecation decorators.The
typing_extensions.deprecatedimport supports the deprecation decorators applied to legacy functions, guiding users toward the fusion API.
125-127: LGTM: Deprecation decorators guide users to fusion API.The deprecation markers on legacy functions provide a clear migration path to
trtllm_create_ipc_workspace_for_all_reduce_fusionandtrtllm_allreduce_fusion.Also applies to: 400-402, 709-711
522-528: LGTM: Docstrings accurately document the new behavior.The documentation correctly describes the conditional return values and accurately reflects the lamport buffer size computation based on
max_token_numwithout referencing the deprecatedOneShotMaxToken.Also applies to: 537-538
828-828: LGTM: Optional metadata parameter maintains backward compatibility.The default
Nonevalue ensures existing callers continue to work while enabling validation when metadata is explicitly provided.
853-855: LGTM: Documentation clearly explains validation behavior.The docstring effectively communicates the metadata parameter's purpose and the consequences of validation failures.
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (1)
flashinfer/comm/trtllm_ar.py (1)
893-898: Critical: Inverted logic inuse_fp32_lamportvalidation.Line 894 uses
==which raises an error when metadata and dtype correctly match, allowing mismatches to pass. This is inverted and defeats the validation purpose.Apply this diff to fix the inverted condition:
- if metadata["use_fp32_lamport"] == (allreduce_in.dtype == torch.float32): + if metadata["use_fp32_lamport"] != (allreduce_in.dtype == torch.float32):
🧹 Nitpick comments (2)
flashinfer/comm/trtllm_ar.py (2)
512-514: Considertyping.overloadfor better type safety (optional).As noted in a previous review, using
@overloaddecorators would provide clearer type signatures for differentcreate_metadatavalues and improve IDE autocompletion. While the currentUnionapproach works, overloads make it explicit what callers receive based on the boolean flag.Based on learnings
630-640: ConsiderTypedDictordataclassfor metadata (recommended).As noted in a previous review, defining a structured type for the metadata dict would prevent typos in keys and enable static type checking. This would make the API more robust when metadata is created here and consumed in
trtllm_allreduce_fusion.Based on learnings
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/comm/trtllm_ar.py(10 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (3)
flashinfer/comm/trtllm_ar.py (3)
22-22: LGTM: Deprecation markers correctly applied.The deprecation decorators properly guide users toward the fusion APIs and maintain backward compatibility.
Also applies to: 125-127, 400-402, 710-712
862-891: LGTM: Validation logic for keys, token_num, world_size, and hidden_dim is correct.The checks appropriately detect missing keys, token overflow, and dimension mismatches, providing clear error messages.
829-829: LGTM: Metadata parameter properly added with clear documentation.The optional parameter maintains API compatibility while enabling validation when workspace metadata is provided.
Also applies to: 854-856
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
flashinfer/comm/trtllm_ar.py (1)
859-905: Comprehensive workspace metadata validation addresses issue #1986.The validation logic correctly prevents illegal memory access by checking token_num against max_token_num and ensuring all workspace parameters match the runtime configuration. The error accumulation pattern provides excellent user feedback by reporting all mismatches in a single error message.
The critical issue from past reviews regarding inverted logic at line 894 has been fixed—the condition now correctly uses
!=to detect mismatches betweenuse_fp32_lamportand the input dtype.Optional: Minor Pythonic improvement for key checking.
Lines 863-865 could be slightly more concise using set operations:
- for key in required_keys: - if key not in metadata: - errors.append(f"Workspace metadata is missing required key: {key}") + missing_keys = set(required_keys) - set(metadata.keys()) + if missing_keys: + errors.append(f"Workspace metadata is missing required keys: {sorted(missing_keys)}")This is purely stylistic and doesn't affect correctness.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/comm/trtllm_ar.py(10 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (6)
flashinfer/comm/trtllm_ar.py (6)
22-22: LGTM: Appropriate import for API deprecation.The
deprecateddecorator import is correctly placed and supports the API migration strategy.
125-127: LGTM: Consistent deprecation strategy.The deprecation decorators are applied consistently to both the custom op registrations and wrapper functions, with clear migration guidance pointing users to the fusion-based APIs.
Also applies to: 400-402, 710-712
503-514: Optional metadata return maintains backward compatibility.The
create_metadataparameter allows callers to opt into receiving workspace metadata for validation purposes without breaking existing code. The TODO comment appropriately flags that this should become the default behavior in a future major version.Note: A past review suggested using
@overloadfor better type safety instead ofUnionof tuples, but the current approach is functional and maintains simplicity.Also applies to: 523-530
538-539: LGTM: Documentation now matches implementation.The lamport buffer size formula correctly documents the dtype-dependent sizing (2 bytes for fp16/bf16, 4 bytes for fp32), matching the conditional logic in lines 551-555.
629-643: Conditional metadata return logic is correct.The implementation correctly returns metadata when requested, including all parameters necessary for downstream validation in
trtllm_allreduce_fusion.Note: A past review suggested using
TypedDictordataclassfor the metadata structure to improve type safety and reduce typo-related errors, but the current dict-based approach is functional.
829-829: LGTM: Optional metadata parameter enables workspace validation.The optional
metadataparameter allows callers to opt into runtime validation while maintaining backward compatibility with existing code that doesn't provide metadata.Also applies to: 854-856
<!-- .github/pull_request_template.md --> ## 📌 Description This PR attempts to fix flashinfer-ai#1986 (to be confirmed by requester) The issue is that num_tokens was larger than MAX_TOKEN_NUM, which results in an IMA, or even in a hang. To address this, I added a validation check. This required a non-breaking API change: * create_ipc_workspace_for_all_reduce_fusion now has an optional "create_metadata" bool, which results in an additional return value * it is made optional as additional return value could break the API * trtllm_allreduce_fusion now takes an optional metadata dictionary * When provided, this will run the validation check * again, this is also optional, to avoid breaking the api In addition this PR deprecates the older AllReduce functionality so it can be removed in a major version bump. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **API Changes** * Workspace creation can optionally return metadata describing the workspace configuration (create_metadata flag). * Allreduce fusion operations accept optional metadata to validate runtime parameters against the workspace and raise clear errors on mismatch. * A workspace destruction endpoint was renamed for naming consistency. * Legacy wrappers were marked deprecated and now point users toward the newer fusion variants. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
This PR attempts to fix #1986 (to be confirmed by requester)
The issue is that num_tokens was larger than MAX_TOKEN_NUM, which results in an IMA, or even in a hang. To address this, I added a validation check. This required a non-breaking API change:
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Bug Fixes
Chores