Skip to content

feat: IdType indices in sampling kernels#2281

Merged
yzh119 merged 5 commits intoflashinfer-ai:mainfrom
raayandhar:users/rdhar/idtype-sampling-indices
Jan 3, 2026
Merged

feat: IdType indices in sampling kernels#2281
yzh119 merged 5 commits intoflashinfer-ai:mainfrom
raayandhar:users/rdhar/idtype-sampling-indices

Conversation

@raayandhar
Copy link
Copy Markdown
Contributor

@raayandhar raayandhar commented Jan 2, 2026

📌 Description

Based on this comment in #2127, we can add support for Int64 indices as well. I decided to do this using IdType like it is done in other files.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Test results:

(flashinfer) raayan@uril-1:~/projects/flashinfer$ pytest tests/utils/test_sampling.py
============================================================= test session starts =============================================================
platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0
rootdir: /home/raayan/projects/flashinfer
configfile: pytest.ini
collected 1884 items

tests/utils/test_sampling.py .......................................................................................................... [  5%]
....................................................................................................................................... [ 12%]
....................................................................................................................................... [ 19%]
....................s..s..s..........................................................................sss........................sss.... [ 27%]
....................................................................................................................................... [ 34%]
..........................ssss................................ssss................................ssss................................s [ 41%]
sss................................ssss................................ssss................................ssss........................ [ 48%]
........ssss................................ssss................................ssss................................ssss............... [ 55%]
.................ssss................................ssss................................ssss................................ssss...... [ 62%]
..........................ssss................................ssss................................ssss................................s [ 70%]
sss................................ssss................................ssss................................ssss........................ [ 77%]
........ssss................................ssss................................ssss................................ssss............... [ 84%]
.................ssss.................................................................................................................. [ 91%]
........................................................sss............................................................................ [ 98%]
.......................                                                                                                                 [100%]

================================================ 1764 passed, 120 skipped in 546.33s (0:09:06) ================================================
(flashinfer) raayan@uril-1:~/projects/flashinfer$

Reviewer Notes

Signed-off-by: raayandhar <raayan.dhar@gmail.com>
Signed-off-by: raayandhar <raayan.dhar@gmail.com>
Signed-off-by: raayandhar <raayan.dhar@gmail.com>
Signed-off-by: raayandhar <raayan.dhar@gmail.com>
Signed-off-by: raayandhar <raayan.dhar@gmail.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jan 2, 2026

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

📝 Walkthrough

Walkthrough

The changes extend CUDA sampling functions to accept both int32 and int64 indices via type-dispatch mechanisms, enforce output dtype matching indices dtype, add validation macros to tvm_ffi_utils.h, update sampling.py to allocate outputs with indices-driven dtype, and introduce tests verifying int64 support across all sampling variants.

Changes

Cohort / File(s) Summary
CUDA Type Dispatch & Validation Macros
csrc/sampling.cu, csrc/tvm_ffi_utils.h
Added three validation macros (CHECK_MAYBE_INPUT_TYPES, CHECK_SAME_DTYPE, CHECK_MAYBE_SAME_DTYPE) to enforce dtype consistency. Replaced direct CUDA calls with DISPATCH_DLPACK_IDTYPE_TO_CTYPE type dispatch in sampling_from_logits, sampling_from_probs, top_p_sampling_from_probs, top_k_sampling_from_probs, min_p_sampling_from_probs, and top_k_top_p_sampling_from_probs to support both int32 and int64 index types.
Python Sampling Layer
flashinfer/sampling.py
Updated output tensor allocation across multiple sampling functions (sampling_from_logits, sampling_from_probs, top_p_sampling_from_probs, top_k_sampling_from_probs, min_p_sampling_from_probs, top_k_top_p_sampling_from_probs, and internal helpers) to derive dtype from provided indices tensor, defaulting to int32 when indices absent. Updated docstrings to document int32/int64 indices support and dtype matching behavior.
Sampling Function Tests
tests/utils/test_sampling.py
Removed test validating non-int32 indices rejection; added new parameterized test test_int64_indices_sampling covering all sampling types (from_probs, from_logits, top_p, top_k, min_p, top_k_top_p) to verify output dtype, shape, and value range correctness with int64 indices. Reorganized tensor-validation test numbering.

Sequence Diagram

