Skip to content

fix: add _check_tensor_params to check correct sampling parameters and dtype validation in decode.py#1652

Merged
yzh119 merged 6 commits intoflashinfer-ai:mainfrom
raayandhar:check_tensor_params
Sep 26, 2025
Merged

fix: add _check_tensor_params to check correct sampling parameters and dtype validation in decode.py#1652
yzh119 merged 6 commits intoflashinfer-ai:mainfrom
raayandhar:check_tensor_params

Conversation

@raayandhar
Copy link
Copy Markdown
Contributor

@raayandhar raayandhar commented Sep 8, 2025

📌 Description

This adds the _check_tensor_params function in sampling.py that is used in the different sampling functions to check that the parameters make sense for this function (see related issues). This also adds some dtype validation checks in decode.py to make sure we don't silently produce wrong results when passing in the wrong dtype.

I tried to make just one small function that would cover the cases. But please let me know if there is a case I missed. Looking forward to the feedback.

🔍 Related Issues

This was originally from this issue #1634 that showed that we can pass relatively bogus inputs to sampling function and get wrong output. Now, when running his code, only the first case passes, which seems correct to me. The second case is passing a 2D tensor (2, 4) when the docs should indicate it should be (2,) (which passes). If it should be the other way around, let me know. The output is the same as well.

For dtype validation, this was originally from #1654. I add a check that is very similar to the ones I saw in mla.py.

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Added unit tests test_check_tensor_params_top_k, test_check_tensor_params_top_p, test_check_tensor_params_min_p that cover different functions and three different cases. Tested the dtype validation changes locally.

Reviewer Notes

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @raayandhar, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the robustness of the sampling functions by introducing a dedicated input validation utility. The primary goal is to prevent unexpected behavior and incorrect results stemming from improperly shaped or dimensioned tensor parameters, thereby improving the overall reliability of the sampling module.

