Skip to content

fix: add a check for int32 indices in sampling.py#2127

Merged
yzh119 merged 3 commits intoflashinfer-ai:mainfrom
raayandhar:user/rdhar/int32_sampling_indices
Nov 22, 2025
Merged

fix: add a check for int32 indices in sampling.py#2127
yzh119 merged 3 commits intoflashinfer-ai:mainfrom
raayandhar:user/rdhar/int32_sampling_indices

Conversation

@raayandhar
Copy link
Copy Markdown
Contributor

@raayandhar raayandhar commented Nov 21, 2025

📌 Description

New function to validate that the indices type, when provided, is int32. To close #2115.
There are now two separate functions doing checking in this file. I will move them to the C++ side later when I have some more bandwidth, probably after Thanksgiving. Just a short fix for now. You can close if you'd rather wait for that.

🔍 Related Issues

#2115

Relevant to the issue. Now running their code:

(flashinfer) raayan@uril-1:~/projects/flashinfer$ python test.py 
tensor([1, 1, 0, 0], device='cuda:0', dtype=torch.int32)
Traceback (most recent call last):
  File "/home/raayan/projects/flashinfer/test.py", line 15, in <module>
    incorrect_samples = flashinfer.sampling.top_k_top_p_sampling_from_logits(
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/raayan/projects/flashinfer/flashinfer/sampling.py", line 1031, in top_k_top_p_sampling_from_logits
    _check_indices_dtype(indices)
  File "/home/raayan/projects/flashinfer/flashinfer/sampling.py", line 487, in _check_indices_dtype
    raise ValueError(f"indices must have dtype torch.int32, got {indices.dtype}")
ValueError: indices must have dtype torch.int32, got torch.int64

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Improvements

    • Enforced that indices passed to sampling operations must use int32, adding runtime validation before sampling.
  • Documentation

    • Clarified docstrings to state the int32 requirement for indices parameters.
  • Tests

    • Updated and expanded tests to cover the new dtype validation paths and related error cases.

✏️ Tip: You can customize this high-level summary in your review settings.

Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @raayandhar, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a crucial type validation for indices tensors within the sampling module, ensuring they are consistently torch.int32. This change directly resolves an existing issue where int64 indices could cause unexpected failures, thereby enhancing the robustness and reliability of the sampling operations across the codebase.

Highlights

  • Indices Type Validation: A new utility function, _check_indices_dtype, has been introduced to validate that indices tensors passed to sampling functions are of torch.int32 dtype. This addresses an issue where int64 indices could lead to runtime errors.
  • Integration of Type Check: The _check_indices_dtype function is now called within all relevant sampling functions (sampling_from_logits, sampling_from_probs, top_p_sampling_from_probs, top_k_sampling_from_probs, min_p_sampling_from_probs, top_k_top_p_sampling_from_logits, top_k_top_p_sampling_from_probs) to enforce the int32 requirement for indices.
  • Documentation Update: The docstrings for the indices parameter in several sampling functions have been updated to explicitly state the torch.int32 dtype requirement, improving clarity for developers.
  • New Test Case: A new test case has been added to test_sampling.py to specifically validate that a ValueError is raised when indices with torch.int64 dtype are provided to min_p_sampling_from_probs, ensuring the new validation works as expected.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Nov 21, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Added an internal _check_indices_dtype(indices) validator enforcing torch.int32 for any provided indices tensor and invoked it at the start of multiple sampling entrypoints; docstrings were updated to state the dtype requirement and tests were adjusted to cover non-int32 rejection.

Changes

Cohort / File(s) Summary
Sampling validation
flashinfer/sampling.py
Added _check_indices_dtype(indices) to validate that indices tensors are torch.int32. Inserted calls to this validator at the start of sampling functions handling logits, probs, top_p, top_k, min_p, and combined top_k/top_p variants. Updated docstrings to document the torch.int32 requirement.
Tests — sampling utils
tests/utils/test_sampling.py
Renamed test_check_tensor_param_min_ptest_tensor_validation_min_p. Updated tests to pass float32 for min_p indicator where applicable and added assertions that non-int32 indices raise a ValueError about dtype; preserved batch-size mismatch checks.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Caller
  participant Sampling as Sampling Module
  participant Validator as _check_indices_dtype
  Caller->>Sampling: call sampling_fn(inputs, ..., indices)
  alt indices provided
    Sampling->>Validator: _check_indices_dtype(indices)
    alt dtype == int32
      Validator-->>Sampling: OK
      Sampling->>Sampling: proceed with sampling logic
      Sampling-->>Caller: return samples
    else dtype != int32
      Validator-->>Sampling: raise ValueError("indices must be torch.int32")
      Sampling-->>Caller: propagate ValueError
    end
  else no indices
    Sampling->>Sampling: proceed with sampling logic
    Sampling-->>Caller: return samples
  end
Loading

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

  • Pay attention to the exact error message text in _check_indices_dtype() for clarity and test alignment.
  • Verify every sampling entrypoint that accepts indices calls the validator.
  • Confirm docstring updates match function signatures and test expectations.

Poem

🐰 I nibble bugs beneath the code's glow,
I check each index — int32, steady and slow.
No long sneaks past where the samples hop,
I guard the paths from quiet flop.
Hooray — clean dtype, onward we go! 🥕

Pre-merge checks and finishing touches

✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: adding validation for int32 dtype on indices in sampling.py functions.
Description check ✅ Passed The description includes motivation (closing issue #2115), related issue link, and working example demonstrating the fix. However, it lacks explicit mention of the test additions beyond a checkbox.
Linked Issues check ✅ Passed The PR directly addresses issue #2115 by implementing validation to reject indices with non-int32 dtypes, which prevents silent failures when long/int64 indices are used.
Out of Scope Changes check ✅ Passed All changes are directly scoped to the objective: adding _check_indices_dtype validation function, applying it across sampling functions, updating docstrings, and adding corresponding tests.
Docstring Coverage ✅ Passed Docstring coverage is 81.82% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 71e48ea and 214a7a9.

📒 Files selected for processing (1)
  • tests/utils/test_sampling.py (3 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (1)
tests/utils/test_sampling.py (1)

575-602: Revamped min-p tensor validation test looks correct and aligned with API semantics

Renaming to test_tensor_validation_min_p and switching the 2D and 0D min_p tensors to dtype=torch.float32 matches the intended “probability” semantics and keeps the error-path checks (rank and batch-size handling) intact. The scalar/1D success cases and the 2D/0D failure cases together give good coverage of the Python-side validation for min_p_sampling_from_probs.

Tip

📝 Customizable high-level summaries are now available in beta!

You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.

  • Provide your own instructions using the high_level_summary_instructions setting.
  • Format the summary however you like (bullet lists, tables, multi-section layouts, contributor stats, etc.).
  • Use high_level_summary_in_walkthrough to move the summary from the description to the walkthrough section.

Example instruction:

"Divide the high-level summary into five sections:

  1. 📝 Description — Summarize the main change in 50–60 words, explaining what was done.
  2. 📓 References — List relevant issues, discussions, documentation, or related PRs.
  3. 📦 Dependencies & Requirements — Mention any new/updated dependencies, environment variable changes, or configuration updates.
  4. 📊 Contributor Summary — Include a Markdown table showing contributions:
    | Contributor | Lines Added | Lines Removed | Files Changed |
  5. ✔️ Additional Notes — Add any extra reviewer context.
    Keep each section concise (under 200 words) and use bullet or numbered lists for clarity."

Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly adds a check to ensure that the indices tensor, when provided, has a torch.int32 dtype. The new validation function _check_indices_dtype is properly integrated into the various sampling functions, and the docstrings have been updated accordingly. I have one suggestion to improve the clarity of the newly added test case.

Comment thread tests/utils/test_sampling.py Outdated
Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be easy to support int64 indices for these kernels as well, but let's left them for future PRs.

@yzh119 yzh119 merged commit 5e11004 into flashinfer-ai:main Nov 22, 2025
4 checks passed
@raayandhar raayandhar mentioned this pull request Dec 2, 2025
5 tasks
yzh119 pushed a commit that referenced this pull request Dec 3, 2025
<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

Originally in this PR
#1652 and
#2127 we added better
error messaging / prevent silent failures for wrong dtype / shape of
tensors. However, it was on the python side when we can instead move to
the C++ side (actually the `.cu` side, I guess). We already have various
checks here via macros, so it is somewhat natural.

## 🔍 Related Issues

See the issues in the PRs above.

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Native modules now perform stricter runtime validation of int index
inputs and sampling-parameter shapes; some Python-level preflight checks
were removed, deferring certain dtype/shape errors to lower-level code.

* **Refactor**
* Added a centralized, reusable validation helper for optional sampling
parameters to unify checks.

* **Tests**
* Updated tests to expect different error types and more general
error-message matching for the adjusted validation behavior.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
juju812 pushed a commit to juju812/flashinfer that referenced this pull request Dec 4, 2025
<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

Originally in this PR
flashinfer-ai#1652 and
flashinfer-ai#2127 we added better
error messaging / prevent silent failures for wrong dtype / shape of
tensors. However, it was on the python side when we can instead move to
the C++ side (actually the `.cu` side, I guess). We already have various
checks here via macros, so it is somewhat natural.

## 🔍 Related Issues

See the issues in the PRs above.

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Native modules now perform stricter runtime validation of int index
inputs and sampling-parameter shapes; some Python-level preflight checks
were removed, deferring certain dtype/shape errors to lower-level code.

* **Refactor**
* Added a centralized, reusable validation helper for optional sampling
parameters to unify checks.

* **Tests**
* Updated tests to expect different error types and more general
error-message matching for the adjusted validation behavior.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
yzh119 pushed a commit that referenced this pull request Jan 3, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

Based on this
[comment](#2127 (review))
in #2127, we can add
support for Int64 indices as well. I decided to do this using `IdType`
like it is done in other files.

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

Test results:

```
(flashinfer) raayan@uril-1:~/projects/flashinfer$ pytest tests/utils/test_sampling.py
============================================================= test session starts =============================================================
platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0
rootdir: /home/raayan/projects/flashinfer
configfile: pytest.ini
collected 1884 items

tests/utils/test_sampling.py .......................................................................................................... [  5%]
....................................................................................................................................... [ 12%]
....................................................................................................................................... [ 19%]
....................s..s..s..........................................................................sss........................sss.... [ 27%]
....................................................................................................................................... [ 34%]
..........................ssss................................ssss................................ssss................................s [ 41%]
sss................................ssss................................ssss................................ssss........................ [ 48%]
........ssss................................ssss................................ssss................................ssss............... [ 55%]
.................ssss................................ssss................................ssss................................ssss...... [ 62%]
..........................ssss................................ssss................................ssss................................s [ 70%]
sss................................ssss................................ssss................................ssss........................ [ 77%]
........ssss................................ssss................................ssss................................ssss............... [ 84%]
.................ssss.................................................................................................................. [ 91%]
........................................................sss............................................................................ [ 98%]
.......................                                                                                                                 [100%]

================================================ 1764 passed, 120 skipped in 546.33s (0:09:06) ================================================
(flashinfer) raayan@uril-1:~/projects/flashinfer$
```


## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

---------

Signed-off-by: raayandhar <raayan.dhar@gmail.com>
BingooYang pushed a commit to BingooYang/flashinfer that referenced this pull request Mar 13, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

New function to validate that the indices type, when provided, is
`int32`. To close
flashinfer-ai#2115.
There are now two separate functions doing checking in this file. I will
move them to the C++ side later when I have some more bandwidth,
probably after Thanksgiving. Just a short fix for now. You can close if
you'd rather wait for that.

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

flashinfer-ai#2115
<!-- Link any related issues here -->

Relevant to the issue. Now running their code:
```
(flashinfer) raayan@uril-1:~/projects/flashinfer$ python test.py 
tensor([1, 1, 0, 0], device='cuda:0', dtype=torch.int32)
Traceback (most recent call last):
  File "/home/raayan/projects/flashinfer/test.py", line 15, in <module>
    incorrect_samples = flashinfer.sampling.top_k_top_p_sampling_from_logits(
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/raayan/projects/flashinfer/flashinfer/sampling.py", line 1031, in top_k_top_p_sampling_from_logits
    _check_indices_dtype(indices)
  File "/home/raayan/projects/flashinfer/flashinfer/sampling.py", line 487, in _check_indices_dtype
    raise ValueError(f"indices must have dtype torch.int32, got {indices.dtype}")
ValueError: indices must have dtype torch.int32, got torch.int64
```

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Improvements**
* Enforced that indices passed to sampling operations must use int32,
adding runtime validation before sampling.

* **Documentation**
* Clarified docstrings to state the int32 requirement for indices
parameters.

* **Tests**
* Updated and expanded tests to cover the new dtype validation paths and
related error cases.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
BingooYang pushed a commit to BingooYang/flashinfer that referenced this pull request Mar 13, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

Originally in this PR
flashinfer-ai#1652 and
flashinfer-ai#2127 we added better
error messaging / prevent silent failures for wrong dtype / shape of
tensors. However, it was on the python side when we can instead move to
the C++ side (actually the `.cu` side, I guess). We already have various
checks here via macros, so it is somewhat natural.

## 🔍 Related Issues

See the issues in the PRs above.

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Native modules now perform stricter runtime validation of int index
inputs and sampling-parameter shapes; some Python-level preflight checks
were removed, deferring certain dtype/shape errors to lower-level code.

* **Refactor**
* Added a centralized, reusable validation helper for optional sampling
parameters to unify checks.

* **Tests**
* Updated tests to expect different error types and more general
error-message matching for the adjusted validation behavior.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
murphymatt pushed a commit to fw-ai/flashinfer that referenced this pull request Mar 31, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

Originally in this PR
flashinfer-ai/flashinfer#1652 and
flashinfer-ai/flashinfer#2127 we added better
error messaging / prevent silent failures for wrong dtype / shape of
tensors. However, it was on the python side when we can instead move to
the C++ side (actually the `.cu` side, I guess). We already have various
checks here via macros, so it is somewhat natural.

## 🔍 Related Issues

See the issues in the PRs above.

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Native modules now perform stricter runtime validation of int index
inputs and sampling-parameter shapes; some Python-level preflight checks
were removed, deferring certain dtype/shape errors to lower-level code.

* **Refactor**
* Added a centralized, reusable validation helper for optional sampling
parameters to unify checks.

* **Tests**
* Updated tests to expect different error types and more general
error-message matching for the adjusted validation behavior.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
murphymatt pushed a commit to fw-ai/flashinfer that referenced this pull request Mar 31, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

Based on this
[comment](flashinfer-ai/flashinfer#2127 (review))
in flashinfer-ai/flashinfer#2127, we can add
support for Int64 indices as well. I decided to do this using `IdType`
like it is done in other files.

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

Test results:

```
(flashinfer) raayan@uril-1:~/projects/flashinfer$ pytest tests/utils/test_sampling.py
============================================================= test session starts =============================================================
platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0
rootdir: /home/raayan/projects/flashinfer
configfile: pytest.ini
collected 1884 items

tests/utils/test_sampling.py .......................................................................................................... [  5%]
....................................................................................................................................... [ 12%]
....................................................................................................................................... [ 19%]
....................s..s..s..........................................................................sss........................sss.... [ 27%]
....................................................................................................................................... [ 34%]
..........................ssss................................ssss................................ssss................................s [ 41%]
sss................................ssss................................ssss................................ssss........................ [ 48%]
........ssss................................ssss................................ssss................................ssss............... [ 55%]
.................ssss................................ssss................................ssss................................ssss...... [ 62%]
..........................ssss................................ssss................................ssss................................s [ 70%]
sss................................ssss................................ssss................................ssss........................ [ 77%]
........ssss................................ssss................................ssss................................ssss............... [ 84%]
.................ssss.................................................................................................................. [ 91%]
........................................................sss............................................................................ [ 98%]
.......................                                                                                                                 [100%]

================================================ 1764 passed, 120 skipped in 546.33s (0:09:06) ================================================
(flashinfer) raayan@uril-1:~/projects/flashinfer$
```


## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

---------

Signed-off-by: raayandhar <raayan.dhar@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Silent failure in top_k_top_p_sampling_from_logits() when indices are long

2 participants