Skip to content

feat: C++ side tensor validation#2160

Merged
yzh119 merged 5 commits intoflashinfer-ai:mainfrom
raayandhar:users/rdhar/cpp_side_validation
Dec 3, 2025
Merged

feat: C++ side tensor validation#2160
yzh119 merged 5 commits intoflashinfer-ai:mainfrom
raayandhar:users/rdhar/cpp_side_validation

Conversation

@raayandhar
Copy link
Copy Markdown
Contributor

@raayandhar raayandhar commented Dec 2, 2025

📌 Description

Originally in this PR #1652 and #2127 we added better error messaging / prevent silent failures for wrong dtype / shape of tensors. However, it was on the python side when we can instead move to the C++ side (actually the .cu side, I guess). We already have various checks here via macros, so it is somewhat natural.

🔍 Related Issues

See the issues in the PRs above.

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes

    • Native modules now perform stricter runtime validation of int index inputs and sampling-parameter shapes; some Python-level preflight checks were removed, deferring certain dtype/shape errors to lower-level code.
  • Refactor

    • Added a centralized, reusable validation helper for optional sampling parameters to unify checks.
  • Tests

    • Updated tests to expect different error types and more general error-message matching for the adjusted validation behavior.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Dec 2, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds native runtime validations for index-like and optional sampling tensors (int32 type and scalar/1D batch consistency) in C++/CUDA and removes equivalent Python preflight checks; public API signatures remain unchanged.

Changes

Cohort / File(s) Summary
CUDA decode functions
csrc/batch_decode.cu, csrc/batch_decode_mla_plan.cu, csrc/batch_decode_mla_run.cu
Inserted CHECK_INPUT_TYPE(..., dl_int32) checks for various index-like inputs (indptr, paged_kv_indptr, paged_kv_indices, paged_kv_last_page_len) at function entry points.
CUDA attention / MLA functions
csrc/batch_mla_plan.cu, csrc/batch_mla_run.cu
Added CHECK_INPUT_TYPE(..., dl_int32) validations for attention index tensors (qo_indptr, kv_indptr, kv_len, kv_indices) before workspace/plan computations and plan construction.
CUDA sampling & renorm helpers
csrc/sampling.cu, csrc/renorm.cu, csrc/sampling_utils.h
Introduced check_tensor_param (in csrc/sampling_utils.h) and applied check_tensor_param / CHECK_MAYBE_INPUT_TYPE in sampling and renorm codepaths to validate optional sampling params (scalar or 1D matching batch) and optional index tensors prior to kernel invocations.
TVM FFI utilities
csrc/tvm_ffi_utils.h
Added macro CHECK_MAYBE_INPUT_TYPE(maybe_x, st) to conditionally validate dtype for optional inputs when present.
Python API layer (flashinfer)
flashinfer/decode.py, flashinfer/mla.py, flashinfer/sampling.py
Removed Python-side preflight dtype/shape checks and helper functions for indices and sampling parameters, deferring validation to native C++/CUDA layers.
Tests
tests/utils/test_sampling.py
Adjusted expected error types/messages (broader regexes, RuntimeError expectations) to match moved validation behavior in native code.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

  • Focus review on:
    • csrc/sampling.cu — ensure all sampling entry points consistently use check_tensor_param and CHECK_MAYBE_INPUT_TYPE.
    • csrc/renorm.cu — verify scalar vs 1D detection and batch-size consistency logic.
    • csrc/tvm_ffi_utils.h — confirm macro safety, edge cases, and error message consistency.
    • Python removals (flashinfer/*) — confirm downstream error surfaces and user-facing behavior align with tests.

Poem

🐰 I hopped from Python down to C,

I tucked my checks where kernels be.
Indices whisper thirty-two,
batches counted straight and true.
Hopping light, validations free.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 36.36% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Linked Issues check ❓ Inconclusive The PR description references related PRs #1652 and #2127 and states 'See the issues in the PRs above' but does not provide direct links to specific issues. Add explicit links to the related GitHub issues (not just PR references) to clarify which issues this PR addresses or resolves.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat: C++ side tensor validation' accurately reflects the main change: moving tensor validation checks from Python to C++ code.
Description check ✅ Passed The description includes context from related PRs and explains the motivation for moving validation. However, the PR checklist sections are present but the key content section is minimal.
Out of Scope Changes check ✅ Passed All changes are scoped to tensor validation logic in C++/CUDA files and corresponding Python binding removals, with test updates reflecting the validation move.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @raayandhar, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refactors the tensor input validation mechanism within the FlashInfer library. It shifts the responsibility of checking tensor data types, dimensions, and batch sizes from the Python API to the underlying C++ (CUDA) implementation. This change aims to consolidate error handling, prevent silent failures, and ensure that invalid tensor configurations are identified and reported more effectively at the execution level, leading to a more robust and maintainable codebase.

Highlights

  • Moved Tensor Validation to C++: The core logic for validating tensor data types, dimensions, and batch sizes has been migrated from the Python frontend to the C++ (CUDA) backend for improved efficiency and error handling.
  • Enhanced C++ Validation Macros: New C++ macros, such as CHECK_INPUT_TYPE and CHECK_MAYBE_INPUT_TYPE, along with an inline function check_tensor_param, have been introduced or utilized to perform robust validation directly within the C++ kernel code.
  • Removed Redundant Python Checks: Corresponding Python-side validation functions and checks for tensor dtypes, dimensions, and batch sizes have been removed, streamlining the Python codebase and reducing redundancy.
  • Improved Error Handling: By centralizing validation in C++, errors related to incorrect tensor inputs will now be caught earlier at the runtime level, providing more precise RuntimeError messages from the C++ backend.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully moves tensor validation from Python to C++, which is a good improvement for performance and consistency. The changes are well-contained, and the tests are updated accordingly. I've identified a case of code duplication that affects maintainability and a potential bug in the new validation logic that could lead to a crash. My review comments address these issues.

Comment thread csrc/renorm.cu Outdated
Comment thread csrc/sampling.cu Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
csrc/sampling.cu (1)

24-42: Duplicate code with csrc/renorm.cu.

As noted in the review of csrc/renorm.cu, this function is duplicated. Consider extracting to csrc/tvm_ffi_utils.h.

🧹 Nitpick comments (3)
csrc/renorm.cu (1)

24-42: Duplicate code: check_tensor_param is duplicated in csrc/sampling.cu.

This helper function has an identical implementation in csrc/sampling.cu (lines 24-42). Consider moving it to csrc/tvm_ffi_utils.h alongside the other validation utilities to avoid duplication.

tests/utils/test_sampling.py (2)

634-676: Make match pattern for top‑p batch‑size mismatch a raw regex string

match="Sampling parameter.*batch size mismatch" uses . and * as regex metacharacters but isn’t a raw string, which is what Ruff RUF043 is flagging. Switching to a raw string keeps the intent and silences the lint:

-        with pytest.raises(
-            RuntimeError, match="Sampling parameter.*batch size mismatch"
-        ):
+        with pytest.raises(
+            RuntimeError, match=r"Sampling parameter.*batch size mismatch"
+        ):

(Based on static analysis hints.)


682-726: Likewise, use a raw regex for top‑k batch‑size mismatch

Same issue here: match="Sampling parameter.*batch size mismatch" is intended as a regex but isn’t raw, triggering RUF043. Mirroring the top‑p fix:

-        with pytest.raises(
-            RuntimeError, match="Sampling parameter.*batch size mismatch"
-        ):
+        with pytest.raises(
+            RuntimeError, match=r"Sampling parameter.*batch size mismatch"
+        ):

You could also optionally unify the min‑p mismatch check to use the same Sampling parameter.*batch size mismatch regex for future‑proofing, but that’s not required by this change.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 89e1adb and abc3a392e8a11a4badbcfa278621de723509dbef.

📒 Files selected for processing (12)
  • csrc/batch_decode.cu (2 hunks)
  • csrc/batch_decode_mla_plan.cu (1 hunks)
  • csrc/batch_decode_mla_run.cu (1 hunks)
  • csrc/batch_mla_plan.cu (1 hunks)
  • csrc/batch_mla_run.cu (1 hunks)
  • csrc/renorm.cu (3 hunks)
  • csrc/sampling.cu (7 hunks)
  • csrc/tvm_ffi_utils.h (1 hunks)
  • flashinfer/decode.py (0 hunks)
  • flashinfer/mla.py (0 hunks)
  • flashinfer/sampling.py (1 hunks)
  • tests/utils/test_sampling.py (10 hunks)
💤 Files with no reviewable changes (2)
  • flashinfer/mla.py
  • flashinfer/decode.py
🧰 Additional context used
🧬 Code graph analysis (2)
csrc/sampling.cu (1)
csrc/renorm.cu (2)
  • check_tensor_param (25-42)
  • check_tensor_param (25-25)
csrc/renorm.cu (1)
csrc/sampling.cu (2)
  • check_tensor_param (25-42)
  • check_tensor_param (25-25)
🪛 Ruff (0.14.7)
tests/utils/test_sampling.py

665-665: Pattern passed to match= contains metacharacters but is neither escaped nor raw

(RUF043)


715-715: Pattern passed to match= contains metacharacters but is neither escaped nor raw

(RUF043)

🔇 Additional comments (18)
csrc/tvm_ffi_utils.h (1)

257-260: LGTM!

The macro correctly handles optional tensor type validation, following the existing pattern of validation macros in this file.

csrc/batch_decode.cu (2)

45-46: LGTM!

Early type validation for the indptr tensor ensures proper dtype before proceeding with the plan computation.


91-94: LGTM!

Consistent type validation for all paged KV cache index tensors. These checks align with the subsequent casts to IdType*.

csrc/renorm.cu (3)

50-50: LGTM!

Validation ensures the optional top-p parameter tensor has the correct shape and batch size.


69-69: LGTM!

Consistent parameter validation for the top-k array.


89-89: LGTM!

Validation correctly uses logits as the reference tensor for batch size comparison.

flashinfer/sampling.py (1)

19-19: LGTM!

Import cleanup reflects the removal of Python-side validation helpers. Validation is now handled in the C++ layer.

csrc/sampling.cu (6)

71-71: LGTM!

Validates optional indices tensor dtype before use.


89-89: LGTM!

Consistent index type validation.


109-112: LGTM!

Both index type and top-p parameter validation are properly placed.


135-138: LGTM!

Consistent validation pattern for top-k sampling.


161-164: LGTM!

Consistent validation pattern for min-p sampling.


190-194: LGTM!

Validates indices, top-k, and top-p parameters for the combined sampling function.

tests/utils/test_sampling.py (1)

575-629: Min‑p tensor validation test updates look consistent with new C++ checks

Using RuntimeError plus the more permissive dtype error regex for indices and distinct cases for scalar/1D/2D shapes gives good coverage of the new C++ validation paths. I don’t see issues in this block.

csrc/batch_decode_mla_run.cu (1)

20-22: Int32 checks for paged‑KV metadata are in the right place

Validating paged_kv_indptr, paged_kv_indices, and paged_kv_last_page_len as dl_int32 before constructing DecodePlanInfo and setting up paged_kv_mla_t is consistent with the rest of the decode path and prevents misuse of non‑int32 tensors as index buffers. Looks good.

csrc/batch_decode_mla_plan.cu (1)

18-19: Plan‑side indptr type check is consistent with downstream usage

The CHECK_INPUT_TYPE(indptr, dl_int32); at function entry matches the IdType* usage in DecodePlan and aligns with other newly‑added decode checks. No issues here.

csrc/batch_mla_plan.cu (1)

32-35: Good: MLA plan now validates all index/length tensors as int32

Adding CHECK_INPUT_TYPE for qo_indptr, kv_indptr, and kv_len up front is aligned with how these buffers are used (IdType*) and with the broader move of validation into C++/CUDA. Change looks correct and non‑disruptive.

csrc/batch_mla_run.cu (1)

40-41: Int32 validation for kv_indices matches MLA run expectations

The new CHECK_INPUT_TYPE(kv_indices, dl_int32); is appropriately placed before plan/materialization and consistent with the IdType* usage in params.kv_indices. This should surface dtype mistakes early without affecting correct callers.

@raayandhar raayandhar force-pushed the users/rdhar/cpp_side_validation branch from abc3a39 to 6c2d7c9 Compare December 2, 2025 17:51
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
csrc/sampling.cu (1)

24-42: check_tensor_param matches renorm.cu logic; consider de‑duplicating into a shared header.

The helper correctly enforces that sampling parameter tensors are scalar or 1D and have batch size matching the reference tensor, with error messages aligned to the updated tests. However, an essentially identical check_tensor_param already exists in csrc/renorm.cu; moving this helper into a shared header like tvm_ffi_utils.h and reusing it in both places would avoid future drift and centralize the sampling‑parameter contract.

🧹 Nitpick comments (3)
tests/utils/test_sampling.py (2)

644-666: top‑p tensor‑param tests align with C++ checks; consider raw string for regex.

The move to RuntimeError and the regex patterns for 2D, 0D, and batch‑size mismatch cases line up with the new check_tensor_param helper in C++.

To satisfy Ruff’s RUF043 and make it explicit that match is a regex, you can switch the last pattern to a raw string:

with pytest.raises(
    RuntimeError,
-   match="Sampling parameter.*batch size mismatch",
+   match=r"Sampling parameter.*batch size mismatch",
):
    ...

694-716: top‑k tensor‑param tests mirror the new C++ validation; same raw‑string tweak applies.

The updated RuntimeError expectations and regexes for shape/batch mismatches are consistent with the C++ check_tensor_param behavior.

As with the top‑p tests, consider making the final match a raw string to quiet RUF043 and clarify intent:

with pytest.raises(
    RuntimeError,
-   match="Sampling parameter.*batch size mismatch",
+   match=r"Sampling parameter.*batch size mismatch",
):
    ...
csrc/sampling.cu (1)

71-71: New C++‑side dtype and shape validations for sampling look correct and align with Python/tests.

  • CHECK_MAYBE_INPUT_TYPE(maybe_indices, dl_int32) in sampling_from_logits, sampling_from_probs, and all sampling variants correctly enforces that optional indices are int32, matching how they’re cast to int* in the kernels and how Python tests now assert RuntimeError for non‑int32 indices.
  • check_tensor_param uses in top‑p, top‑k, min‑p, and joint top‑k/top‑p functions properly restrict parameter tensors to 1D (or scalar via the separate scalar argument) and enforce param.size(0) == probs.size(0), which matches the new tests around 2D, 0D, and batch‑size mismatches.

One small consistency nit: top_k_sampling_from_probs, min_p_sampling_from_probs, and top_k_top_p_sampling_from_probs validate both probs and output (including device consistency), whereas top_p_sampling_from_probs only checks probs. For symmetry and slightly better diagnostics, you might also add CHECK_INPUT(output) / CHECK_DEVICE(output, probs) there, though this is a pre‑existing gap and not introduced by this change.

Also applies to: 89-89, 109-113, 136-139, 161-165, 190-195

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between abc3a392e8a11a4badbcfa278621de723509dbef and 6c2d7c987ba754925f30948a61868bb050d3de20.

📒 Files selected for processing (12)
  • csrc/batch_decode.cu (2 hunks)
  • csrc/batch_decode_mla_plan.cu (1 hunks)
  • csrc/batch_decode_mla_run.cu (1 hunks)
  • csrc/batch_mla_plan.cu (1 hunks)
  • csrc/batch_mla_run.cu (1 hunks)
  • csrc/renorm.cu (3 hunks)
  • csrc/sampling.cu (7 hunks)
  • csrc/tvm_ffi_utils.h (1 hunks)
  • flashinfer/decode.py (0 hunks)
  • flashinfer/mla.py (0 hunks)
  • flashinfer/sampling.py (1 hunks)
  • tests/utils/test_sampling.py (10 hunks)
💤 Files with no reviewable changes (2)
  • flashinfer/mla.py
  • flashinfer/decode.py
🚧 Files skipped from review as they are similar to previous changes (4)
  • csrc/batch_mla_plan.cu
  • csrc/batch_decode_mla_run.cu
  • csrc/tvm_ffi_utils.h
  • csrc/renorm.cu
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/sampling.cu (1)
csrc/renorm.cu (2)
  • check_tensor_param (25-42)
  • check_tensor_param (25-25)
🪛 Ruff (0.14.7)
tests/utils/test_sampling.py

665-665: Pattern passed to match= contains metacharacters but is neither escaped nor raw

(RUF043)


715-715: Pattern passed to match= contains metacharacters but is neither escaped nor raw

(RUF043)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (5)
csrc/batch_decode_mla_plan.cu (1)

18-18: Early int32 check for indptr is correct and well‑placed.

Validating indptr with CHECK_INPUT_TYPE(indptr, dl_int32) before passing it as IdType* into DecodePlan gives a clear, early failure for wrong dtypes and matches the kernel’s expectations. No further changes needed here.

csrc/batch_mla_run.cu (1)

40-40: Int32 validation on kv_indices matches kernel expectations.

The new CHECK_INPUT_TYPE(kv_indices, dl_int32) is consistent with using kv_indices as IdType* in the dispatched MLA kernel and will now fail fast on incorrect dtypes instead of producing undefined behavior.

csrc/batch_decode.cu (1)

45-45: Dtype checks for indptr and paged‑KV index tensors are appropriate safeguards.

The CHECK_INPUT_TYPE(..., dl_int32) calls on indptr, paged_kv_indptr, paged_kv_indices, and paged_kv_last_page_len match their use as IdType* in the decode plan/run paths and will surface dtype mismatches as clear runtime errors rather than silent corruption. This is a solid move of validation into the C++ layer.

Also applies to: 91-93

tests/utils/test_sampling.py (1)

585-618: Min‑p validation tests correctly target C++ RuntimeError behavior.

Switching these expectations to RuntimeError and broadening the regex to allow both legacy and new message forms (including the int32 index checks) matches the new C++/TVM‑FFI validation path. The test coverage for 2D/0D/mismatched‑batch and non‑int32 indices looks comprehensive.

flashinfer/sampling.py (1)

20-20: Typing imports now accurately reflect usage after validation refactor.

Switching to from typing import Optional, Tuple, Union is correct given the current annotations, and the removal of Python‑side validation in favor of C++ checks keeps the wrapper thin while tests cover the new error surface.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
csrc/tvm_ffi_utils.h (1)

241-260: Validation logic is sound; minor clarity improvement possible.

The guard if (param.ndim() == 0) before accessing param.size(0) correctly prevents out-of-bounds access (addressing the prior review concern). The logic properly rejects invalid shapes.

One minor observation: the error message mentions "or scalar" as acceptable, but the function rejects 0-dimensional tensors. This is likely intentional (users should use the scalar value parameter like top_p_val instead of a 0-dim tensor), but could be clarified:

-          << "Expected a 1D tensor of shape (batch_size,) or scalar for the sampling parameter, "
+          << "Expected a 1D tensor of shape (batch_size,) for the sampling parameter array, "
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6c2d7c987ba754925f30948a61868bb050d3de20 and 61b9af59fdc78e543edc327da7f29a21b89ea7fd.

📒 Files selected for processing (3)
  • csrc/renorm.cu (3 hunks)
  • csrc/sampling.cu (6 hunks)
  • csrc/tvm_ffi_utils.h (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
csrc/sampling.cu (1)
csrc/tvm_ffi_utils.h (1)
  • check_tensor_param (241-260)
csrc/renorm.cu (1)
csrc/tvm_ffi_utils.h (1)
  • check_tensor_param (241-260)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (10)
csrc/tvm_ffi_utils.h (1)

279-282: LGTM!

The macro correctly wraps the existing CHECK_INPUT_TYPE with an optional guard, providing consistent dtype validation for optional tensor inputs.

csrc/renorm.cu (3)

30-30: LGTM!

Early validation of maybe_top_p_arr against probs ensures shape and batch size consistency before any CUDA kernel execution, providing clear error messages for misconfigured inputs.


49-49: LGTM!

Consistent application of check_tensor_param for maybe_top_k_arr validation.


69-69: LGTM!

Correct validation of maybe_top_k_arr against logits reference tensor.

csrc/sampling.cu (6)

51-51: LGTM!

Proper dtype validation for maybe_indices ensures the optional index tensor is int32 before being cast and passed to the CUDA kernel.


69-69: LGTM!

Consistent int32 type validation for the indices tensor in sampling_from_probs.


89-92: LGTM!

Both validations are correctly placed: dtype check for maybe_indices and shape/batch validation for maybe_top_p_arr against the probs reference tensor.


115-118: LGTM!

Consistent validation pattern for top_k_sampling_from_probs: indices dtype check and parameter tensor shape validation.


141-144: LGTM!

Proper validation for min_p_sampling_from_probs matching the pattern used in other sampling functions.


170-174: LGTM!

The combined top_k_top_p_sampling_from_probs correctly validates all optional inputs: indices dtype and both parameter arrays for shape consistency.

@raayandhar raayandhar changed the title feat: move tensor validation to c++ feat: device-side tensor validation Dec 2, 2025
@yzh119 yzh119 changed the title feat: device-side tensor validation feat: C++ side tensor validation Dec 2, 2025
Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed from "device-side" to "C++ side", when we say device-side, it usually refers to something happened inside kernels.

cc @cyx-6 to comment on the error handling.

@raayandhar
Copy link
Copy Markdown
Contributor Author

Renamed from "device-side" to "C++ side", when we say device-side, it usually refers to something happened inside kernels.

cc @cyx-6 to comment on the error handling.

Got it, sorry about that. Makes sense.

Comment thread csrc/tvm_ffi_utils.h Outdated
Comment thread tests/utils/test_sampling.py
Comment thread csrc/tvm_ffi_utils.h Outdated
Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
@raayandhar raayandhar force-pushed the users/rdhar/cpp_side_validation branch from c3a3c48 to 52e562d Compare December 2, 2025 22:28
Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
@raayandhar raayandhar requested review from tqchen and yzh119 December 2, 2025 22:34
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
csrc/sampling.cu (1)

24-42: check_tensor_param helper looks correct but is duplicated across .cu files

The shape/batch-size validation logic here is sound and, importantly, only reads param.size(0) when ndim() == 1, which addresses the earlier risk of touching size(0) on a 0D tensor. However, this helper is now duplicated in both csrc/sampling.cu and csrc/renorm.cu; consider moving it into tvm_ffi_utils.h (or another shared header) to avoid divergence in future edits.

csrc/renorm.cu (1)

24-42: Shared check_tensor_param logic now safe for 0D tensors but should be centralized

This version correctly throws on 0D and >1D tensors and only inspects size(0) for the 1D case, avoiding the earlier out-of-bounds risk mentioned in prior review comments. Since the same helper now exists in both csrc/renorm.cu and csrc/sampling.cu, it would be cleaner to move it into a shared header (e.g., tvm_ffi_utils.h) and include it from both sites.

🧹 Nitpick comments (1)
tests/utils/test_sampling.py (1)

664-667: Use raw strings for regex patterns to satisfy Ruff and clarify intent

These match="Sampling parameter.*batch size mismatch" patterns intentionally use .* as regex, so they should be raw strings to both silence RUF043 and make the intent explicit.

-        with pytest.raises(ValueError, match="Sampling parameter.*batch size mismatch"):
+        with pytest.raises(ValueError, match=r"Sampling parameter.*batch size mismatch"):
...
-        with pytest.raises(ValueError, match="Sampling parameter.*batch size mismatch"):
+        with pytest.raises(ValueError, match=r"Sampling parameter.*batch size mismatch"):

Also applies to: 712-715

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 61b9af59fdc78e543edc327da7f29a21b89ea7fd and c3a3c48a762891858ae04ac64db9b4ed9623e409.

📒 Files selected for processing (4)
  • csrc/renorm.cu (3 hunks)
  • csrc/sampling.cu (7 hunks)
  • csrc/tvm_ffi_utils.h (1 hunks)
  • tests/utils/test_sampling.py (3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • csrc/tvm_ffi_utils.h
🧰 Additional context used
🧬 Code graph analysis (2)
csrc/renorm.cu (1)
csrc/sampling.cu (2)
  • check_tensor_param (25-42)
  • check_tensor_param (25-25)
csrc/sampling.cu (2)
csrc/nvshmem_binding.cu (4)
  • tensor (59-63)
  • tensor (59-59)
  • tensor (64-64)
  • tensor (64-64)
csrc/renorm.cu (2)
  • check_tensor_param (25-42)
  • check_tensor_param (25-25)
🪛 Ruff (0.14.7)
tests/utils/test_sampling.py

664-664: Pattern passed to match= contains metacharacters but is neither escaped nor raw

(RUF043)


712-712: Pattern passed to match= contains metacharacters but is neither escaped nor raw

(RUF043)

🔇 Additional comments (4)
tests/utils/test_sampling.py (1)

603-612: RuntimeError expectation and flexible regex for non-int32 indices look correct

Switching to RuntimeError here matches how TVM-style CHECK_* failures surface, and the regex covering multiple possible messages (indices must have dtype..., Inconsistency of Tensor type..., int64 vs. int32) makes the test robust to backend wording changes.

csrc/sampling.cu (2)

67-72: New dtype and shape checks on optional tensors align with C++-side validation

Adding CHECK_MAYBE_INPUT_TYPE(maybe_indices, dl_int32) and routing the optional sampling-parameter tensors through check_tensor_param(...) gives consistent, early validation for indices/top-p/top-k/min-p across these entry points, and matches the updated Python tests (RuntimeError for non-int32 indices, ValueError for bad shapes/batch sizes).

Also applies to: 85-90, 103-112, 126-139, 152-165, 179-195


143-147: Verify expected dtype for maybe_top_k_arr in top_k_sampling_from_probs

Here maybe_top_k_arr is passed to TopKSamplingFromProb<float, int> as float*, whereas in top_k_top_p_sampling_from_probs (this file) and the renorm kernels (csrc/renorm.cu) the same per-batch top-k tensor is treated as int*. That inconsistency strongly suggests maybe_top_k_arr is meant to be an int32 tensor everywhere, and this call should likely cast to int* instead:

-      has_top_k_arr ? static_cast<float*>(maybe_top_k_arr.value().data_ptr()) : nullptr, batch_size,
+      has_top_k_arr ? static_cast<int*>(maybe_top_k_arr.value().data_ptr()) : nullptr, batch_size,

Please double-check the signature of TopKSamplingFromProb in include/flashinfer/sampling.cuh and the expected dtype of any Python-facing top_k tensor before changing.

csrc/renorm.cu (1)

44-61: Renorm and mask kernels now correctly validate optional sampling-parameter tensors

The added check_tensor_param calls in top_p_renorm_probs, top_k_renorm_probs, and top_k_mask_logits enforce the same 0D/2D/batch-size rules as the sampling kernels, which lines up with the new Python tests around tensor-parameter validation and ensures kernels never see inconsistent per-batch parameters.

Also applies to: 63-81, 83-101

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (5)
csrc/tvm_ffi_utils.h (2)

246-248: CHECK_LAST_DIM_CONTIGUOUS macro looks malformed

CHECK_LAST_DIM_CONTIGUOUS is missing a stream insertion before the string literal, so expansion will not compile once used. Consider aligning it with other macros:

-#define CHECK_LAST_DIM_CONTIGUOUS(x) \
-  TVM_FFI_ICHECK_EQ(x.stride(-1), 1) \
-  #x "must be contiguous at last dimension";
+#define CHECK_LAST_DIM_CONTIGUOUS(x)                                             \
+  TVM_FFI_ICHECK_EQ(x.stride(-1), 1)                                             \
+      << #x " must be contiguous at last dimension";

258-261: Optional-input dtype helper is fine; just keep usage as a standalone statement

CHECK_MAYBE_INPUT_TYPE correctly gates CHECK_INPUT_TYPE on has_value() and matches Optional<TensorView> usage in the CUDA entry points. Since it expands to a bare if block, it should only be used as a standalone statement (not directly under another if without braces), which is consistent with how other macros here are used.

csrc/batch_decode_mla_plan.cu (1)

18-18: Good to add dtype check; consider also enforcing device/contiguity

Adding CHECK_INPUT_TYPE(indptr, dl_int32); is the right guard before casting indptr.data_ptr() to IdType*. You might also consider CHECK_INPUT_AND_TYPE(indptr, dl_int32); (or equivalent CUDA/contiguity checks) so we fail fast if indptr is on the wrong device or non-contiguous before passing it into DecodePlan.

tests/utils/test_sampling.py (1)

664-664: Use raw regex strings for the match patterns (RUF043)

The "Sampling parameter.*batch size mismatch" patterns intentionally use .* as a regex wildcard. To satisfy Ruff and make the intent explicit, consider raw strings:

-        with pytest.raises(ValueError, match="Sampling parameter.*batch size mismatch"):
+        with pytest.raises(ValueError, match=r"Sampling parameter.*batch size mismatch"):

Apply the same change at both call sites.

Also applies to: 712-712

csrc/sampling.cu (1)

112-112: Good use of check_tensor_param; consider also validating device if needed

Calling check_tensor_param on maybe_top_p_arr, maybe_top_k_arr, and maybe_min_p_arr before launching the kernels gives consistent, centralized validation for parameter tensor rank and batch-size alignment.

If there’s any chance these tensors could be created on a different device than probs, you might also add a lightweight device check (e.g., CHECK_DEVICE(param, tensor) inside check_tensor_param) to catch CPU/GPU mismatches early. Otherwise this looks solid.

Also applies to: 138-138, 164-164, 193-194

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c3a3c48a762891858ae04ac64db9b4ed9623e409 and 5abee0a.

📒 Files selected for processing (12)
  • csrc/batch_decode.cu (2 hunks)
  • csrc/batch_decode_mla_plan.cu (1 hunks)
  • csrc/batch_decode_mla_run.cu (1 hunks)
  • csrc/batch_mla_plan.cu (1 hunks)
  • csrc/batch_mla_run.cu (1 hunks)
  • csrc/renorm.cu (3 hunks)
  • csrc/sampling.cu (7 hunks)
  • csrc/tvm_ffi_utils.h (1 hunks)
  • flashinfer/decode.py (0 hunks)
  • flashinfer/mla.py (0 hunks)
  • flashinfer/sampling.py (1 hunks)
  • tests/utils/test_sampling.py (3 hunks)
💤 Files with no reviewable changes (2)
  • flashinfer/decode.py
  • flashinfer/mla.py
🚧 Files skipped from review as they are similar to previous changes (6)
  • flashinfer/sampling.py
  • csrc/renorm.cu
  • csrc/batch_mla_run.cu
  • csrc/batch_decode.cu
  • csrc/batch_decode_mla_run.cu
  • csrc/batch_mla_plan.cu
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/sampling.cu (1)
csrc/renorm.cu (2)
  • check_tensor_param (25-42)
  • check_tensor_param (25-25)
🪛 Ruff (0.14.7)
tests/utils/test_sampling.py

664-664: Pattern passed to match= contains metacharacters but is neither escaped nor raw

(RUF043)


712-712: Pattern passed to match= contains metacharacters but is neither escaped nor raw

(RUF043)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (3)
tests/utils/test_sampling.py (1)

604-607: Exception expectation now matches C++-side dtype validation

Switching this case to expect a RuntimeError with the generic "Inconsistency of Tensor type.*maybe_indices" message is consistent with the C++ CHECK_INPUT_TYPE / CHECK_MAYBE_INPUT_TYPE path, which surfaces as a runtime error rather than a Python ValueError.

csrc/sampling.cu (2)

24-42: Centralized sampling-parameter shape checks look correct

check_tensor_param cleanly enforces:

  • only scalar-or-1D parameter tensors are allowed, with clear error messages for 0D and >1D, and
  • 1D length must match the reference tensor batch size.

Using TVM_FFI_THROW(ValueError) here aligns with treating these as user argument errors and matches the updated tests.


71-71: Int32 dtype checks for optional indices are a solid safety improvement

Adding CHECK_MAYBE_INPUT_TYPE(maybe_indices, dl_int32); across all sampling entry points ensures we fail fast on mis-typed index tensors before casting to int*, avoiding UB and surfacing a consistent error message instead.

Also applies to: 89-89, 109-109, 135-135, 161-161, 190-190

Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
csrc/sampling_utils.h (1)

24-41: LGTM: Robust validation logic with clear error messages.

The validation correctly checks dimensionality before accessing size(0), addressing the concern raised in past reviews. The batch size validation will catch mismatches early with descriptive error messages.

Optional: Consider clarifying the error message on lines 29-30.

The phrase "or scalar" might be slightly confusing since PyTorch treats 0-dimensional tensors as scalars. However, since the API design expects scalar values to be passed as primitives (e.g., top_p_val, top_k_val) rather than 0-dim tensors, rejecting 0-dim tensors is correct behavior. Consider rephrasing to:

-          << "Expected a 1D tensor of shape (batch_size,) or scalar for the sampling parameter, "
+          << "Expected a 1D tensor of shape (batch_size,) for per-batch sampling parameter, "
           << "but got a 0-dimensional tensor.";

This makes it clearer that the function expects either a 1D batch-specific tensor or no tensor at all (with scalar values passed separately).

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5abee0a and a362384.

📒 Files selected for processing (3)
  • csrc/renorm.cu (4 hunks)
  • csrc/sampling.cu (7 hunks)
  • csrc/sampling_utils.h (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
csrc/sampling.cu (2)
flashinfer/logits_processor/types.py (1)
  • probs (81-85)
csrc/sampling_utils.h (1)
  • check_tensor_param (24-41)
csrc/renorm.cu (2)
csrc/sampling_utils.h (1)
  • check_tensor_param (24-41)
flashinfer/logits_processor/types.py (2)
  • probs (81-85)
  • logits (74-78)
🪛 Clang (14.0.6)
csrc/sampling_utils.h

[error] 17-17: 'tvm/ffi/container/tensor.h' file not found

(clang-diagnostic-error)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (7)
csrc/renorm.cu (2)

18-18: LGTM: Clean dependency addition.

The include for sampling_utils.h appropriately supports the new validation calls added in this file.


31-31: LGTM: Early validation prevents silent failures.

The check_tensor_param calls appropriately validate optional sampling parameters before kernel launches, catching shape and batch size mismatches early with clear error messages.

Also applies to: 50-50, 70-70

csrc/sampling.cu (3)

18-18: LGTM: Include addition for validation utilities.

Appropriately includes the shared validation utilities used throughout this file.


52-52: LGTM: Consistent type validation for optional indices.

The CHECK_MAYBE_INPUT_TYPE macro appropriately validates that optional maybe_indices parameters have int32 dtype when present, catching type mismatches before kernel execution.

Also applies to: 70-70, 90-90, 116-116, 142-142, 171-171


93-93: LGTM: Comprehensive shape validation for sampling parameters.

The check_tensor_param calls validate that optional per-batch sampling parameters (top_p, top_k, min_p arrays) are 1D tensors with batch sizes matching the reference tensor, preventing shape mismatches before kernel launches.

Also applies to: 119-119, 145-145, 174-175

csrc/sampling_utils.h (2)

16-22: LGTM: Clean header structure.

The header guard, includes, and type aliases are appropriate for a shared validation utility.


17-17: Static analysis false positive: TVM header is valid.

The static analyzer reports the TVM FFI header as not found, but this is a false positive—the code compiles and tests pass. The analyzer simply lacks the TVM headers in its include path.

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yzh119 yzh119 merged commit 9ac59e5 into flashinfer-ai:main Dec 3, 2025
4 checks passed
juju812 pushed a commit to juju812/flashinfer that referenced this pull request Dec 4, 2025
<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

Originally in this PR
flashinfer-ai#1652 and
flashinfer-ai#2127 we added better
error messaging / prevent silent failures for wrong dtype / shape of
tensors. However, it was on the python side when we can instead move to
the C++ side (actually the `.cu` side, I guess). We already have various
checks here via macros, so it is somewhat natural.

## 🔍 Related Issues

See the issues in the PRs above.

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Native modules now perform stricter runtime validation of int index
inputs and sampling-parameter shapes; some Python-level preflight checks
were removed, deferring certain dtype/shape errors to lower-level code.

* **Refactor**
* Added a centralized, reusable validation helper for optional sampling
parameters to unify checks.

* **Tests**
* Updated tests to expect different error types and more general
error-message matching for the adjusted validation behavior.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
BingooYang pushed a commit to BingooYang/flashinfer that referenced this pull request Mar 13, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

Originally in this PR
flashinfer-ai#1652 and
flashinfer-ai#2127 we added better
error messaging / prevent silent failures for wrong dtype / shape of
tensors. However, it was on the python side when we can instead move to
the C++ side (actually the `.cu` side, I guess). We already have various
checks here via macros, so it is somewhat natural.

## 🔍 Related Issues

See the issues in the PRs above.

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Native modules now perform stricter runtime validation of int index
inputs and sampling-parameter shapes; some Python-level preflight checks
were removed, deferring certain dtype/shape errors to lower-level code.

* **Refactor**
* Added a centralized, reusable validation helper for optional sampling
parameters to unify checks.

* **Tests**
* Updated tests to expect different error types and more general
error-message matching for the adjusted validation behavior.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants