fix: add a check for int32 indices in sampling.py#2127
fix: add a check for int32 indices in sampling.py#2127yzh119 merged 3 commits intoflashinfer-ai:mainfrom
Conversation
Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
Summary of ChangesHello @raayandhar, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a crucial type validation for Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdded an internal Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Caller
participant Sampling as Sampling Module
participant Validator as _check_indices_dtype
Caller->>Sampling: call sampling_fn(inputs, ..., indices)
alt indices provided
Sampling->>Validator: _check_indices_dtype(indices)
alt dtype == int32
Validator-->>Sampling: OK
Sampling->>Sampling: proceed with sampling logic
Sampling-->>Caller: return samples
else dtype != int32
Validator-->>Sampling: raise ValueError("indices must be torch.int32")
Sampling-->>Caller: propagate ValueError
end
else no indices
Sampling->>Sampling: proceed with sampling logic
Sampling-->>Caller: return samples
end
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes
Poem
Pre-merge checks and finishing touches✅ Passed checks (5 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
🔇 Additional comments (1)
Tip 📝 Customizable high-level summaries are now available in beta!You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.
Example instruction:
Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request correctly adds a check to ensure that the indices tensor, when provided, has a torch.int32 dtype. The new validation function _check_indices_dtype is properly integrated into the various sampling functions, and the docstrings have been updated accordingly. I have one suggestion to improve the clarity of the newly added test case.
Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
yzh119
left a comment
There was a problem hiding this comment.
It should be easy to support int64 indices for these kernels as well, but let's left them for future PRs.
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> Originally in this PR #1652 and #2127 we added better error messaging / prevent silent failures for wrong dtype / shape of tensors. However, it was on the python side when we can instead move to the C++ side (actually the `.cu` side, I guess). We already have various checks here via macros, so it is somewhat natural. ## 🔍 Related Issues See the issues in the PRs above. <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Native modules now perform stricter runtime validation of int index inputs and sampling-parameter shapes; some Python-level preflight checks were removed, deferring certain dtype/shape errors to lower-level code. * **Refactor** * Added a centralized, reusable validation helper for optional sampling parameters to unify checks. * **Tests** * Updated tests to expect different error types and more general error-message matching for the adjusted validation behavior. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> Originally in this PR flashinfer-ai#1652 and flashinfer-ai#2127 we added better error messaging / prevent silent failures for wrong dtype / shape of tensors. However, it was on the python side when we can instead move to the C++ side (actually the `.cu` side, I guess). We already have various checks here via macros, so it is somewhat natural. ## 🔍 Related Issues See the issues in the PRs above. <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Native modules now perform stricter runtime validation of int index inputs and sampling-parameter shapes; some Python-level preflight checks were removed, deferring certain dtype/shape errors to lower-level code. * **Refactor** * Added a centralized, reusable validation helper for optional sampling parameters to unify checks. * **Tests** * Updated tests to expect different error types and more general error-message matching for the adjusted validation behavior. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
<!-- .github/pull_request_template.md --> ## 📌 Description Based on this [comment](#2127 (review)) in #2127, we can add support for Int64 indices as well. I decided to do this using `IdType` like it is done in other files. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). Test results: ``` (flashinfer) raayan@uril-1:~/projects/flashinfer$ pytest tests/utils/test_sampling.py ============================================================= test session starts ============================================================= platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0 rootdir: /home/raayan/projects/flashinfer configfile: pytest.ini collected 1884 items tests/utils/test_sampling.py .......................................................................................................... [ 5%] ....................................................................................................................................... [ 12%] ....................................................................................................................................... [ 19%] ....................s..s..s..........................................................................sss........................sss.... [ 27%] ....................................................................................................................................... [ 34%] ..........................ssss................................ssss................................ssss................................s [ 41%] sss................................ssss................................ssss................................ssss........................ [ 48%] ........ssss................................ssss................................ssss................................ssss............... [ 55%] .................ssss................................ssss................................ssss................................ssss...... [ 62%] ..........................ssss................................ssss................................ssss................................s [ 70%] sss................................ssss................................ssss................................ssss........................ [ 77%] ........ssss................................ssss................................ssss................................ssss............... [ 84%] .................ssss.................................................................................................................. [ 91%] ........................................................sss............................................................................ [ 98%] ....................... [100%] ================================================ 1764 passed, 120 skipped in 546.33s (0:09:06) ================================================ (flashinfer) raayan@uril-1:~/projects/flashinfer$ ``` ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Signed-off-by: raayandhar <raayan.dhar@gmail.com>
<!-- .github/pull_request_template.md --> ## 📌 Description New function to validate that the indices type, when provided, is `int32`. To close flashinfer-ai#2115. There are now two separate functions doing checking in this file. I will move them to the C++ side later when I have some more bandwidth, probably after Thanksgiving. Just a short fix for now. You can close if you'd rather wait for that. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues flashinfer-ai#2115 <!-- Link any related issues here --> Relevant to the issue. Now running their code: ``` (flashinfer) raayan@uril-1:~/projects/flashinfer$ python test.py tensor([1, 1, 0, 0], device='cuda:0', dtype=torch.int32) Traceback (most recent call last): File "/home/raayan/projects/flashinfer/test.py", line 15, in <module> incorrect_samples = flashinfer.sampling.top_k_top_p_sampling_from_logits( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/raayan/projects/flashinfer/flashinfer/sampling.py", line 1031, in top_k_top_p_sampling_from_logits _check_indices_dtype(indices) File "/home/raayan/projects/flashinfer/flashinfer/sampling.py", line 487, in _check_indices_dtype raise ValueError(f"indices must have dtype torch.int32, got {indices.dtype}") ValueError: indices must have dtype torch.int32, got torch.int64 ``` ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Improvements** * Enforced that indices passed to sampling operations must use int32, adding runtime validation before sampling. * **Documentation** * Clarified docstrings to state the int32 requirement for indices parameters. * **Tests** * Updated and expanded tests to cover the new dtype validation paths and related error cases. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> Originally in this PR flashinfer-ai#1652 and flashinfer-ai#2127 we added better error messaging / prevent silent failures for wrong dtype / shape of tensors. However, it was on the python side when we can instead move to the C++ side (actually the `.cu` side, I guess). We already have various checks here via macros, so it is somewhat natural. ## 🔍 Related Issues See the issues in the PRs above. <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Native modules now perform stricter runtime validation of int index inputs and sampling-parameter shapes; some Python-level preflight checks were removed, deferring certain dtype/shape errors to lower-level code. * **Refactor** * Added a centralized, reusable validation helper for optional sampling parameters to unify checks. * **Tests** * Updated tests to expect different error types and more general error-message matching for the adjusted validation behavior. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> Originally in this PR flashinfer-ai/flashinfer#1652 and flashinfer-ai/flashinfer#2127 we added better error messaging / prevent silent failures for wrong dtype / shape of tensors. However, it was on the python side when we can instead move to the C++ side (actually the `.cu` side, I guess). We already have various checks here via macros, so it is somewhat natural. ## 🔍 Related Issues See the issues in the PRs above. <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Native modules now perform stricter runtime validation of int index inputs and sampling-parameter shapes; some Python-level preflight checks were removed, deferring certain dtype/shape errors to lower-level code. * **Refactor** * Added a centralized, reusable validation helper for optional sampling parameters to unify checks. * **Tests** * Updated tests to expect different error types and more general error-message matching for the adjusted validation behavior. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
<!-- .github/pull_request_template.md --> ## 📌 Description Based on this [comment](flashinfer-ai/flashinfer#2127 (review)) in flashinfer-ai/flashinfer#2127, we can add support for Int64 indices as well. I decided to do this using `IdType` like it is done in other files. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). Test results: ``` (flashinfer) raayan@uril-1:~/projects/flashinfer$ pytest tests/utils/test_sampling.py ============================================================= test session starts ============================================================= platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0 rootdir: /home/raayan/projects/flashinfer configfile: pytest.ini collected 1884 items tests/utils/test_sampling.py .......................................................................................................... [ 5%] ....................................................................................................................................... [ 12%] ....................................................................................................................................... [ 19%] ....................s..s..s..........................................................................sss........................sss.... [ 27%] ....................................................................................................................................... [ 34%] ..........................ssss................................ssss................................ssss................................s [ 41%] sss................................ssss................................ssss................................ssss........................ [ 48%] ........ssss................................ssss................................ssss................................ssss............... [ 55%] .................ssss................................ssss................................ssss................................ssss...... [ 62%] ..........................ssss................................ssss................................ssss................................s [ 70%] sss................................ssss................................ssss................................ssss........................ [ 77%] ........ssss................................ssss................................ssss................................ssss............... [ 84%] .................ssss.................................................................................................................. [ 91%] ........................................................sss............................................................................ [ 98%] ....................... [100%] ================================================ 1764 passed, 120 skipped in 546.33s (0:09:06) ================================================ (flashinfer) raayan@uril-1:~/projects/flashinfer$ ``` ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Signed-off-by: raayandhar <raayan.dhar@gmail.com>
📌 Description
New function to validate that the indices type, when provided, is
int32. To close #2115.There are now two separate functions doing checking in this file. I will move them to the C++ side later when I have some more bandwidth, probably after Thanksgiving. Just a short fix for now. You can close if you'd rather wait for that.
🔍 Related Issues
#2115
Relevant to the issue. Now running their code:
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Improvements
Documentation
Tests
✏️ Tip: You can customize this high-level summary in your review settings.