feat: C++ side tensor validation#2160
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds native runtime validations for index-like and optional sampling tensors (int32 type and scalar/1D batch consistency) in C++/CUDA and removes equivalent Python preflight checks; public API signatures remain unchanged. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @raayandhar, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refactors the tensor input validation mechanism within the FlashInfer library. It shifts the responsibility of checking tensor data types, dimensions, and batch sizes from the Python API to the underlying C++ (CUDA) implementation. This change aims to consolidate error handling, prevent silent failures, and ensure that invalid tensor configurations are identified and reported more effectively at the execution level, leading to a more robust and maintainable codebase. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request successfully moves tensor validation from Python to C++, which is a good improvement for performance and consistency. The changes are well-contained, and the tests are updated accordingly. I've identified a case of code duplication that affects maintainability and a potential bug in the new validation logic that could lead to a crash. My review comments address these issues.
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (1)
csrc/sampling.cu (1)
24-42: Duplicate code withcsrc/renorm.cu.As noted in the review of
csrc/renorm.cu, this function is duplicated. Consider extracting tocsrc/tvm_ffi_utils.h.
🧹 Nitpick comments (3)
csrc/renorm.cu (1)
24-42: Duplicate code:check_tensor_paramis duplicated incsrc/sampling.cu.This helper function has an identical implementation in
csrc/sampling.cu(lines 24-42). Consider moving it tocsrc/tvm_ffi_utils.halongside the other validation utilities to avoid duplication.tests/utils/test_sampling.py (2)
634-676: Makematchpattern for top‑p batch‑size mismatch a raw regex string
match="Sampling parameter.*batch size mismatch"uses.and*as regex metacharacters but isn’t a raw string, which is what Ruff RUF043 is flagging. Switching to a raw string keeps the intent and silences the lint:- with pytest.raises( - RuntimeError, match="Sampling parameter.*batch size mismatch" - ): + with pytest.raises( + RuntimeError, match=r"Sampling parameter.*batch size mismatch" + ):(Based on static analysis hints.)
682-726: Likewise, use a raw regex for top‑k batch‑size mismatchSame issue here:
match="Sampling parameter.*batch size mismatch"is intended as a regex but isn’t raw, triggering RUF043. Mirroring the top‑p fix:- with pytest.raises( - RuntimeError, match="Sampling parameter.*batch size mismatch" - ): + with pytest.raises( + RuntimeError, match=r"Sampling parameter.*batch size mismatch" + ):You could also optionally unify the min‑p mismatch check to use the same
Sampling parameter.*batch size mismatchregex for future‑proofing, but that’s not required by this change.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between 89e1adb and abc3a392e8a11a4badbcfa278621de723509dbef.
📒 Files selected for processing (12)
csrc/batch_decode.cu(2 hunks)csrc/batch_decode_mla_plan.cu(1 hunks)csrc/batch_decode_mla_run.cu(1 hunks)csrc/batch_mla_plan.cu(1 hunks)csrc/batch_mla_run.cu(1 hunks)csrc/renorm.cu(3 hunks)csrc/sampling.cu(7 hunks)csrc/tvm_ffi_utils.h(1 hunks)flashinfer/decode.py(0 hunks)flashinfer/mla.py(0 hunks)flashinfer/sampling.py(1 hunks)tests/utils/test_sampling.py(10 hunks)
💤 Files with no reviewable changes (2)
- flashinfer/mla.py
- flashinfer/decode.py
🧰 Additional context used
🧬 Code graph analysis (2)
csrc/sampling.cu (1)
csrc/renorm.cu (2)
check_tensor_param(25-42)check_tensor_param(25-25)
csrc/renorm.cu (1)
csrc/sampling.cu (2)
check_tensor_param(25-42)check_tensor_param(25-25)
🪛 Ruff (0.14.7)
tests/utils/test_sampling.py
665-665: Pattern passed to match= contains metacharacters but is neither escaped nor raw
(RUF043)
715-715: Pattern passed to match= contains metacharacters but is neither escaped nor raw
(RUF043)
🔇 Additional comments (18)
csrc/tvm_ffi_utils.h (1)
257-260: LGTM!The macro correctly handles optional tensor type validation, following the existing pattern of validation macros in this file.
csrc/batch_decode.cu (2)
45-46: LGTM!Early type validation for the
indptrtensor ensures proper dtype before proceeding with the plan computation.
91-94: LGTM!Consistent type validation for all paged KV cache index tensors. These checks align with the subsequent casts to
IdType*.csrc/renorm.cu (3)
50-50: LGTM!Validation ensures the optional top-p parameter tensor has the correct shape and batch size.
69-69: LGTM!Consistent parameter validation for the top-k array.
89-89: LGTM!Validation correctly uses
logitsas the reference tensor for batch size comparison.flashinfer/sampling.py (1)
19-19: LGTM!Import cleanup reflects the removal of Python-side validation helpers. Validation is now handled in the C++ layer.
csrc/sampling.cu (6)
71-71: LGTM!Validates optional indices tensor dtype before use.
89-89: LGTM!Consistent index type validation.
109-112: LGTM!Both index type and top-p parameter validation are properly placed.
135-138: LGTM!Consistent validation pattern for top-k sampling.
161-164: LGTM!Consistent validation pattern for min-p sampling.
190-194: LGTM!Validates indices, top-k, and top-p parameters for the combined sampling function.
tests/utils/test_sampling.py (1)
575-629: Min‑p tensor validation test updates look consistent with new C++ checksUsing
RuntimeErrorplus the more permissive dtype error regex forindicesand distinct cases for scalar/1D/2D shapes gives good coverage of the new C++ validation paths. I don’t see issues in this block.csrc/batch_decode_mla_run.cu (1)
20-22: Int32 checks for paged‑KV metadata are in the right placeValidating
paged_kv_indptr,paged_kv_indices, andpaged_kv_last_page_lenasdl_int32before constructingDecodePlanInfoand setting uppaged_kv_mla_tis consistent with the rest of the decode path and prevents misuse of non‑int32 tensors as index buffers. Looks good.csrc/batch_decode_mla_plan.cu (1)
18-19: Plan‑sideindptrtype check is consistent with downstream usageThe
CHECK_INPUT_TYPE(indptr, dl_int32);at function entry matches theIdType*usage inDecodePlanand aligns with other newly‑added decode checks. No issues here.csrc/batch_mla_plan.cu (1)
32-35: Good: MLA plan now validates all index/length tensors as int32Adding
CHECK_INPUT_TYPEforqo_indptr,kv_indptr, andkv_lenup front is aligned with how these buffers are used (IdType*) and with the broader move of validation into C++/CUDA. Change looks correct and non‑disruptive.csrc/batch_mla_run.cu (1)
40-41: Int32 validation forkv_indicesmatches MLA run expectationsThe new
CHECK_INPUT_TYPE(kv_indices, dl_int32);is appropriately placed before plan/materialization and consistent with theIdType*usage inparams.kv_indices. This should surface dtype mistakes early without affecting correct callers.
abc3a39 to
6c2d7c9
Compare
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (1)
csrc/sampling.cu (1)
24-42:check_tensor_parammatches renorm.cu logic; consider de‑duplicating into a shared header.The helper correctly enforces that sampling parameter tensors are scalar or 1D and have batch size matching the reference tensor, with error messages aligned to the updated tests. However, an essentially identical
check_tensor_paramalready exists incsrc/renorm.cu; moving this helper into a shared header liketvm_ffi_utils.hand reusing it in both places would avoid future drift and centralize the sampling‑parameter contract.
🧹 Nitpick comments (3)
tests/utils/test_sampling.py (2)
644-666: top‑p tensor‑param tests align with C++ checks; consider raw string for regex.The move to
RuntimeErrorand the regex patterns for 2D, 0D, and batch‑size mismatch cases line up with the newcheck_tensor_paramhelper in C++.To satisfy Ruff’s RUF043 and make it explicit that
matchis a regex, you can switch the last pattern to a raw string:with pytest.raises( RuntimeError, - match="Sampling parameter.*batch size mismatch", + match=r"Sampling parameter.*batch size mismatch", ): ...
694-716: top‑k tensor‑param tests mirror the new C++ validation; same raw‑string tweak applies.The updated
RuntimeErrorexpectations and regexes for shape/batch mismatches are consistent with the C++check_tensor_parambehavior.As with the top‑p tests, consider making the final
matcha raw string to quiet RUF043 and clarify intent:with pytest.raises( RuntimeError, - match="Sampling parameter.*batch size mismatch", + match=r"Sampling parameter.*batch size mismatch", ): ...csrc/sampling.cu (1)
71-71: New C++‑side dtype and shape validations for sampling look correct and align with Python/tests.
CHECK_MAYBE_INPUT_TYPE(maybe_indices, dl_int32)insampling_from_logits,sampling_from_probs, and all sampling variants correctly enforces that optional indices areint32, matching how they’re cast toint*in the kernels and how Python tests now assertRuntimeErrorfor non‑int32 indices.check_tensor_paramuses in top‑p, top‑k, min‑p, and joint top‑k/top‑p functions properly restrict parameter tensors to 1D (or scalar via the separate scalar argument) and enforceparam.size(0) == probs.size(0), which matches the new tests around 2D, 0D, and batch‑size mismatches.One small consistency nit:
top_k_sampling_from_probs,min_p_sampling_from_probs, andtop_k_top_p_sampling_from_probsvalidate bothprobsandoutput(including device consistency), whereastop_p_sampling_from_probsonly checksprobs. For symmetry and slightly better diagnostics, you might also addCHECK_INPUT(output)/CHECK_DEVICE(output, probs)there, though this is a pre‑existing gap and not introduced by this change.Also applies to: 89-89, 109-113, 136-139, 161-165, 190-195
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between abc3a392e8a11a4badbcfa278621de723509dbef and 6c2d7c987ba754925f30948a61868bb050d3de20.
📒 Files selected for processing (12)
csrc/batch_decode.cu(2 hunks)csrc/batch_decode_mla_plan.cu(1 hunks)csrc/batch_decode_mla_run.cu(1 hunks)csrc/batch_mla_plan.cu(1 hunks)csrc/batch_mla_run.cu(1 hunks)csrc/renorm.cu(3 hunks)csrc/sampling.cu(7 hunks)csrc/tvm_ffi_utils.h(1 hunks)flashinfer/decode.py(0 hunks)flashinfer/mla.py(0 hunks)flashinfer/sampling.py(1 hunks)tests/utils/test_sampling.py(10 hunks)
💤 Files with no reviewable changes (2)
- flashinfer/mla.py
- flashinfer/decode.py
🚧 Files skipped from review as they are similar to previous changes (4)
- csrc/batch_mla_plan.cu
- csrc/batch_decode_mla_run.cu
- csrc/tvm_ffi_utils.h
- csrc/renorm.cu
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/sampling.cu (1)
csrc/renorm.cu (2)
check_tensor_param(25-42)check_tensor_param(25-25)
🪛 Ruff (0.14.7)
tests/utils/test_sampling.py
665-665: Pattern passed to match= contains metacharacters but is neither escaped nor raw
(RUF043)
715-715: Pattern passed to match= contains metacharacters but is neither escaped nor raw
(RUF043)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (5)
csrc/batch_decode_mla_plan.cu (1)
18-18: Early int32 check forindptris correct and well‑placed.Validating
indptrwithCHECK_INPUT_TYPE(indptr, dl_int32)before passing it asIdType*intoDecodePlangives a clear, early failure for wrong dtypes and matches the kernel’s expectations. No further changes needed here.csrc/batch_mla_run.cu (1)
40-40: Int32 validation onkv_indicesmatches kernel expectations.The new
CHECK_INPUT_TYPE(kv_indices, dl_int32)is consistent with usingkv_indicesasIdType*in the dispatched MLA kernel and will now fail fast on incorrect dtypes instead of producing undefined behavior.csrc/batch_decode.cu (1)
45-45: Dtype checks forindptrand paged‑KV index tensors are appropriate safeguards.The
CHECK_INPUT_TYPE(..., dl_int32)calls onindptr,paged_kv_indptr,paged_kv_indices, andpaged_kv_last_page_lenmatch their use asIdType*in the decode plan/run paths and will surface dtype mismatches as clear runtime errors rather than silent corruption. This is a solid move of validation into the C++ layer.Also applies to: 91-93
tests/utils/test_sampling.py (1)
585-618: Min‑p validation tests correctly target C++ RuntimeError behavior.Switching these expectations to
RuntimeErrorand broadening the regex to allow both legacy and new message forms (including the int32 index checks) matches the new C++/TVM‑FFI validation path. The test coverage for 2D/0D/mismatched‑batch and non‑int32 indices looks comprehensive.flashinfer/sampling.py (1)
20-20: Typing imports now accurately reflect usage after validation refactor.Switching to
from typing import Optional, Tuple, Unionis correct given the current annotations, and the removal of Python‑side validation in favor of C++ checks keeps the wrapper thin while tests cover the new error surface.
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
csrc/tvm_ffi_utils.h (1)
241-260: Validation logic is sound; minor clarity improvement possible.The guard
if (param.ndim() == 0)before accessingparam.size(0)correctly prevents out-of-bounds access (addressing the prior review concern). The logic properly rejects invalid shapes.One minor observation: the error message mentions "or scalar" as acceptable, but the function rejects 0-dimensional tensors. This is likely intentional (users should use the scalar value parameter like
top_p_valinstead of a 0-dim tensor), but could be clarified:- << "Expected a 1D tensor of shape (batch_size,) or scalar for the sampling parameter, " + << "Expected a 1D tensor of shape (batch_size,) for the sampling parameter array, "
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between 6c2d7c987ba754925f30948a61868bb050d3de20 and 61b9af59fdc78e543edc327da7f29a21b89ea7fd.
📒 Files selected for processing (3)
csrc/renorm.cu(3 hunks)csrc/sampling.cu(6 hunks)csrc/tvm_ffi_utils.h(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
csrc/sampling.cu (1)
csrc/tvm_ffi_utils.h (1)
check_tensor_param(241-260)
csrc/renorm.cu (1)
csrc/tvm_ffi_utils.h (1)
check_tensor_param(241-260)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (10)
csrc/tvm_ffi_utils.h (1)
279-282: LGTM!The macro correctly wraps the existing
CHECK_INPUT_TYPEwith an optional guard, providing consistent dtype validation for optional tensor inputs.csrc/renorm.cu (3)
30-30: LGTM!Early validation of
maybe_top_p_arragainstprobsensures shape and batch size consistency before any CUDA kernel execution, providing clear error messages for misconfigured inputs.
49-49: LGTM!Consistent application of
check_tensor_paramformaybe_top_k_arrvalidation.
69-69: LGTM!Correct validation of
maybe_top_k_arragainstlogitsreference tensor.csrc/sampling.cu (6)
51-51: LGTM!Proper dtype validation for
maybe_indicesensures the optional index tensor isint32before being cast and passed to the CUDA kernel.
69-69: LGTM!Consistent
int32type validation for the indices tensor insampling_from_probs.
89-92: LGTM!Both validations are correctly placed: dtype check for
maybe_indicesand shape/batch validation formaybe_top_p_arragainst theprobsreference tensor.
115-118: LGTM!Consistent validation pattern for
top_k_sampling_from_probs: indices dtype check and parameter tensor shape validation.
141-144: LGTM!Proper validation for
min_p_sampling_from_probsmatching the pattern used in other sampling functions.
170-174: LGTM!The combined
top_k_top_p_sampling_from_probscorrectly validates all optional inputs: indices dtype and both parameter arrays for shape consistency.
Got it, sorry about that. Makes sense. |
Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
c3a3c48 to
52e562d
Compare
Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (2)
csrc/sampling.cu (1)
24-42: check_tensor_param helper looks correct but is duplicated across .cu filesThe shape/batch-size validation logic here is sound and, importantly, only reads
param.size(0)whenndim() == 1, which addresses the earlier risk of touchingsize(0)on a 0D tensor. However, this helper is now duplicated in bothcsrc/sampling.cuandcsrc/renorm.cu; consider moving it intotvm_ffi_utils.h(or another shared header) to avoid divergence in future edits.csrc/renorm.cu (1)
24-42: Shared check_tensor_param logic now safe for 0D tensors but should be centralizedThis version correctly throws on 0D and >1D tensors and only inspects
size(0)for the 1D case, avoiding the earlier out-of-bounds risk mentioned in prior review comments. Since the same helper now exists in bothcsrc/renorm.cuandcsrc/sampling.cu, it would be cleaner to move it into a shared header (e.g.,tvm_ffi_utils.h) and include it from both sites.
🧹 Nitpick comments (1)
tests/utils/test_sampling.py (1)
664-667: Use raw strings for regex patterns to satisfy Ruff and clarify intentThese
match="Sampling parameter.*batch size mismatch"patterns intentionally use.*as regex, so they should be raw strings to both silence RUF043 and make the intent explicit.- with pytest.raises(ValueError, match="Sampling parameter.*batch size mismatch"): + with pytest.raises(ValueError, match=r"Sampling parameter.*batch size mismatch"): ... - with pytest.raises(ValueError, match="Sampling parameter.*batch size mismatch"): + with pytest.raises(ValueError, match=r"Sampling parameter.*batch size mismatch"):Also applies to: 712-715
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between 61b9af59fdc78e543edc327da7f29a21b89ea7fd and c3a3c48a762891858ae04ac64db9b4ed9623e409.
📒 Files selected for processing (4)
csrc/renorm.cu(3 hunks)csrc/sampling.cu(7 hunks)csrc/tvm_ffi_utils.h(1 hunks)tests/utils/test_sampling.py(3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- csrc/tvm_ffi_utils.h
🧰 Additional context used
🧬 Code graph analysis (2)
csrc/renorm.cu (1)
csrc/sampling.cu (2)
check_tensor_param(25-42)check_tensor_param(25-25)
csrc/sampling.cu (2)
csrc/nvshmem_binding.cu (4)
tensor(59-63)tensor(59-59)tensor(64-64)tensor(64-64)csrc/renorm.cu (2)
check_tensor_param(25-42)check_tensor_param(25-25)
🪛 Ruff (0.14.7)
tests/utils/test_sampling.py
664-664: Pattern passed to match= contains metacharacters but is neither escaped nor raw
(RUF043)
712-712: Pattern passed to match= contains metacharacters but is neither escaped nor raw
(RUF043)
🔇 Additional comments (4)
tests/utils/test_sampling.py (1)
603-612: RuntimeError expectation and flexible regex for non-int32 indices look correctSwitching to
RuntimeErrorhere matches how TVM-styleCHECK_*failures surface, and the regex covering multiple possible messages (indices must have dtype...,Inconsistency of Tensor type...,int64 vs. int32) makes the test robust to backend wording changes.csrc/sampling.cu (2)
67-72: New dtype and shape checks on optional tensors align with C++-side validationAdding
CHECK_MAYBE_INPUT_TYPE(maybe_indices, dl_int32)and routing the optional sampling-parameter tensors throughcheck_tensor_param(...)gives consistent, early validation for indices/top-p/top-k/min-p across these entry points, and matches the updated Python tests (RuntimeError for non-int32 indices, ValueError for bad shapes/batch sizes).Also applies to: 85-90, 103-112, 126-139, 152-165, 179-195
143-147: Verify expected dtype for maybe_top_k_arr in top_k_sampling_from_probsHere
maybe_top_k_arris passed toTopKSamplingFromProb<float, int>asfloat*, whereas intop_k_top_p_sampling_from_probs(this file) and the renorm kernels (csrc/renorm.cu) the same per-batch top-k tensor is treated asint*. That inconsistency strongly suggestsmaybe_top_k_arris meant to be an int32 tensor everywhere, and this call should likely cast toint*instead:- has_top_k_arr ? static_cast<float*>(maybe_top_k_arr.value().data_ptr()) : nullptr, batch_size, + has_top_k_arr ? static_cast<int*>(maybe_top_k_arr.value().data_ptr()) : nullptr, batch_size,Please double-check the signature of
TopKSamplingFromProbininclude/flashinfer/sampling.cuhand the expected dtype of any Python-facingtop_ktensor before changing.csrc/renorm.cu (1)
44-61: Renorm and mask kernels now correctly validate optional sampling-parameter tensorsThe added
check_tensor_paramcalls intop_p_renorm_probs,top_k_renorm_probs, andtop_k_mask_logitsenforce the same 0D/2D/batch-size rules as the sampling kernels, which lines up with the new Python tests around tensor-parameter validation and ensures kernels never see inconsistent per-batch parameters.Also applies to: 63-81, 83-101
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (5)
csrc/tvm_ffi_utils.h (2)
246-248: CHECK_LAST_DIM_CONTIGUOUS macro looks malformed
CHECK_LAST_DIM_CONTIGUOUSis missing a stream insertion before the string literal, so expansion will not compile once used. Consider aligning it with other macros:-#define CHECK_LAST_DIM_CONTIGUOUS(x) \ - TVM_FFI_ICHECK_EQ(x.stride(-1), 1) \ - #x "must be contiguous at last dimension"; +#define CHECK_LAST_DIM_CONTIGUOUS(x) \ + TVM_FFI_ICHECK_EQ(x.stride(-1), 1) \ + << #x " must be contiguous at last dimension";
258-261: Optional-input dtype helper is fine; just keep usage as a standalone statement
CHECK_MAYBE_INPUT_TYPEcorrectly gatesCHECK_INPUT_TYPEonhas_value()and matchesOptional<TensorView>usage in the CUDA entry points. Since it expands to a bareifblock, it should only be used as a standalone statement (not directly under anotherifwithout braces), which is consistent with how other macros here are used.csrc/batch_decode_mla_plan.cu (1)
18-18: Good to add dtype check; consider also enforcing device/contiguityAdding
CHECK_INPUT_TYPE(indptr, dl_int32);is the right guard before castingindptr.data_ptr()toIdType*. You might also considerCHECK_INPUT_AND_TYPE(indptr, dl_int32);(or equivalent CUDA/contiguity checks) so we fail fast ifindptris on the wrong device or non-contiguous before passing it intoDecodePlan.tests/utils/test_sampling.py (1)
664-664: Use raw regex strings for thematchpatterns (RUF043)The
"Sampling parameter.*batch size mismatch"patterns intentionally use.*as a regex wildcard. To satisfy Ruff and make the intent explicit, consider raw strings:- with pytest.raises(ValueError, match="Sampling parameter.*batch size mismatch"): + with pytest.raises(ValueError, match=r"Sampling parameter.*batch size mismatch"):Apply the same change at both call sites.
Also applies to: 712-712
csrc/sampling.cu (1)
112-112: Good use ofcheck_tensor_param; consider also validating device if neededCalling
check_tensor_paramonmaybe_top_p_arr,maybe_top_k_arr, andmaybe_min_p_arrbefore launching the kernels gives consistent, centralized validation for parameter tensor rank and batch-size alignment.If there’s any chance these tensors could be created on a different device than
probs, you might also add a lightweight device check (e.g.,CHECK_DEVICE(param, tensor)insidecheck_tensor_param) to catch CPU/GPU mismatches early. Otherwise this looks solid.Also applies to: 138-138, 164-164, 193-194
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between c3a3c48a762891858ae04ac64db9b4ed9623e409 and 5abee0a.
📒 Files selected for processing (12)
csrc/batch_decode.cu(2 hunks)csrc/batch_decode_mla_plan.cu(1 hunks)csrc/batch_decode_mla_run.cu(1 hunks)csrc/batch_mla_plan.cu(1 hunks)csrc/batch_mla_run.cu(1 hunks)csrc/renorm.cu(3 hunks)csrc/sampling.cu(7 hunks)csrc/tvm_ffi_utils.h(1 hunks)flashinfer/decode.py(0 hunks)flashinfer/mla.py(0 hunks)flashinfer/sampling.py(1 hunks)tests/utils/test_sampling.py(3 hunks)
💤 Files with no reviewable changes (2)
- flashinfer/decode.py
- flashinfer/mla.py
🚧 Files skipped from review as they are similar to previous changes (6)
- flashinfer/sampling.py
- csrc/renorm.cu
- csrc/batch_mla_run.cu
- csrc/batch_decode.cu
- csrc/batch_decode_mla_run.cu
- csrc/batch_mla_plan.cu
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/sampling.cu (1)
csrc/renorm.cu (2)
check_tensor_param(25-42)check_tensor_param(25-25)
🪛 Ruff (0.14.7)
tests/utils/test_sampling.py
664-664: Pattern passed to match= contains metacharacters but is neither escaped nor raw
(RUF043)
712-712: Pattern passed to match= contains metacharacters but is neither escaped nor raw
(RUF043)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (3)
tests/utils/test_sampling.py (1)
604-607: Exception expectation now matches C++-side dtype validationSwitching this case to expect a
RuntimeErrorwith the generic"Inconsistency of Tensor type.*maybe_indices"message is consistent with the C++CHECK_INPUT_TYPE/CHECK_MAYBE_INPUT_TYPEpath, which surfaces as a runtime error rather than a PythonValueError.csrc/sampling.cu (2)
24-42: Centralized sampling-parameter shape checks look correct
check_tensor_paramcleanly enforces:
- only scalar-or-1D parameter tensors are allowed, with clear error messages for 0D and >1D, and
- 1D length must match the reference tensor batch size.
Using
TVM_FFI_THROW(ValueError)here aligns with treating these as user argument errors and matches the updated tests.
71-71: Int32 dtype checks for optional indices are a solid safety improvementAdding
CHECK_MAYBE_INPUT_TYPE(maybe_indices, dl_int32);across all sampling entry points ensures we fail fast on mis-typed index tensors before casting toint*, avoiding UB and surfacing a consistent error message instead.Also applies to: 89-89, 109-109, 135-135, 161-161, 190-190
Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
csrc/sampling_utils.h (1)
24-41: LGTM: Robust validation logic with clear error messages.The validation correctly checks dimensionality before accessing
size(0), addressing the concern raised in past reviews. The batch size validation will catch mismatches early with descriptive error messages.Optional: Consider clarifying the error message on lines 29-30.
The phrase "or scalar" might be slightly confusing since PyTorch treats 0-dimensional tensors as scalars. However, since the API design expects scalar values to be passed as primitives (e.g.,
top_p_val,top_k_val) rather than 0-dim tensors, rejecting 0-dim tensors is correct behavior. Consider rephrasing to:- << "Expected a 1D tensor of shape (batch_size,) or scalar for the sampling parameter, " + << "Expected a 1D tensor of shape (batch_size,) for per-batch sampling parameter, " << "but got a 0-dimensional tensor.";This makes it clearer that the function expects either a 1D batch-specific tensor or no tensor at all (with scalar values passed separately).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
csrc/renorm.cu(4 hunks)csrc/sampling.cu(7 hunks)csrc/sampling_utils.h(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
csrc/sampling.cu (2)
flashinfer/logits_processor/types.py (1)
probs(81-85)csrc/sampling_utils.h (1)
check_tensor_param(24-41)
csrc/renorm.cu (2)
csrc/sampling_utils.h (1)
check_tensor_param(24-41)flashinfer/logits_processor/types.py (2)
probs(81-85)logits(74-78)
🪛 Clang (14.0.6)
csrc/sampling_utils.h
[error] 17-17: 'tvm/ffi/container/tensor.h' file not found
(clang-diagnostic-error)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (7)
csrc/renorm.cu (2)
18-18: LGTM: Clean dependency addition.The include for
sampling_utils.happropriately supports the new validation calls added in this file.
31-31: LGTM: Early validation prevents silent failures.The
check_tensor_paramcalls appropriately validate optional sampling parameters before kernel launches, catching shape and batch size mismatches early with clear error messages.Also applies to: 50-50, 70-70
csrc/sampling.cu (3)
18-18: LGTM: Include addition for validation utilities.Appropriately includes the shared validation utilities used throughout this file.
52-52: LGTM: Consistent type validation for optional indices.The
CHECK_MAYBE_INPUT_TYPEmacro appropriately validates that optionalmaybe_indicesparameters have int32 dtype when present, catching type mismatches before kernel execution.Also applies to: 70-70, 90-90, 116-116, 142-142, 171-171
93-93: LGTM: Comprehensive shape validation for sampling parameters.The
check_tensor_paramcalls validate that optional per-batch sampling parameters (top_p, top_k, min_p arrays) are 1D tensors with batch sizes matching the reference tensor, preventing shape mismatches before kernel launches.Also applies to: 119-119, 145-145, 174-175
csrc/sampling_utils.h (2)
16-22: LGTM: Clean header structure.The header guard, includes, and type aliases are appropriate for a shared validation utility.
17-17: Static analysis false positive: TVM header is valid.The static analyzer reports the TVM FFI header as not found, but this is a false positive—the code compiles and tests pass. The analyzer simply lacks the TVM headers in its include path.
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> Originally in this PR flashinfer-ai#1652 and flashinfer-ai#2127 we added better error messaging / prevent silent failures for wrong dtype / shape of tensors. However, it was on the python side when we can instead move to the C++ side (actually the `.cu` side, I guess). We already have various checks here via macros, so it is somewhat natural. ## 🔍 Related Issues See the issues in the PRs above. <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Native modules now perform stricter runtime validation of int index inputs and sampling-parameter shapes; some Python-level preflight checks were removed, deferring certain dtype/shape errors to lower-level code. * **Refactor** * Added a centralized, reusable validation helper for optional sampling parameters to unify checks. * **Tests** * Updated tests to expect different error types and more general error-message matching for the adjusted validation behavior. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> Originally in this PR flashinfer-ai#1652 and flashinfer-ai#2127 we added better error messaging / prevent silent failures for wrong dtype / shape of tensors. However, it was on the python side when we can instead move to the C++ side (actually the `.cu` side, I guess). We already have various checks here via macros, so it is somewhat natural. ## 🔍 Related Issues See the issues in the PRs above. <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Native modules now perform stricter runtime validation of int index inputs and sampling-parameter shapes; some Python-level preflight checks were removed, deferring certain dtype/shape errors to lower-level code. * **Refactor** * Added a centralized, reusable validation helper for optional sampling parameters to unify checks. * **Tests** * Updated tests to expect different error types and more general error-message matching for the adjusted validation behavior. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
📌 Description
Originally in this PR #1652 and #2127 we added better error messaging / prevent silent failures for wrong dtype / shape of tensors. However, it was on the python side when we can instead move to the C++ side (actually the
.cuside, I guess). We already have various checks here via macros, so it is somewhat natural.🔍 Related Issues
See the issues in the PRs above.
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Bug Fixes
Refactor
Tests
✏️ Tip: You can customize this high-level summary in your review settings.