Highlights

  • New Input Validation Function: Introduced _check_tensor_param in flashinfer/sampling.py to validate tensor parameters passed to various sampling functions.
  • Enhanced Parameter Checking: The new function ensures that tensor inputs are not scalar, not higher than 1D, and have a shape consistent with the main probs or logits tensor, preventing incorrect behavior from malformed inputs.
  • Improved Robustness: This change addresses a known issue (#1634) where invalid tensor inputs could lead to erroneous outputs in sampling functions, making the sampling module more robust.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@raayandhar raayandhar changed the title fix: add _check_tensor_params to check correct types fix: add _check_tensor_params to check correct parameters Sep 8, 2025
@raayandhar raayandhar changed the title fix: add _check_tensor_params to check correct parameters fix: add _check_tensor_params to check correct sampling parameters Sep 8, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new helper function _check_tensor_param to validate tensor parameters in various sampling functions, which is a good step towards improving input validation. My review focuses on enhancing this new function for better error reporting and optimizing one of its call sites to avoid unnecessary computation.

Additionally, while this PR fixes an important validation issue, it would be beneficial to add unit tests that specifically target these new checks. For instance, tests could be added to tests/test_sampling.py that pass invalid tensor shapes to the sampling functions and assert that a ValueError is raised. This would prevent future regressions.

Overall, the changes are beneficial, and with the suggested improvements, the code will be more robust and efficient.

Comment thread flashinfer/sampling.py Outdated
Comment thread flashinfer/sampling.py Outdated
@netanel-haber
Copy link
Copy Markdown
Contributor

Perhaps add a unittest that uses originial issue's repro code with self.assertRaises?

@raayandhar
Copy link
Copy Markdown
Contributor Author

raayandhar commented Sep 8, 2025

Perhaps add a unittest that uses originial issue's repro code with self.assertRaises?

Yes, good idea will add.

Edit: have added this unit test.

Comment thread tests/test_sampling.py Outdated
@raayandhar raayandhar changed the title fix: add _check_tensor_params to check correct sampling parameters fix: add _check_tensor_params to check correct sampling parameters and dtype validation in decode.py Sep 13, 2025
@raayandhar
Copy link
Copy Markdown
Contributor Author

I have added dtype validation in decode.py that is consistent with the same way it's done in mla.py. I also included this validation not only for last_page_len like the issue poster describes but for indptr and indices as well. Please let me know if that is not correct.

@raayandhar raayandhar requested a review from yzh119 September 16, 2025 19:37
jimmyzho pushed a commit to jimmyzho/flashinfer that referenced this pull request Sep 25, 2025
Co-authored-by: Alexander Zinoviev <azinoviev@tesla.com>
Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @raayandhar Thanks for your contribution, and I'm good with the changes here.

We can keep improving this by moving the checks to C++ side to reduce python overhead (in later PRs).

Comment thread flashinfer/decode.py

The :meth:`plan` method cannot be used in Cuda Graph or in ``torch.compile``.
"""
for tensor, name in [
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Later on we can move them to C++ side to save python overhead.

Comment thread flashinfer/sampling.py
return (None, x)


def _check_tensor_param(param: Any, tensor: torch.Tensor) -> None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could also be C++ side function.

@yzh119 yzh119 merged commit aca1e41 into flashinfer-ai:main Sep 26, 2025
2 checks passed
@raayandhar
Copy link
Copy Markdown
Contributor Author

raayandhar commented Sep 26, 2025

Hi @raayandhar Thanks for your contribution, and I'm good with the changes here.

We can keep improving this by moving the checks to C++ side to reduce python overhead (in later PRs).

I see, that sounds better. I can try that as well and make a PR later. There are other checks that I was looking at as reference that can also be moved to C++ side beyond the ones here so I will try to include those as well.

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Sep 26, 2025

Hi @raayandhar it's better to do that after #1641 where we changed a lot on C++ side structure (we will merge it today).

@raayandhar
Copy link
Copy Markdown
Contributor Author

Hi @raayandhar it's better to do that after #1641 where we changed a lot on C++ side structure (we will merge it today).

I see, thanks for letting me know.

@raayandhar raayandhar mentioned this pull request Dec 2, 2025
5 tasks
yzh119 pushed a commit that referenced this pull request Dec 3, 2025
<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

Originally in this PR
#1652 and
#2127 we added better
error messaging / prevent silent failures for wrong dtype / shape of
tensors. However, it was on the python side when we can instead move to
the C++ side (actually the `.cu` side, I guess). We already have various
checks here via macros, so it is somewhat natural.

## 🔍 Related Issues

See the issues in the PRs above.

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Native modules now perform stricter runtime validation of int index
inputs and sampling-parameter shapes; some Python-level preflight checks
were removed, deferring certain dtype/shape errors to lower-level code.

* **Refactor**
* Added a centralized, reusable validation helper for optional sampling
parameters to unify checks.

* **Tests**
* Updated tests to expect different error types and more general
error-message matching for the adjusted validation behavior.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
juju812 pushed a commit to juju812/flashinfer that referenced this pull request Dec 4, 2025
<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

Originally in this PR
flashinfer-ai#1652 and
flashinfer-ai#2127 we added better
error messaging / prevent silent failures for wrong dtype / shape of
tensors. However, it was on the python side when we can instead move to
the C++ side (actually the `.cu` side, I guess). We already have various
checks here via macros, so it is somewhat natural.

## 🔍 Related Issues

See the issues in the PRs above.

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Native modules now perform stricter runtime validation of int index
inputs and sampling-parameter shapes; some Python-level preflight checks
were removed, deferring certain dtype/shape errors to lower-level code.

* **Refactor**
* Added a centralized, reusable validation helper for optional sampling
parameters to unify checks.

* **Tests**
* Updated tests to expect different error types and more general
error-message matching for the adjusted validation behavior.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
BingooYang pushed a commit to BingooYang/flashinfer that referenced this pull request Mar 13, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

Originally in this PR
flashinfer-ai#1652 and
flashinfer-ai#2127 we added better
error messaging / prevent silent failures for wrong dtype / shape of
tensors. However, it was on the python side when we can instead move to
the C++ side (actually the `.cu` side, I guess). We already have various
checks here via macros, so it is somewhat natural.

## 🔍 Related Issues

See the issues in the PRs above.

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Native modules now perform stricter runtime validation of int index
inputs and sampling-parameter shapes; some Python-level preflight checks
were removed, deferring certain dtype/shape errors to lower-level code.

* **Refactor**
* Added a centralized, reusable validation helper for optional sampling
parameters to unify checks.

* **Tests**
* Updated tests to expect different error types and more general
error-message matching for the adjusted validation behavior.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
murphymatt pushed a commit to fw-ai/flashinfer that referenced this pull request Mar 31, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

Originally in this PR
flashinfer-ai/flashinfer#1652 and
flashinfer-ai/flashinfer#2127 we added better
error messaging / prevent silent failures for wrong dtype / shape of
tensors. However, it was on the python side when we can instead move to
the C++ side (actually the `.cu` side, I guess). We already have various
checks here via macros, so it is somewhat natural.

## 🔍 Related Issues

See the issues in the PRs above.

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Native modules now perform stricter runtime validation of int index
inputs and sampling-parameter shapes; some Python-level preflight checks
were removed, deferring certain dtype/shape errors to lower-level code.

* **Refactor**
* Added a centralized, reusable validation helper for optional sampling
parameters to unify checks.

* **Tests**
* Updated tests to expect different error types and more general
error-message matching for the adjusted validation behavior.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants