Skip to content

Vendor CCCL v3.3.2 from GitHub instead of relying on CTK-bundled copy#3091

Merged
kahyunnam merged 7 commits intoflashinfer-ai:mainfrom
kahyunnam:knam/cccl-submodule
Apr 22, 2026
Merged

Vendor CCCL v3.3.2 from GitHub instead of relying on CTK-bundled copy#3091
kahyunnam merged 7 commits intoflashinfer-ai:mainfrom
kahyunnam:knam/cccl-submodule

Conversation

@kahyunnam
Copy link
Copy Markdown
Member

@kahyunnam kahyunnam commented Apr 16, 2026

📌 Description

Pin CCCL (CUB, libcudacxx, Thrust) to a specific release tag as a git submodule under 3rdparty/cccl, replacing the implicit dependency on whatever version ships with the user's CUDA Toolkit. This enables the CCCL team to land improvements (e.g., TopK, DeviceTransform, fast_mod_div) independently of CTK releases, and lets FlashInfer adopt new CCCL features immediately.

Vendoring infrastructure (c2cf0708)

  • Add 3rdparty/cccl submodule at CCCL v3.3.2 (maps to CTK 13.2)
  • Wire vendored CCCL into JIT include paths using -I (not -isystem) so it takes precedence over CTK headers, per CCCL guidelines
  • Remove $cuda_home/include/cccl from system includes
  • Package CCCL headers (cub, libcudacxx, thrust) into the wheel via pyproject.toml and build_backend.py
  • Update modal_runner.py fallback clone for CCCL
  • Remove dead #if CUDA_VERSION guards for cub::Max/cub::Min which no longer exist in CCCL 3.x; unconditionally use cuda::maximum/cuda::minimum

Adopt cub::DeviceTransform for LSE computation (9fda09f2)

  • Replace ComputeLSEFromMDKernel (hand-rolled element-wise kernel with manual PDL asm, launch config, bounds checking) with a single cub::DeviceTransform::Transform call
  • DeviceTransform automatically provides PDL on SM90+, vectorized loads, software prefetch, auto-tuned occupancy, and bulk copy (TMA) on Hopper+
  • Uses a named functor (MDToLSE) instead of a lambda to work around an nvcc name-mangling bug
  • log2f replaces the PTX-only math::ptx_log2 since the functor must be __host__ __device__; with -use_fast_math, nvcc emits the same lg2.approx.ftz.f32 instruction on device

Adopt cuda::fast_mod_div for fast integer division (ac916940, ddd77d1c)

  • Replace the FastModDivInt32 polyfill in kernelParams.h with a type alias to cuda::fast_mod_div<int32_t> — the struct it was explicitly polyfilling (binary layout is identical)
  • Reimplement flashinfer::uint_fastdiv in fastdiv.cuh as a thin API-compatible wrapper around cuda::fast_mod_div<uint32_t>, preserving the default constructor, implicit conversions, and divmod() method used by ~30 call sites in the attention, page, MLA, and RoPE kernels

Not changed

  • trtllm::dev::IntFastDiv (MoE routing) — vendored TRT-LLM code with different binary layout and negative divisor support that cuda::fast_mod_div doesn't provide
  • NV-internal / cuDNN / CUTLASS fast-divmod implementations — external codebases, best left alone

📊 Performance

Benchmarked top_k (BF16, non-deterministic, random input) across 162 configurations (batch ∈ {1, 16, 64, 256, 2048, 4096}, seq_len ∈ {256..524288}, k ∈ {256..4096}) on an B200 with CUDA 13.0.
CCCL 3.3.2 (vendored) vs CCCL 3.0 (CTK-bundled):

  • Mean speedup: 1.00x | Median: 1.00x
  • No regressions or improvements beyond measurement noise (±2% at small problem sizes)
  • Scatter plot shows all 162 points on the y=x diagonal
    This is expected — CCCL 3.0 → 3.3.2 is only 3 minor versions apart, and FlashInfer's TopK uses block-level CUB primitives which are stable across minor releases. Users on older CTKs (e.g., 12.6 shipping CCCL 2.5) would see larger gains from the version jump. The primary value of this PR is infrastructure: FlashInfer can now adopt new CCCL features (TopK improvements, DeviceTransform, fast_mod_div, segmented scan) independently of CTK releases, and the CCCL team can land targeted optimizations by bumping the submodule tag.
image

🔍 Related Issues

#3096

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • test_trtllm_gen_prefill passes on SM100 (exercises DeviceTransform LSE, FastModDivInt32, and uint_fastdiv code paths)
  • test_batch_prefill_with_paged_kv_cache passes on SM80+ (exercises uint_fastdiv in core attention)
  • All pre-commit hooks pass (clang-format, ruff, mypy)

Reviewer Notes

  • The CCCL submodule adds to wheel size. Only the header directories needed at JIT time (cub/cub/, libcudacxx/include/, thrust/thrust/) are packaged — not the full CCCL repo.
  • cuda::fast_mod_div has a deleted default constructor, so uint_fastdiv is kept as a thin wrapper to preserve the existing API contract (default-constructible, implicit conversions, .divmod() method) without touching ~30 call sites.
  • The DeviceTransform change uses a named functor instead of a lambda due to an nvcc name-mangling bug with __host__ __device__ lambdas in inline functions used as kernel template arguments.
  • Bumping the CCCL submodule to a newer tag in the future is a one-line git submodule update — no code changes needed unless new APIs are deprecated.

Summary by CodeRabbit

  • Chores
    • Vendored NVIDIA CCCL added and packaged with distributed headers.
    • Build/packaging updated to include CCCL header trees.
    • Runtime tooling now ensures CCCL is present (clones vendored sources if missing).
    • CUDA-related kernel and math handling unified for improved compatibility and correctness.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 16, 2026

📝 Walkthrough

Walkthrough

Adds NVIDIA CCCL as a vendored Git submodule, updates packaging and build to include CCCL headers, changes JIT include resolution to prefer vendored CCCL, and replaces several CUDA/CUB fallback implementations with CCCL/CUDA utilities across kernel and header code. Also auto-clones CCCL in runner scripts when missing.

Changes

Cohort / File(s) Summary
Git Submodule
/.gitmodules, 3rdparty/cccl
Added 3rdparty/cccl submodule pointing at https://github.com/NVIDIA/cccl.git and updated the submodule commit reference.
Packaging & Build Backend
pyproject.toml, build_backend.py
Package now ships flashinfer.data.cccl with selected cub, libcudacxx/include, and thrust paths; build backend copies/links 3rdparty/cccl into package _data_dir for wheel/sdist/editable installs.
JIT include handling
flashinfer/jit/env.py, flashinfer/jit/cpp_ext.py
Added CCCL_INCLUDE_DIRS vendored include list and get_cccl_includes(); JIT build flags now append vendored CCCL -I paths instead of using hardcoded system CCCL include.
Runtime/Startup script
scripts/modal_runner.py
Pre-run step now clones https://github.com/NVIDIA/cccl.git into 3rdparty/cccl if missing (depth=1, fail-fast).
CUDA/C++ kernel & util changes
include/flashinfer/sampling.cuh, csrc/.../trtllm_fused_moe_dev_kernel.cu, include/flashinfer/fastdiv.cuh, include/flashinfer/trtllm/fmha/kernelParams.h, include/flashinfer/trtllm/fmha/lse.cuh
Removed CUDA-version CUB fallbacks and unconditionally use cuda::maximum<>/cuda::minimum<>; replaced several local fast-mod/div implementations with cuda::fast_mod_div<T> and updated related constructors/operators; LSE compute now uses CUB DeviceTransform; added #include <cuda/cmath> where needed.

Sequence Diagram

sequenceDiagram
    participant Dev as Developer (CI / Local)
    participant Script as modal_runner.py
    participant Build as build_backend.py
    participant JIT as JIT Compiler (cpp_ext)
    participant Kern as CUDA Kernels / CUDA Toolchain
    participant Pkg as Package Assembler

    Dev->>Script: run command / CI job
    Script->>Script: ensure `3rdparty/cccl` present (git clone if missing)
    Script->>Build: invoke build
    Build->>Build: copy/link `3rdparty/cccl` -> _data_dir/cccl
    Build->>JIT: initialize JIT env
    JIT->>JIT: load CCCL_INCLUDE_DIRS (vendored -I paths)
    JIT->>Kern: compile kernels with vendored CCCL includes
    Kern->>Kern: use cuda::fast_mod_div / cuda::maximum<> etc.
    Kern-->>JIT: compiled artifacts
    Build->>Pkg: assemble distribution with cccl data paths
    Pkg-->>Dev: produced wheel/sdist/editable package
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested labels

op: attention

Suggested reviewers

  • yzh119
  • aleozlx
  • cyx-6
  • jimmyzho
  • sricketts
  • samuellees
  • nv-yunzheq

Poem

🐇 A rabbit hops, with code in tow,
CCCL bundled, headers ready to show,
No fallbacks now, the includes align,
Fast-div and kernels hum in time,
Hooray — the vendored hop went fine! ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main objective of the PR: vendoring CCCL v3.3.2 as a Git submodule instead of relying on the CTK-bundled version.
Description check ✅ Passed The PR description is comprehensive and well-structured, covering all required template sections with substantial detail.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@kahyunnam
Copy link
Copy Markdown
Member Author

/bot run

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request integrates the NVIDIA CCCL library as a vendored submodule, replacing the dependency on CUDA Toolkit-bundled versions of CUB and Thrust. Key changes include updating the JIT compilation logic to ensure vendored headers take precedence, simplifying CUDA kernels by removing version-specific conditional compilation for reduction operators, and updating build configurations. Review feedback recommends adding specific type hints to the new include retrieval function and addressing a hardcoded version string in the runner script to improve maintainability.

Comment thread flashinfer/jit/cpp_ext.py
Comment thread scripts/modal_runner.py Outdated
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !561 has been created, and the CI pipeline #48720071 is currently running. I'll report back once the pipeline job completes.

…d copy

Pin CCCL (CUB, libcudacxx, Thrust) to a specific release tag as a git
submodule under 3rdparty/cccl, replacing the implicit dependency on
whatever version ships with the user's CUDA Toolkit. This enables the
CCCL team to land TopK improvements independently of CTK releases.

Key changes:
- Add 3rdparty/cccl submodule at CCCL v3.3.2 (maps to CTK 13.2)
- Wire vendored CCCL into JIT include paths using -I (not -isystem)
  so it takes precedence over CTK headers, per CCCL guidelines
- Remove $cuda_home/include/cccl from system includes
- Package CCCL headers (cub, libcudacxx, thrust) into the wheel
- Update build_backend.py symlink/copy logic for editable/wheel/sdist
- Update modal_runner.py fallback clone with --branch=v3.3.2
- Remove dead #if CUDA_VERSION guards for cub::Max/Min which no longer
  exist in CCCL 3.x; unconditionally use cuda::maximum/minimum

Made-with: Cursor
Replace ComputeLSEFromMDKernel (a hand-rolled element-wise CUDA kernel
with manual PDL asm, launch config, and bounds checking) with a single
cub::DeviceTransform::Transform call.

DeviceTransform automatically provides:
- PDL (griddepcontrol) on SM90+ via _CCCL_PDL macros in the kernel
- Vectorized loads when alignment permits
- Software prefetch on Hopper+
- Auto-tuned occupancy and grid sizing
- Bulk copy (TMA) on SM90+

Uses a named functor (MDToLSE) instead of a lambda to avoid an nvcc
name-mangling bug with __host__ __device__ lambdas in inline functions
used as kernel template arguments.

log2f replaces the PTX-only math::ptx_log2 since the functor must be
__host__ __device__; with -use_fast_math, nvcc emits the same
lg2.approx.ftz.f32 instruction on device.

Made-with: Cursor
The FastModDivInt32 struct was an explicit polyfill for
cuda::fast_mod_div<int32_t>, with a comment referencing the CCCL
source. Now that CCCL v3.3.2 is vendored, replace the 25-line
hand-rolled struct with a type alias to the real thing.

The memory layout is identical (divisor, multiplier, add, shift as
4 consecutive 32-bit fields), so this is ABI-compatible with the
precompiled FMHA cubins that consume KernelParams on device.

Made-with: Cursor
Replace the hand-rolled Hacker's Delight fast-division code with
a thin API-compatible wrapper around cuda::fast_mod_div<uint32_t>
from the vendored CCCL v3.3.2.

The wrapper preserves the existing interface (default constructor,
implicit conversions, divmod() method, operator/ and operator%)
so the ~30 call sites in attention, page, MLA, and RoPE kernels
require no changes.

Made-with: Cursor
@kahyunnam
Copy link
Copy Markdown
Member Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !561 has been created, and the CI pipeline #49034669 is currently running. I'll report back once the pipeline job completes.

@kahyunnam kahyunnam changed the title [draft] Vendor CCCL v3.3.2 from GitHub instead of relying on CTK-bundled copy Vendor CCCL v3.3.2 from GitHub instead of relying on CTK-bundled copy Apr 20, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
build_backend.py (1)

108-108: Wheel/sdist copies the entire 3rdparty/cccl tree, not just packaged subdirs.

ln("3rdparty/cccl", "cccl") uses shutil.copytree for wheel/sdist (non-symlink) paths, which copies the whole CCCL checkout (CMake files, docs, tests, .github, c/parallel, cudax, python/, etc.) into flashinfer/data/cccl/. Setuptools' package-data globs in pyproject.toml will still filter what ends up inside the built wheel to just cub/cub/**, libcudacxx/include/**, and thrust/thrust/**, so the final .whl isn't bloated — but the intermediate source tree staged during builds grows by hundreds of MB and every _prepare_for_* call does a full rmtree + recopy.

Consider copying only the three header subtrees that are actually packaged (matching the pyproject.toml globs) for the wheel/sdist case, e.g.:

♻️ Suggested refinement
-    ln("3rdparty/cccl", "cccl")
+    if use_symlinks:
+        ln("3rdparty/cccl", "cccl")
+    else:
+        # Only stage the header subtrees that are declared as package-data in pyproject.toml
+        for sub in ("cub/cub", "libcudacxx/include", "thrust/thrust"):
+            ln(f"3rdparty/cccl/{sub}", f"cccl/{sub}")

Note ln's current dst.exists() handling may need a small tweak to create parent directories for nested targets.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@build_backend.py` at line 108, The ln("3rdparty/cccl", "cccl") call copies
the entire CCCL checkout; change ln so that for non-symlink (wheel/sdist) paths
it only copies the three packaged header subtrees (cub/cub, libcudacxx/include,
thrust/thrust) from 3rdparty/cccl into the cccl target instead of copytreeing
the whole repo; update ln's existence handling to ensure parent directories for
nested targets are created before copy/rename operations so dst.exists() checks
work for nested paths; locate and modify the ln function and its call sites to
implement this selective-copy behavior consistent with the pyproject.toml
package-data globs.
flashinfer/jit/cpp_ext.py (1)

98-146: Precedence over CTK CCCL via -I is correct; minor note on extra_include_dirs ordering.

Emitting vendored CCCL as -I before -isystem $cuda_home/include ensures it shadows the CTK-bundled CCCL during header search, which is the stated goal. The #pragma GCC system_header inside CCCL suppresses warnings, so dropping -isystem here is fine. Inline comment with the issue link is helpful.

One subtle ordering note: extra_include_dirs are emitted as -I before the vendored CCCL -I entries (Line 135-142). If a caller ever passes an extra_include_dirs entry that itself contains a (possibly stale/system) cub/, thrust/, or cuda/ header tree, it will now shadow the vendored CCCL. In practice this is unlikely for current callers, but worth keeping in mind — putting CCCL first among the -I flags would be strictly safer:

♻️ Optional: put vendored CCCL ahead of user extras
-    if extra_include_dirs is not None:
-        for extra_dir in extra_include_dirs:
-            common_cflags.append(f"-I{extra_dir.resolve()}")
-    # Vendored CCCL headers use -I (not -isystem) so they take precedence
-    # over the CTK-bundled copy. CCCL headers use `#pragma` system_header
-    # internally to suppress warnings. See https://github.com/NVIDIA/cccl/issues/527
-    for cccl_dir in cccl_includes:
-        common_cflags.append(f"-I{cccl_dir}")
+    # Vendored CCCL headers use -I (not -isystem) so they take precedence
+    # over the CTK-bundled copy. CCCL headers use `#pragma` system_header
+    # internally to suppress warnings. See https://github.com/NVIDIA/cccl/issues/527
+    for cccl_dir in cccl_includes:
+        common_cflags.append(f"-I{cccl_dir}")
+    if extra_include_dirs is not None:
+        for extra_dir in extra_include_dirs:
+            common_cflags.append(f"-I{extra_dir.resolve()}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/jit/cpp_ext.py` around lines 98 - 146, The current
build_common_cflags emits extra_include_dirs as -I before vendored CCCL -I
entries (cccl_includes), which lets caller-provided includes shadow the vendored
CCCL; to fix, reorder so get_cccl_includes() entries are appended to
common_cflags before iterating extra_include_dirs (or otherwise ensure
cccl_includes are added prior to extra_include_dirs) in build_common_cflags,
referencing the cccl_includes variable and
get_cccl_includes()/build_common_cflags functions to locate the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@scripts/modal_runner.py`:
- Around line 116-128: The fallback clone uses a floating tag (--branch=v3.3.2)
which can diverge from the submodule's pinned commit; change the logic that
currently clones "3rdparty/cccl" to instead clone the repo normally and then
explicitly check out the submodule SHA (the pinned commit
876867684f7fac130e0f5911236e0a92a970d4fd) by running git clone ...
"3rdparty/cccl" followed by git -C 3rdparty/cccl checkout
876867684f7fac130e0f5911236e0a92a970d4fd (or programmatically supply that commit
instead of the v3.3.2 tag) so Modal's fallback matches the submodule commit
referenced in the repo.

---

Nitpick comments:
In `@build_backend.py`:
- Line 108: The ln("3rdparty/cccl", "cccl") call copies the entire CCCL
checkout; change ln so that for non-symlink (wheel/sdist) paths it only copies
the three packaged header subtrees (cub/cub, libcudacxx/include, thrust/thrust)
from 3rdparty/cccl into the cccl target instead of copytreeing the whole repo;
update ln's existence handling to ensure parent directories for nested targets
are created before copy/rename operations so dst.exists() checks work for nested
paths; locate and modify the ln function and its call sites to implement this
selective-copy behavior consistent with the pyproject.toml package-data globs.

In `@flashinfer/jit/cpp_ext.py`:
- Around line 98-146: The current build_common_cflags emits extra_include_dirs
as -I before vendored CCCL -I entries (cccl_includes), which lets
caller-provided includes shadow the vendored CCCL; to fix, reorder so
get_cccl_includes() entries are appended to common_cflags before iterating
extra_include_dirs (or otherwise ensure cccl_includes are added prior to
extra_include_dirs) in build_common_cflags, referencing the cccl_includes
variable and get_cccl_includes()/build_common_cflags functions to locate the
change.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 7e3eb5b3-3e2d-4935-bc81-5950bfc28d47

📥 Commits

Reviewing files that changed from the base of the PR and between a99ee72 and 29c79d6cddaa06f84aa4e6023d0719585bc33e55.

📒 Files selected for processing (9)
  • .gitmodules
  • 3rdparty/cccl
  • build_backend.py
  • csrc/fused_moe/trtllm_backend/trtllm_fused_moe_dev_kernel.cu
  • flashinfer/jit/cpp_ext.py
  • flashinfer/jit/env.py
  • include/flashinfer/sampling.cuh
  • pyproject.toml
  • scripts/modal_runner.py
💤 Files with no reviewable changes (2)
  • csrc/fused_moe/trtllm_backend/trtllm_fused_moe_dev_kernel.cu
  • include/flashinfer/sampling.cuh

Comment thread scripts/modal_runner.py
@kahyunnam
Copy link
Copy Markdown
Member Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !561 has been updated with latest changes, and the CI pipeline #49035860 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
include/flashinfer/trtllm/fmha/kernelParams.h (1)

844-844: ⚠️ Potential issue | 🟡 Minor

Add explicit validation that options.mNumHeadsQPerKv > 0 before constructing FastModDivInt32 at line 844.

cuda::fast_mod_div<int32_t> from CCCL requires a strictly positive divisor; constructing it with zero or negative values triggers an assertion. While the divisibility check at line 92 (num_qo_heads % num_kv_heads != 0) implicitly catches num_kv_heads == 0, there is no explicit guard that num_qo_heads > 0 before computing mNumHeadsQPerKv = num_qo_heads / num_kv_heads. An explicit check at the call site prevents UB if mNumHeadsQPerKv reaches zero.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/trtllm/fmha/kernelParams.h` at line 844, Add an explicit
validation that options.mNumHeadsQPerKv is > 0 before constructing
FastModDivInt32 for params.mNumHeadsQPerKv: check options.mNumHeadsQPerKv > 0
(in the same scope where params.mNumHeadsQPerKv =
FastModDivInt32(options.mNumHeadsQPerKv) is created), and handle the failure by
returning/throwing/logging an error (or asserting) so you never pass
zero/negative into FastModDivInt32; reference the symbols
options.mNumHeadsQPerKv, params.mNumHeadsQPerKv, and FastModDivInt32 when
locating and applying the guard.
🧹 Nitpick comments (2)
include/flashinfer/trtllm/fmha/lse.cuh (1)

29-38: Add a brief inline justification for cub::DeviceTransform and the dropped launch_with_pdl.

The kernel used to be hand-rolled and its sole caller (fmhaKernels.cuh:285-291) still passes params.enable_pdl as the 4th argument, which is now silently discarded. The PR description notes that cub::DeviceTransform already uses PDL/vectorized loads/prefetch/TMA internally — that rationale belongs in the header next to the functor so future readers understand (a) why the parameter is intentionally unused rather than wired through, and (b) why the custom launcher was removed. Per coding guidelines, performance-critical hot paths should document special algorithmic choices and alternatives considered.

💡 Proposed comment
+// Use cub::DeviceTransform instead of a hand-rolled kernel: it auto-tunes
+// occupancy and internally applies PDL, vectorized loads, prefetch, and TMA
+// where available. Because CUB manages PDL itself, `launch_with_pdl` from
+// the caller is intentionally ignored (kept for API stability).
 inline cudaError_t ComputeLSEFromMD(float2* md, float* lse, int n, bool /*launch_with_pdl*/,
                                     cudaStream_t stream) {
   return cub::DeviceTransform::Transform(md, lse, n, MDToLSE{}, stream);
 }

As per coding guidelines: "For performance-critical hot paths, leave comments with justification for special algorithmic choices and mention alternative approaches considered".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/trtllm/fmha/lse.cuh` around lines 29 - 38, Add a concise
inline comment above the MDToLSE functor and ComputeLSEFromMD function
explaining why cub::DeviceTransform is used instead of the previous hand-rolled
kernel and why the boolean parameter launch_with_pdl (and callers passing
params.enable_pdl) is intentionally ignored: state that cub::DeviceTransform
provides PDL/vectorized loads, prefetch/TMA and optimized device transforms
internally so the custom launcher was removed for maintainability and equivalent
performance; mention that callers still pass enable_pdl for API compatibility
and that it is intentionally unused here. Ensure the comment references MDToLSE
and ComputeLSEFromMD by name and briefly notes that alternatives considered
included keeping the hand-rolled kernel or forwarding launch_with_pdl to a
custom launcher.
flashinfer/jit/cpp_ext.py (1)

98-142: LGTM on -I placement and rationale; consider surfacing a clearer error when CCCL paths are absent.

The ordering (extra_include_dirs -Icccl -I → system -isystem) correctly gives vendored CCCL precedence over $cuda_home/include, and the inline comment pointing at CCCL #527 explains the -I vs -isystem choice well.

One minor robustness nit: Path.resolve() (non-strict by default) will happily produce a canonicalized path for directories that don't exist. In a source tree where 3rdparty/cccl hasn't been initialized (e.g., git clone without --recurse-submodules and not an installed wheel), JIT compilation will fail inside nvcc with a header-not-found error instead of a clear message that the submodule is missing. Consider validating at least one expected CCCL header path exists in get_cccl_includes() and raising a descriptive error pointing the user to git submodule update --init 3rdparty/cccl.

💡 Optional hardening
 def get_cccl_includes() -> List:
     """Get vendored CCCL include directories (added with -I for CTK override precedence)."""
-    return [p.resolve() for p in jit_env.CCCL_INCLUDE_DIRS]
+    resolved = [p.resolve() for p in jit_env.CCCL_INCLUDE_DIRS]
+    missing = [str(p) for p in resolved if not p.is_dir()]
+    if missing:
+        raise RuntimeError(
+            "Vendored CCCL headers not found at: "
+            + ", ".join(missing)
+            + ". Initialize the submodule with "
+            "`git submodule update --init --recursive 3rdparty/cccl`."
+        )
+    return resolved
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/jit/cpp_ext.py` around lines 98 - 142, get_cccl_includes currently
returns resolved paths even if the CCCL files are missing, which leads to a
confusing nvcc header-not-found error later; modify get_cccl_includes to check
that at least one expected CCCL header (e.g., a canonical file under
jit_env.CCCL_INCLUDE_DIRS such as "cccl.h" or another known header) exists on
disk after resolving, and if none exist raise a clear RuntimeError explaining
the CCCL submodule is missing and instructing the user to run "git submodule
update --init 3rdparty/cccl" (or to install the wheel) so the JIT compile fails
fast with a helpful message; reference the get_cccl_includes function and
jit_env.CCCL_INCLUDE_DIRS when adding the existence checks and error raise.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Line 844: Add an explicit validation that options.mNumHeadsQPerKv is > 0
before constructing FastModDivInt32 for params.mNumHeadsQPerKv: check
options.mNumHeadsQPerKv > 0 (in the same scope where params.mNumHeadsQPerKv =
FastModDivInt32(options.mNumHeadsQPerKv) is created), and handle the failure by
returning/throwing/logging an error (or asserting) so you never pass
zero/negative into FastModDivInt32; reference the symbols
options.mNumHeadsQPerKv, params.mNumHeadsQPerKv, and FastModDivInt32 when
locating and applying the guard.

---

Nitpick comments:
In `@flashinfer/jit/cpp_ext.py`:
- Around line 98-142: get_cccl_includes currently returns resolved paths even if
the CCCL files are missing, which leads to a confusing nvcc header-not-found
error later; modify get_cccl_includes to check that at least one expected CCCL
header (e.g., a canonical file under jit_env.CCCL_INCLUDE_DIRS such as "cccl.h"
or another known header) exists on disk after resolving, and if none exist raise
a clear RuntimeError explaining the CCCL submodule is missing and instructing
the user to run "git submodule update --init 3rdparty/cccl" (or to install the
wheel) so the JIT compile fails fast with a helpful message; reference the
get_cccl_includes function and jit_env.CCCL_INCLUDE_DIRS when adding the
existence checks and error raise.

In `@include/flashinfer/trtllm/fmha/lse.cuh`:
- Around line 29-38: Add a concise inline comment above the MDToLSE functor and
ComputeLSEFromMD function explaining why cub::DeviceTransform is used instead of
the previous hand-rolled kernel and why the boolean parameter launch_with_pdl
(and callers passing params.enable_pdl) is intentionally ignored: state that
cub::DeviceTransform provides PDL/vectorized loads, prefetch/TMA and optimized
device transforms internally so the custom launcher was removed for
maintainability and equivalent performance; mention that callers still pass
enable_pdl for API compatibility and that it is intentionally unused here.
Ensure the comment references MDToLSE and ComputeLSEFromMD by name and briefly
notes that alternatives considered included keeping the hand-rolled kernel or
forwarding launch_with_pdl to a custom launcher.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c7ef14e0-e8ef-4d05-b839-e2ff4701f5f8

📥 Commits

Reviewing files that changed from the base of the PR and between 29c79d6cddaa06f84aa4e6023d0719585bc33e55 and 91fbbb5.

📒 Files selected for processing (12)
  • .gitmodules
  • 3rdparty/cccl
  • build_backend.py
  • csrc/fused_moe/trtllm_backend/trtllm_fused_moe_dev_kernel.cu
  • flashinfer/jit/cpp_ext.py
  • flashinfer/jit/env.py
  • include/flashinfer/fastdiv.cuh
  • include/flashinfer/sampling.cuh
  • include/flashinfer/trtllm/fmha/kernelParams.h
  • include/flashinfer/trtllm/fmha/lse.cuh
  • pyproject.toml
  • scripts/modal_runner.py
💤 Files with no reviewable changes (2)
  • include/flashinfer/sampling.cuh
  • csrc/fused_moe/trtllm_backend/trtllm_fused_moe_dev_kernel.cu
✅ Files skipped from review due to trivial changes (3)
  • 3rdparty/cccl
  • .gitmodules
  • pyproject.toml
🚧 Files skipped from review as they are similar to previous changes (2)
  • build_backend.py
  • flashinfer/jit/env.py

};

inline cudaError_t ComputeLSEFromMD(float2* md, float* lse, int n, bool launch_with_pdl,
inline cudaError_t ComputeLSEFromMD(float2* md, float* lse, int n, bool /*launch_with_pdl*/,
Copy link
Copy Markdown
Member Author

@kahyunnam kahyunnam Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for reviewers: launch_with_pdl is unused — DeviceTransform enables PDL unconditionally on SM90+ via its internal launcher. On pre-Hopper GPUs, the PDL instructions compile to no-ops. This means callers that pass false will still get PDL when the GPU supports it, which is probably harmless (PDL is a performance hint, not semantic).

Copy link
Copy Markdown
Collaborator

@nv-yunzheq nv-yunzheq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me. Code looks good to me. Please wait for unit test clean to get it merged

@kahyunnam kahyunnam enabled auto-merge (squash) April 22, 2026 18:08
@kahyunnam kahyunnam merged commit 6ddbdb0 into flashinfer-ai:main Apr 22, 2026
32 of 44 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants