Skip to content

Update Mamba selective_state_scan API signature#2392

Merged
yzh119 merged 3 commits intoflashinfer-ai:mainfrom
shaharmor98:fix/allow-user-provided-out-tensor-mamba-kernel
Jan 22, 2026
Merged

Update Mamba selective_state_scan API signature#2392
yzh119 merged 3 commits intoflashinfer-ai:mainfrom
shaharmor98:fix/allow-user-provided-out-tensor-mamba-kernel

Conversation

@shaharmor98
Copy link
Copy Markdown
Contributor

@shaharmor98 shaharmor98 commented Jan 21, 2026

📌 Description

The new Mamba selective_state_update kernel doesn't allow the user to provide the output tensor to write results to.
In some use cases, for example in SGLang and vLLM implementations, the output tensor is allocated before the kernel call. Because of that, we need to allow the user to provide their output tensor.
If the user doesn't provide an output tensor, we would fall back to allocating one on the fly.

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added an optional output-tensor argument to selective state update operations, allowing callers to provide a pre-allocated tensor for reduced allocations and controlled memory reuse.
  • Tests

    • Extended tests to cover usage of the optional output-tensor, verifying returned object identity when a tensor is supplied and preserving numerical behavior.

✏️ Tip: You can customize this high-level summary in your review settings.

Signed-off-by: Shahar Mor <smor@nvidia.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @shaharmor98, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant enhancement to the selective_state_update API within FlashInfer's Mamba implementation. By adding an optional out parameter, the function now allows callers to supply their own pre-allocated output tensor. This modification provides greater flexibility for integrating with systems like SGLang and vLLM, which often manage tensor allocation externally, thereby optimizing memory usage and potentially improving performance by avoiding unnecessary reallocations.

Highlights

  • API Enhancement: The selective_state_update function now accepts an optional out parameter, allowing users to provide a pre-allocated output tensor.
  • Memory Management: If the out parameter is not provided, the function will fall back to allocating a new tensor, maintaining backward compatibility.
  • Use Case Support: This change supports specific use cases in frameworks like SGLang and vLLM where output tensors are often pre-allocated.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jan 21, 2026

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

📝 Walkthrough

Walkthrough

The selective_state_update function in flashinfer.mamba gained an optional out: torch.Tensor | None parameter; when provided the function writes into that tensor, otherwise it allocates a new output via torch.empty_like(x). Internal call to _selective_state_update is updated to receive the chosen output tensor and docstring/tests updated.

Changes

Cohort / File(s) Summary
Core implementation
flashinfer/mamba/selective_state_update.py
Added optional out: torch.Tensor | None = None parameter to selective_state_update; choose provided out or allocate via torch.empty_like(x); pass chosen output into internal _selective_state_update; updated docstring.
Tests
tests/mamba/test_selective_state_update.py
Added use_out_tensor parameterization to tests; conditionally allocate and pass an out tensor and assert returned tensor identity when out is supplied; retained numerical assertions.

Sequence Diagram(s)

(omitted)

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Suggested reviewers

  • cyx-6
  • jimmyzho

Poem

🐰 A tiny hop, a tensor new,
Pass it in or let it brew.
Output found where I decree,
Pre-fill or leave it free.
Mamba whispers: "Use me, please!" ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 inconclusive)
Check name Status Explanation Resolution
Description check ❓ Inconclusive The description includes the required Description section explaining the motivation and implementation details, but is missing the Related Issues section and has not marked all Test checklist items as complete. Link any related issues in the Related Issues section and clearly indicate whether tests have been added/updated and are passing by checking the corresponding boxes.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the main change: updating the API signature of the Mamba selective_state_scan function to accept an optional output tensor parameter.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an out parameter to the selective_state_update function, allowing users to provide a pre-allocated output tensor. This is a useful feature for performance-critical applications where memory allocation overhead is a concern. The implementation is straightforward.

My main feedback is:

  • The provided out tensor should be validated for shape, dtype, device, and contiguity to prevent cryptic runtime errors in the CUDA kernel. I've added a specific comment with a code suggestion for this.
  • The changes are not covered by tests. Please add a new test case that utilizes the out parameter to ensure it functions as expected. This should verify that the result is correctly written to the provided tensor and that the function returns the same tensor object.

Comment on lines +121 to +124
if out is None:
output = torch.empty_like(x)
else:
output = out
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While the C++ backend performs some checks on the output tensor, it's good practice to add validation in the Python API for user-provided tensors. This provides clearer error messages and fails earlier if an incompatible tensor is passed.

Please add checks for the out tensor's shape, dtype, device, and contiguity to ensure it's compatible with what the kernel expects. The out tensor should match the properties of the input tensor x after it has been normalized (unsqueezed).

    if out is None:
        output = torch.empty_like(x)
    else:
        if out.shape != x.shape:
            raise ValueError(
                f"out.shape {out.shape} must be equal to x.shape {x.shape}"
            )
        if out.dtype != x.dtype:
            raise ValueError(
                f"out.dtype {out.dtype} must be equal to x.dtype {x.dtype}"
            )
        if out.device != x.device:
            raise ValueError(
                f"out.device {out.device} must be equal to x.device {x.device}"
            )
        if out.stride(-1) != 1:
            raise ValueError("The last dimension of out must be contiguous")
        output = out

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, added some unittests to this PR.

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Jan 22, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !256 has been created, and the CI pipeline #42252945 is currently running. I'll report back once the pipeline job completes.

@ishovkun
Copy link
Copy Markdown
Contributor

LGTM as well.

@yzh119 yzh119 merged commit b464fbd into flashinfer-ai:main Jan 22, 2026
21 checks passed
@claude claude Bot mentioned this pull request Jan 23, 2026
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants