Update Mamba selective_state_scan API signature#2392
Conversation
Signed-off-by: Shahar Mor <smor@nvidia.com>
Summary of ChangesHello @shaharmor98, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a significant enhancement to the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughThe Changes
Sequence Diagram(s)(omitted) Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces an out parameter to the selective_state_update function, allowing users to provide a pre-allocated output tensor. This is a useful feature for performance-critical applications where memory allocation overhead is a concern. The implementation is straightforward.
My main feedback is:
- The provided
outtensor should be validated for shape, dtype, device, and contiguity to prevent cryptic runtime errors in the CUDA kernel. I've added a specific comment with a code suggestion for this. - The changes are not covered by tests. Please add a new test case that utilizes the
outparameter to ensure it functions as expected. This should verify that the result is correctly written to the provided tensor and that the function returns the same tensor object.
| if out is None: | ||
| output = torch.empty_like(x) | ||
| else: | ||
| output = out |
There was a problem hiding this comment.
While the C++ backend performs some checks on the output tensor, it's good practice to add validation in the Python API for user-provided tensors. This provides clearer error messages and fails earlier if an incompatible tensor is passed.
Please add checks for the out tensor's shape, dtype, device, and contiguity to ensure it's compatible with what the kernel expects. The out tensor should match the properties of the input tensor x after it has been normalized (unsqueezed).
if out is None:
output = torch.empty_like(x)
else:
if out.shape != x.shape:
raise ValueError(
f"out.shape {out.shape} must be equal to x.shape {x.shape}"
)
if out.dtype != x.dtype:
raise ValueError(
f"out.dtype {out.dtype} must be equal to x.dtype {x.dtype}"
)
if out.device != x.device:
raise ValueError(
f"out.device {out.device} must be equal to x.device {x.device}"
)
if out.stride(-1) != 1:
raise ValueError("The last dimension of out must be contiguous")
output = out…ed-out-tensor-mamba-kernel
yzh119
left a comment
There was a problem hiding this comment.
LGTM, added some unittests to this PR.
|
/bot run |
|
LGTM as well. |
📌 Description
The new Mamba
selective_state_updatekernel doesn't allow the user to provide the output tensor to write results to.In some use cases, for example in SGLang and vLLM implementations, the output tensor is allocated before the kernel call. Because of that, we need to allow the user to provide their output tensor.
If the user doesn't provide an output tensor, we would fall back to allocating one on the fly.
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.