fix: add _check_tensor_params to check correct sampling parameters and dtype validation in decode.py#1652
Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @raayandhar, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly enhances the robustness of the sampling functions by introducing a dedicated input validation utility. The primary goal is to prevent unexpected behavior and incorrect results stemming from improperly shaped or dimensioned tensor parameters, thereby improving the overall reliability of the sampling module.
Highlights
- New Input Validation Function: Introduced
_check_tensor_paraminflashinfer/sampling.pyto validate tensor parameters passed to various sampling functions. - Enhanced Parameter Checking: The new function ensures that tensor inputs are not scalar, not higher than 1D, and have a shape consistent with the main
probsorlogitstensor, preventing incorrect behavior from malformed inputs. - Improved Robustness: This change addresses a known issue (#1634) where invalid tensor inputs could lead to erroneous outputs in sampling functions, making the sampling module more robust.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request introduces a new helper function _check_tensor_param to validate tensor parameters in various sampling functions, which is a good step towards improving input validation. My review focuses on enhancing this new function for better error reporting and optimizing one of its call sites to avoid unnecessary computation.
Additionally, while this PR fixes an important validation issue, it would be beneficial to add unit tests that specifically target these new checks. For instance, tests could be added to tests/test_sampling.py that pass invalid tensor shapes to the sampling functions and assert that a ValueError is raised. This would prevent future regressions.
Overall, the changes are beneficial, and with the suggested improvements, the code will be more robust and efficient.
|
Perhaps add a unittest that uses originial issue's repro code with |
Yes, good idea will add. Edit: have added this unit test. |
|
I have added |
45f9a5b to
a8b6e2d
Compare
Co-authored-by: Alexander Zinoviev <azinoviev@tesla.com>
yzh119
left a comment
There was a problem hiding this comment.
Hi @raayandhar Thanks for your contribution, and I'm good with the changes here.
We can keep improving this by moving the checks to C++ side to reduce python overhead (in later PRs).
|
|
||
| The :meth:`plan` method cannot be used in Cuda Graph or in ``torch.compile``. | ||
| """ | ||
| for tensor, name in [ |
There was a problem hiding this comment.
Later on we can move them to C++ side to save python overhead.
| return (None, x) | ||
|
|
||
|
|
||
| def _check_tensor_param(param: Any, tensor: torch.Tensor) -> None: |
There was a problem hiding this comment.
This could also be C++ side function.
I see, that sounds better. I can try that as well and make a PR later. There are other checks that I was looking at as reference that can also be moved to C++ side beyond the ones here so I will try to include those as well. |
|
Hi @raayandhar it's better to do that after #1641 where we changed a lot on C++ side structure (we will merge it today). |
I see, thanks for letting me know. |
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> Originally in this PR #1652 and #2127 we added better error messaging / prevent silent failures for wrong dtype / shape of tensors. However, it was on the python side when we can instead move to the C++ side (actually the `.cu` side, I guess). We already have various checks here via macros, so it is somewhat natural. ## 🔍 Related Issues See the issues in the PRs above. <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Native modules now perform stricter runtime validation of int index inputs and sampling-parameter shapes; some Python-level preflight checks were removed, deferring certain dtype/shape errors to lower-level code. * **Refactor** * Added a centralized, reusable validation helper for optional sampling parameters to unify checks. * **Tests** * Updated tests to expect different error types and more general error-message matching for the adjusted validation behavior. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> Originally in this PR flashinfer-ai#1652 and flashinfer-ai#2127 we added better error messaging / prevent silent failures for wrong dtype / shape of tensors. However, it was on the python side when we can instead move to the C++ side (actually the `.cu` side, I guess). We already have various checks here via macros, so it is somewhat natural. ## 🔍 Related Issues See the issues in the PRs above. <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Native modules now perform stricter runtime validation of int index inputs and sampling-parameter shapes; some Python-level preflight checks were removed, deferring certain dtype/shape errors to lower-level code. * **Refactor** * Added a centralized, reusable validation helper for optional sampling parameters to unify checks. * **Tests** * Updated tests to expect different error types and more general error-message matching for the adjusted validation behavior. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> Originally in this PR flashinfer-ai#1652 and flashinfer-ai#2127 we added better error messaging / prevent silent failures for wrong dtype / shape of tensors. However, it was on the python side when we can instead move to the C++ side (actually the `.cu` side, I guess). We already have various checks here via macros, so it is somewhat natural. ## 🔍 Related Issues See the issues in the PRs above. <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Native modules now perform stricter runtime validation of int index inputs and sampling-parameter shapes; some Python-level preflight checks were removed, deferring certain dtype/shape errors to lower-level code. * **Refactor** * Added a centralized, reusable validation helper for optional sampling parameters to unify checks. * **Tests** * Updated tests to expect different error types and more general error-message matching for the adjusted validation behavior. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> Originally in this PR flashinfer-ai/flashinfer#1652 and flashinfer-ai/flashinfer#2127 we added better error messaging / prevent silent failures for wrong dtype / shape of tensors. However, it was on the python side when we can instead move to the C++ side (actually the `.cu` side, I guess). We already have various checks here via macros, so it is somewhat natural. ## 🔍 Related Issues See the issues in the PRs above. <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Native modules now perform stricter runtime validation of int index inputs and sampling-parameter shapes; some Python-level preflight checks were removed, deferring certain dtype/shape errors to lower-level code. * **Refactor** * Added a centralized, reusable validation helper for optional sampling parameters to unify checks. * **Tests** * Updated tests to expect different error types and more general error-message matching for the adjusted validation behavior. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
📌 Description
This adds the
_check_tensor_paramsfunction insampling.pythat is used in the different sampling functions to check that the parameters make sense for this function (see related issues). This also adds somedtypevalidation checks indecode.pyto make sure we don't silently produce wrong results when passing in the wrongdtype.I tried to make just one small function that would cover the cases. But please let me know if there is a case I missed. Looking forward to the feedback.
🔍 Related Issues
This was originally from this issue #1634 that showed that we can pass relatively bogus inputs to sampling function and get wrong output. Now, when running his code, only the first case passes, which seems correct to me. The second case is passing a 2D tensor (2, 4) when the docs should indicate it should be (2,) (which passes). If it should be the other way around, let me know. The output is the same as well.
For
dtypevalidation, this was originally from #1654. I add a check that is very similar to the ones I saw inmla.py.🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Added unit tests
test_check_tensor_params_top_k,test_check_tensor_params_top_p,test_check_tensor_params_min_pthat cover different functions and three different cases. Tested thedtypevalidation changes locally.Reviewer Notes