sequenceDiagram
    participant Py as Python Layer<br/>(flashinfer/sampling.py)
    participant Val as Validation Layer<br/>(tvm_ffi_utils.h)
    participant Dispatch as Type Dispatch<br/>(DISPATCH_DLPACK_IDTYPE_TO_CTYPE)
    participant CUDA as CUDA Kernel<br/>(sampling.cu)

    Py->>Py: Allocate output tensor<br/>dtype = indices.dtype or int32
    Py->>Val: Call sampling function with indices
    Val->>Val: CHECK_MAYBE_INPUT_TYPES<br/>indices ∈ {int32, int64}
    Val->>Val: CHECK_MAYBE_SAME_DTYPE<br/>indices.dtype == output.dtype
    Val->>Dispatch: Dispatch based on IdType
    alt IdType == int32
        Dispatch->>CUDA: Execute kernel<br/>with int32 template
    else IdType == int64
        Dispatch->>CUDA: Execute kernel<br/>with int64 template
    end
    CUDA->>CUDA: Sample & populate output
    CUDA-->>Py: Return tensor<br/>with matched dtype
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

  • fix: add a check for int32 indices in sampling.py #2127 — Both PRs modify sampling index-dtype handling in flashinfer/sampling.py; this PR adds int32/int64 support with dtype alignment, whereas the related PR enforces int32-only indices, representing opposing approaches to dtype flexibility.
  • feat: C++ side tensor validation #2160 — Both PRs modify C++ sampling validation paths through changes to csrc/sampling.cu and additions to tvm_ffi_utils.h macro validation, indicating coordinated work on sampling function validation infrastructure.

Suggested reviewers

  • cyx-6
  • aleozlx
  • nvmbreughe
  • wenscarl
  • djmmoss

Poem

🐰 With whiskers twitching in the code, a rabbit hops through types—
From int32 to int64, new dispatch routes take flight!
Validation macros sprout like clover, ensuring dtypes align just right,
Sampling kernels now embrace both paths with gentle might. 🌿✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 31.03% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main feature: supporting IdType indices (64-bit index types) in sampling kernels, which is the primary focus of all changes across multiple files.
Description check ✅ Passed The PR description addresses the main purpose (adding Int64 indices support using IdType) with a reference to the related discussion, and confirms checklist items completion and test passage with detailed output.
✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @raayandhar, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the flexibility of the sampling kernels by introducing support for dl_int64 as an IdType for indices. This change allows the system to handle larger index values, which is beneficial for applications involving extensive vocabularies or complex data indexing. The modifications ensure type safety and correct behavior across both the C++ backend and the Python frontend, along with updated documentation and comprehensive testing.

Highlights

  • Expanded IdType Support: The sampling kernels now support dl_int64 (equivalent to long long in C++) for index types, allowing for larger vocabulary sizes or more extensive indexing schemes.
  • Dynamic Output Dtype: The Python frontend for sampling functions (flashinfer/sampling.py) has been updated to dynamically determine the output tensor's data type based on the indices tensor's dtype, defaulting to torch.int32 if no indices are provided.
  • New Type Checking Macros: New utility macros (CHECK_MAYBE_INPUT_TYPES, CHECK_SAME_DTYPE, CHECK_MAYBE_SAME_DTYPE) were introduced in csrc/tvm_ffi_utils.h to facilitate robust type checking for multiple input types and ensure data type consistency across tensors.
  • Updated Documentation and Tests: The documentation for the indices parameter in the Python sampling functions has been updated to reflect int64 support, and a new test case was added to validate int64 indices across all sampling methods.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for int32 and int64 index types in the sampling kernels, which is a great enhancement. The changes are consistently applied across C++ kernels, Python interfaces, and documentation. The new functionality is also well-covered by a new parameterized test case.

My main feedback is on the Python side (flashinfer/sampling.py), where there's some code duplication for creating the output tensor in various sampling functions. I've left comments with suggestions to make the code more concise and maintainable by reducing this repetition.

Overall, this is a solid contribution.

Comment thread flashinfer/sampling.py
Comment thread flashinfer/sampling.py
Comment thread flashinfer/sampling.py
Comment thread flashinfer/sampling.py
Comment thread flashinfer/sampling.py
Comment thread flashinfer/sampling.py
Comment thread flashinfer/sampling.py
Comment thread flashinfer/sampling.py
Comment thread flashinfer/sampling.py
Comment thread flashinfer/sampling.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
csrc/sampling.cu (1)

201-226: Critical: Type mismatch between Python int32 and C++ IdType cast for top_k_arr.

Line 216 casts maybe_top_k_arr to IdType*, but the Python side always converts it to int32:

  • Line 314 in flashinfer/sampling.py: maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
  • Line 231 in flashinfer/sampling.py: maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None

When IdType is int64_t, this creates a type mismatch: the Python side passes an int32 array but the C++ side interprets it as int64_t*, leading to incorrect memory access and potentially reading garbage values or crashes.

Solution options:

  1. Keep maybe_top_k_arr as int32* cast in C++ (don't use IdType*)
  2. Update Python side to match indices dtype: maybe_top_k_arr = maybe_top_k_arr.to(indices.dtype if indices is not None else torch.int32)
  3. Update kernel signature to explicitly accept int32_t* for top_k_arr regardless of IdType
#!/bin/bash
# Verify the kernel signature and how top_k_arr is declared
ast-grep --pattern $'TopKTopPSamplingFromProb<$_, $_>($$$)'

# Check if there are other call sites that might reveal the expected type
rg -n "TopKTopPSamplingFromProb" --type=cpp -C3
🧹 Nitpick comments (1)
tests/utils/test_sampling.py (1)

888-931: Consider adding a test case with non-identity index mapping.

The new test comprehensively covers all sampling variants with both int32 and int64 dtypes, validating dtype matching, shape, and value ranges. However, it uses indices = torch.arange(batch_size, dtype=indices_dtype), which creates an identity mapping where indices[i] = i.

Consider adding a test case that uses non-contiguous or repeated indices (e.g., indices = torch.tensor([0, 0, 1, 1], dtype=indices_dtype)) to verify that the int64 index indirection works correctly when mapping multiple outputs to the same probability distribution.

💡 Example test enhancement
@pytest.mark.parametrize("sampling_type", ["from_probs", "top_p", "top_k"])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
def test_int64_indices_mapping(sampling_type, indices_dtype):
    """Test that int64 indices correctly map multiple requests to same distribution."""
    torch.manual_seed(42)
    unique_batch_size = 2
    batch_size = 6
    vocab_size = 1000
    
    logits = torch.randn(unique_batch_size, vocab_size, device="cuda:0")
    probs = torch.softmax(logits, dim=-1)
    # Map: requests 0,1,2 -> dist 0; requests 3,4,5 -> dist 1
    indices = torch.tensor([0, 0, 0, 1, 1, 1], dtype=indices_dtype, device="cuda:0")
    
    if sampling_type == "from_probs":
        samples = flashinfer.sampling.sampling_from_probs(probs, indices=indices)
    elif sampling_type == "top_p":
        samples = flashinfer.sampling.top_p_sampling_from_probs(probs, 0.9, indices=indices)
    elif sampling_type == "top_k":
        samples = flashinfer.sampling.top_k_sampling_from_probs(probs, 100, indices=indices)
    
    assert samples.dtype == indices_dtype
    assert samples.shape == (batch_size,)
    assert torch.all(samples < vocab_size) and torch.all(samples >= 0)
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6f1624c and f5a4a94.

📒 Files selected for processing (4)
  • csrc/sampling.cu
  • csrc/tvm_ffi_utils.h
  • flashinfer/sampling.py
  • tests/utils/test_sampling.py
🧰 Additional context used
📓 Path-based instructions (3)
csrc/**/*.cu

📄 CodeRabbit inference engine (CLAUDE.md)

Framework bindings and PyTorch tensor handling should be implemented in csrc/ via TVM-FFI, not in include/ headers

Files:

  • csrc/sampling.cu
tests/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

tests/**/*.py: Test implementations should use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing with mpirun on multi-GPU systems, use the pattern: mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes - tests/conftest.py provides auto-skipping for OOM tests as a safety net but should not be relied upon

Files:

  • tests/utils/test_sampling.py
flashinfer/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

flashinfer/**/*.py: Use @functools.cache decorator on Python API functions to implement module-level caching and avoid recompilation
Use @flashinfer_api decorator for debugging API calls, enable via FLASHINFER_LOGLEVEL environment variable (0=off, 1=basic, 3=detailed, 5=with stats)

Files:

  • flashinfer/sampling.py
🧠 Learnings (1)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers
🧬 Code graph analysis (2)
csrc/sampling.cu (2)
csrc/tvm_ffi_utils.h (1)
  • get_stream (306-308)
csrc/sampling_utils.h (1)
  • check_tensor_param (24-41)
tests/utils/test_sampling.py (4)
flashinfer/logits_processor/types.py (2)
  • logits (74-78)
  • probs (81-85)
flashinfer/sampling.py (14)
  • softmax (53-73)
  • softmax (536-590)
  • sampling_from_probs (128-151)
  • sampling_from_probs (661-730)
  • sampling_from_logits (87-112)
  • sampling_from_logits (594-657)
  • top_p_sampling_from_probs (169-199)
  • top_p_sampling_from_probs (734-827)
  • top_k_sampling_from_probs (218-246)
  • top_k_sampling_from_probs (831-924)
  • min_p_sampling_from_probs (265-295)
  • min_p_sampling_from_probs (928-1017)
  • top_k_top_p_sampling_from_probs (300-335)
  • top_k_top_p_sampling_from_probs (1154-1276)
csrc/sampling.cu (14)
  • softmax (25-46)
  • softmax (25-26)
  • sampling_from_probs (72-94)
  • sampling_from_probs (72-73)
  • sampling_from_logits (48-70)
  • sampling_from_logits (48-49)
  • top_p_sampling_from_probs (96-123)
  • top_p_sampling_from_probs (96-99)
  • top_k_sampling_from_probs (125-155)
  • top_k_sampling_from_probs (125-128)
  • min_p_sampling_from_probs (157-188)
  • min_p_sampling_from_probs (157-160)
  • top_k_top_p_sampling_from_probs (190-227)
  • top_k_top_p_sampling_from_probs (190-195)
csrc/flashinfer_sampling_binding.cu (7)
  • softmax (20-21)
  • sampling_from_probs (23-24)
  • sampling_from_logits (26-27)
  • top_p_sampling_from_probs (29-32)
  • top_k_sampling_from_probs (34-37)
  • min_p_sampling_from_probs (39-42)
  • top_k_top_p_sampling_from_probs (44-49)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (9)
csrc/tvm_ffi_utils.h (1)

279-290: LGTM! Well-designed dtype validation macros.

The three new macros follow existing patterns and provide clear error messages for dtype validation. They appropriately handle optional tensors and support the int32/int64 index type flexibility introduced by this PR.

tests/utils/test_sampling.py (1)

699-713: LGTM! Test reordering improves readability.

The reordering groups related test cases together logically (batch size mismatch check followed by correct batch size check), and the comment updates accurately reflect the new structure.

csrc/sampling.cu (4)

52-69: LGTM! Correct int32/int64 dispatch for sampling_from_logits.

The changes properly validate indices dtype, ensure output dtype consistency, and dispatch the kernel call based on the index type. The pointer casts to IdType* for both output and indices are correct.


76-93: LGTM! Correct int32/int64 dispatch for sampling_from_probs.

The implementation correctly mirrors the changes in sampling_from_logits, with proper dtype validation and type dispatch.


102-122: LGTM! Correct int32/int64 dispatch for top_p_sampling_from_probs.

The dtype validation and dispatch logic are correct. The cast of maybe_top_p_arr to float* (line 117) aligns with the Python side's .float() conversion.


166-187: LGTM! Correct int32/int64 dispatch for min_p_sampling_from_probs.

The dtype validation, dispatch logic, and pointer casts are correct. The maybe_min_p_arr cast to float* (line 179) aligns with the Python side's .float() conversion.

flashinfer/sampling.py (3)

100-101: LGTM! Output dtype correctly matches indices dtype.

The change ensures the output tensor dtype matches the provided indices dtype, with a sensible default to int32 when indices are not provided. This pattern is consistently applied across all sampling functions.


122-123: LGTM! Consistent dtype handling across all sampling helpers.

All sampling function helpers (both real and fake implementations) consistently apply the pattern of matching output dtype to indices dtype, with appropriate defaults. This ensures type consistency throughout the sampling pipeline.

Also applies to: 139-140, 163-164, 185-186, 211-212, 232-233, 258-259, 281-282, 319-320, 349-350


613-618: LGTM! Documentation accurately describes int32/int64 support.

The docstring updates clearly communicate:

  • The indices parameter accepts both torch.int32 and torch.int64
  • Output dtype matches indices dtype when provided
  • Output defaults to int32 when indices are not provided

The documentation is consistent across all sampling function variants.

Also applies to: 680-685, 763-768, 860-865, 958-963, 1057-1062, 1190-1195

Comment thread csrc/sampling.cu
@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Jan 2, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !223 has been created, and the CI pipeline #41058057 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[SUCCESS] Pipeline #41058057: 12/20 passed

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for working on this feature!

@yzh119 yzh119 merged commit cda8f3f into flashinfer-ai:main Jan 3, 2026
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants