feat: IdType indices in sampling kernels#2281
Conversation
Signed-off-by: raayandhar <raayan.dhar@gmail.com>
Signed-off-by: raayandhar <raayan.dhar@gmail.com>
Signed-off-by: raayandhar <raayan.dhar@gmail.com>
Signed-off-by: raayandhar <raayan.dhar@gmail.com>
Signed-off-by: raayandhar <raayan.dhar@gmail.com>
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughThe changes extend CUDA sampling functions to accept both int32 and int64 indices via type-dispatch mechanisms, enforce output dtype matching indices dtype, add validation macros to tvm_ffi_utils.h, update sampling.py to allocate outputs with indices-driven dtype, and introduce tests verifying int64 support across all sampling variants. Changes
Sequence DiagramsequenceDiagram
participant Py as Python Layer<br/>(flashinfer/sampling.py)
participant Val as Validation Layer<br/>(tvm_ffi_utils.h)
participant Dispatch as Type Dispatch<br/>(DISPATCH_DLPACK_IDTYPE_TO_CTYPE)
participant CUDA as CUDA Kernel<br/>(sampling.cu)
Py->>Py: Allocate output tensor<br/>dtype = indices.dtype or int32
Py->>Val: Call sampling function with indices
Val->>Val: CHECK_MAYBE_INPUT_TYPES<br/>indices ∈ {int32, int64}
Val->>Val: CHECK_MAYBE_SAME_DTYPE<br/>indices.dtype == output.dtype
Val->>Dispatch: Dispatch based on IdType
alt IdType == int32
Dispatch->>CUDA: Execute kernel<br/>with int32 template
else IdType == int64
Dispatch->>CUDA: Execute kernel<br/>with int64 template
end
CUDA->>CUDA: Sample & populate output
CUDA-->>Py: Return tensor<br/>with matched dtype
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @raayandhar, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the flexibility of the sampling kernels by introducing support for Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds support for int32 and int64 index types in the sampling kernels, which is a great enhancement. The changes are consistently applied across C++ kernels, Python interfaces, and documentation. The new functionality is also well-covered by a new parameterized test case.
My main feedback is on the Python side (flashinfer/sampling.py), where there's some code duplication for creating the output tensor in various sampling functions. I've left comments with suggestions to make the code more concise and maintainable by reducing this repetition.
Overall, this is a solid contribution.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
csrc/sampling.cu (1)
201-226: Critical: Type mismatch between Python int32 and C++ IdType cast for top_k_arr.Line 216 casts
maybe_top_k_arrtoIdType*, but the Python side always converts it toint32:
- Line 314 in
flashinfer/sampling.py:maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None- Line 231 in
flashinfer/sampling.py:maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else NoneWhen
IdTypeisint64_t, this creates a type mismatch: the Python side passes anint32array but the C++ side interprets it asint64_t*, leading to incorrect memory access and potentially reading garbage values or crashes.Solution options:
- Keep
maybe_top_k_arrasint32*cast in C++ (don't useIdType*)- Update Python side to match indices dtype:
maybe_top_k_arr = maybe_top_k_arr.to(indices.dtype if indices is not None else torch.int32)- Update kernel signature to explicitly accept
int32_t*for top_k_arr regardless of IdType#!/bin/bash # Verify the kernel signature and how top_k_arr is declared ast-grep --pattern $'TopKTopPSamplingFromProb<$_, $_>($$$)' # Check if there are other call sites that might reveal the expected type rg -n "TopKTopPSamplingFromProb" --type=cpp -C3
🧹 Nitpick comments (1)
tests/utils/test_sampling.py (1)
888-931: Consider adding a test case with non-identity index mapping.The new test comprehensively covers all sampling variants with both int32 and int64 dtypes, validating dtype matching, shape, and value ranges. However, it uses
indices = torch.arange(batch_size, dtype=indices_dtype), which creates an identity mapping whereindices[i] = i.Consider adding a test case that uses non-contiguous or repeated indices (e.g.,
indices = torch.tensor([0, 0, 1, 1], dtype=indices_dtype)) to verify that the int64 index indirection works correctly when mapping multiple outputs to the same probability distribution.💡 Example test enhancement
@pytest.mark.parametrize("sampling_type", ["from_probs", "top_p", "top_k"]) @pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64]) def test_int64_indices_mapping(sampling_type, indices_dtype): """Test that int64 indices correctly map multiple requests to same distribution.""" torch.manual_seed(42) unique_batch_size = 2 batch_size = 6 vocab_size = 1000 logits = torch.randn(unique_batch_size, vocab_size, device="cuda:0") probs = torch.softmax(logits, dim=-1) # Map: requests 0,1,2 -> dist 0; requests 3,4,5 -> dist 1 indices = torch.tensor([0, 0, 0, 1, 1, 1], dtype=indices_dtype, device="cuda:0") if sampling_type == "from_probs": samples = flashinfer.sampling.sampling_from_probs(probs, indices=indices) elif sampling_type == "top_p": samples = flashinfer.sampling.top_p_sampling_from_probs(probs, 0.9, indices=indices) elif sampling_type == "top_k": samples = flashinfer.sampling.top_k_sampling_from_probs(probs, 100, indices=indices) assert samples.dtype == indices_dtype assert samples.shape == (batch_size,) assert torch.all(samples < vocab_size) and torch.all(samples >= 0)
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
csrc/sampling.cucsrc/tvm_ffi_utils.hflashinfer/sampling.pytests/utils/test_sampling.py
🧰 Additional context used
📓 Path-based instructions (3)
csrc/**/*.cu
📄 CodeRabbit inference engine (CLAUDE.md)
Framework bindings and PyTorch tensor handling should be implemented in
csrc/via TVM-FFI, not ininclude/headers
Files:
csrc/sampling.cu
tests/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
tests/**/*.py: Test implementations should useflashinfer.utilsfunctions (get_compute_capability,is_sm90a_supported,is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing withmpirunon multi-GPU systems, use the pattern:mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes -tests/conftest.pyprovides auto-skipping for OOM tests as a safety net but should not be relied upon
Files:
tests/utils/test_sampling.py
flashinfer/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/**/*.py: Use@functools.cachedecorator on Python API functions to implement module-level caching and avoid recompilation
Use@flashinfer_apidecorator for debugging API calls, enable viaFLASHINFER_LOGLEVELenvironment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Files:
flashinfer/sampling.py
🧠 Learnings (1)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers
🧬 Code graph analysis (2)
csrc/sampling.cu (2)
csrc/tvm_ffi_utils.h (1)
get_stream(306-308)csrc/sampling_utils.h (1)
check_tensor_param(24-41)
tests/utils/test_sampling.py (4)
flashinfer/logits_processor/types.py (2)
logits(74-78)probs(81-85)flashinfer/sampling.py (14)
softmax(53-73)softmax(536-590)sampling_from_probs(128-151)sampling_from_probs(661-730)sampling_from_logits(87-112)sampling_from_logits(594-657)top_p_sampling_from_probs(169-199)top_p_sampling_from_probs(734-827)top_k_sampling_from_probs(218-246)top_k_sampling_from_probs(831-924)min_p_sampling_from_probs(265-295)min_p_sampling_from_probs(928-1017)top_k_top_p_sampling_from_probs(300-335)top_k_top_p_sampling_from_probs(1154-1276)csrc/sampling.cu (14)
softmax(25-46)softmax(25-26)sampling_from_probs(72-94)sampling_from_probs(72-73)sampling_from_logits(48-70)sampling_from_logits(48-49)top_p_sampling_from_probs(96-123)top_p_sampling_from_probs(96-99)top_k_sampling_from_probs(125-155)top_k_sampling_from_probs(125-128)min_p_sampling_from_probs(157-188)min_p_sampling_from_probs(157-160)top_k_top_p_sampling_from_probs(190-227)top_k_top_p_sampling_from_probs(190-195)csrc/flashinfer_sampling_binding.cu (7)
softmax(20-21)sampling_from_probs(23-24)sampling_from_logits(26-27)top_p_sampling_from_probs(29-32)top_k_sampling_from_probs(34-37)min_p_sampling_from_probs(39-42)top_k_top_p_sampling_from_probs(44-49)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (9)
csrc/tvm_ffi_utils.h (1)
279-290: LGTM! Well-designed dtype validation macros.The three new macros follow existing patterns and provide clear error messages for dtype validation. They appropriately handle optional tensors and support the int32/int64 index type flexibility introduced by this PR.
tests/utils/test_sampling.py (1)
699-713: LGTM! Test reordering improves readability.The reordering groups related test cases together logically (batch size mismatch check followed by correct batch size check), and the comment updates accurately reflect the new structure.
csrc/sampling.cu (4)
52-69: LGTM! Correct int32/int64 dispatch for sampling_from_logits.The changes properly validate indices dtype, ensure output dtype consistency, and dispatch the kernel call based on the index type. The pointer casts to
IdType*for both output and indices are correct.
76-93: LGTM! Correct int32/int64 dispatch for sampling_from_probs.The implementation correctly mirrors the changes in
sampling_from_logits, with proper dtype validation and type dispatch.
102-122: LGTM! Correct int32/int64 dispatch for top_p_sampling_from_probs.The dtype validation and dispatch logic are correct. The cast of
maybe_top_p_arrtofloat*(line 117) aligns with the Python side's.float()conversion.
166-187: LGTM! Correct int32/int64 dispatch for min_p_sampling_from_probs.The dtype validation, dispatch logic, and pointer casts are correct. The
maybe_min_p_arrcast tofloat*(line 179) aligns with the Python side's.float()conversion.flashinfer/sampling.py (3)
100-101: LGTM! Output dtype correctly matches indices dtype.The change ensures the output tensor dtype matches the provided indices dtype, with a sensible default to
int32when indices are not provided. This pattern is consistently applied across all sampling functions.
122-123: LGTM! Consistent dtype handling across all sampling helpers.All sampling function helpers (both real and fake implementations) consistently apply the pattern of matching output dtype to indices dtype, with appropriate defaults. This ensures type consistency throughout the sampling pipeline.
Also applies to: 139-140, 163-164, 185-186, 211-212, 232-233, 258-259, 281-282, 319-320, 349-350
613-618: LGTM! Documentation accurately describes int32/int64 support.The docstring updates clearly communicate:
- The
indicesparameter accepts bothtorch.int32andtorch.int64- Output dtype matches indices dtype when provided
- Output defaults to
int32when indices are not providedThe documentation is consistent across all sampling function variants.
Also applies to: 680-685, 763-768, 860-865, 958-963, 1057-1062, 1190-1195
|
/bot run |
|
[SUCCESS] Pipeline #41058057: 12/20 passed |
yzh119
left a comment
There was a problem hiding this comment.
LGTM, thanks for working on this feature!
📌 Description
Based on this comment in #2127, we can add support for Int64 indices as well. I decided to do this using
IdTypelike it is done in other files.🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Test results:
Reviewer Notes