[gdn] cuTile GDN prefill on Blackwell (SM100)#2729
Conversation
…(SM100) - flashinfer/gdn_kernels/cutile_gdn_prefill.py: cuTile-based GDN chunked prefill kernel (4 fused kernels: cumsum, prepare, recurrence, output), replacing the 6-stage Triton pipeline. Targets SM100 (Blackwell B200). - flashinfer/gdn_kernels/__init__.py: export chunk_gated_delta_rule_cutile. - tests/gdn/test_prefill_cutile_blackwell.py: pytest accuracy tests comparing cuTile vs FLA Triton GDN prefill on B200. 33 configs (B=1..16, T=64..4096, K/V=64/128/256, w/wo initial state). All 33 PASS (bfloat16, atol=rtol=1e-2). - benchmarks/bench_gdn_prefill_cutile_vs_fla_blackwell.py: kernel latency comparison. Results on B200: 1.38x–1.83x speedup vs Triton baseline. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the GDN prefill functionality by introducing a new cuTile-based kernel optimized for NVIDIA Blackwell GPUs. It includes comprehensive accuracy tests to ensure correctness against an existing baseline and provides benchmarks showcasing substantial performance gains. The overall impact is improved efficiency and broader hardware compatibility for GDN prefill operations. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
📝 WalkthroughWalkthroughAdds a cuTile-based GDN prefill implementation for Blackwell GPUs, plus a benchmark script comparing cuTile vs FLA Triton, test coverage validating outputs against the FLA baseline, and an optional cuTile import exported from gdn_kernels. Changes
Sequence DiagramsequenceDiagram
participant Caller
participant Entrypoint as chunk_gated_delta_rule_cutile
participant L2Norm
participant CacheInit
participant Scheduler as ExecutionPath
participant Kernels
participant Profiler
participant Output
Caller->>Entrypoint: call(q,k,v,g,beta, initial_state, ...)
Entrypoint->>L2Norm: optionally normalize q,k
L2Norm-->>Entrypoint: normalized tensors
Entrypoint->>CacheInit: allocate/init caches & tensors
CacheInit-->>Entrypoint: cache handles
Entrypoint->>Scheduler: choose execution path (fused vs cached WY/REC)
Scheduler->>Kernels: launch cuTile kernels per chunk
Kernels->>Kernels: compute WY / CKKT solve / recurrence / output
Kernels-->>Entrypoint: kernel outputs (o, h, optional aux)
Entrypoint->>Profiler: record kernel & wall-clock timing (benchmarks)
Profiler-->>Caller: timing results (when run)
Entrypoint->>Output: finalize and return (o, None, h)
Output-->>Caller: outputs
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
⚔️ Resolve merge conflicts
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces benchmarks and accuracy tests for the new cuTile-based GDN prefill kernel targeting NVIDIA's Blackwell (SM100) architecture. The code is well-structured, with clear separation between tests and benchmarks. My review includes a suggestion to make the benchmark script more robust by ensuring it exits on non-SM100 hardware, which will prevent misleading performance results. The accuracy tests are comprehensive and the changes to the package's __init__.py are correct.
| is_sm100 = _check_sm100() | ||
| print(f"GPU: {torch.cuda.get_device_name(0)}") | ||
| print(f"Device capability: SM{torch.cuda.get_device_capability()}") | ||
| print() |
There was a problem hiding this comment.
The _check_sm100 function's return value is stored in is_sm100 but never used. The benchmark proceeds even on non-SM100 hardware, printing a warning but potentially producing misleading or incorrect performance numbers. Since the cuTile kernel is specifically for SM100+, the benchmark should exit if the hardware is not supported to avoid confusion.
| is_sm100 = _check_sm100() | |
| print(f"GPU: {torch.cuda.get_device_name(0)}") | |
| print(f"Device capability: SM{torch.cuda.get_device_capability()}") | |
| print() | |
| if not _check_sm100(): | |
| print("\nExiting benchmark: This benchmark is for SM100+ (Blackwell) GPUs.") | |
| return | |
| print(f"GPU: {torch.cuda.get_device_name(0)}") | |
| print(f"Device capability: SM{torch.cuda.get_device_capability()}") | |
| print() |
Key optimizations over the initial kernel submission: - Adaptive output kernel occupancy: use occ2 for NT*BH>=256 (was >512), saving ~4us on medium-grid configs (B=2/4/8 with T=1024-2048) - Adaptive recurrence BV dispatch: tune CTA count per BH range (BH>=64→512 CTAs, BH<=8→128, BH<=4+long→64) for optimal wave count - Fused ckkt+KKT+solve v2 with 10-step squaring trick (occ1/2/3 variants) - Q-cached output kernel: load Q once in registers, reuse across V-blocks - Many experimental kernels explored (fused ckkt+WY, no-WY rec, ILP-2 split, higher-occ rec/output) — kept only the ones that improved perf AI-assisted (Claude Code) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Update benchmark configs to match Qwen3.5 linear attention parameters: K=128, V=128, H=32 (was K=256, V=256, H=4) - Change initial_state (SSM state) dtype from bf16 to fp32, matching sglang's default mamba2_state_dtype - Update accuracy tests with Qwen3.5 parametrization - Add performance summary markdown with kernel breakdown AI-assisted (Claude Code) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Add TP8 benchmark configs: H=4 (4B model) and H=8 (397B model) - Optimize rec BV for small V: cap max_rec_CTAs=256 when BH>=64 and V<=128, using BV=32 (0.87 waves) instead of BV=16 (1.73 waves). This improves large-batch H=8 configs by 0.10-0.21x. - Key improvements: B=16,T=1024,H=4: 1.53x -> 1.77x B=8,T=1024,H=8: 1.42x -> 1.60x B=8,T=2048,H=8: 1.38x -> 1.59x B=16,T=1024,H=8: 1.62x -> 1.69x AI-assisted (Claude Code) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
f524be1 to
99ce4d8
Compare
Add cutile_fused_wy_rec_kernel that computes w and u on-the-fly per chunk inside the recurrence loop, eliminating w/u HBM write+read. For B=8,T=2048,H=8,K=128: saves ~128MB bandwidth (~16us at 8TB/s). Dispatch condition: K<=128, BH>=64, BV<=32, NT*BH>=512. This targets large-batch TP8 configs where bandwidth savings outweigh the extra per-chunk A_inv WGMMA compute. Key improvements on Qwen3.5 TP8: B=16,T=1024,H=4: 1.74x -> 1.90x B=8 ,T=1024,H=8: 1.60x -> 1.76x B=8 ,T=2048,H=8: 1.59x -> 1.72x AI-assisted (Claude Code) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
When fused WY+rec dispatch would skip a config due to rec_BV>32 (e.g. B=16,T=1024,H=8 with BV=64), fall back to BV=32 with more CTAs for the fused path. This ensures all BH>=64 K<=128 configs use the bandwidth-optimized fused kernel. No perf change for B=16 configs (BV=32/512CTAs ≈ BV=64/256CTAs), but enables future optimizations on the unified fused path. AI-assisted (Claude Code) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Remove unused experimental kernels that were explored during optimization
but never activated in the dispatch path. No functional change.
Removed: cutile_chunk_local_cumsum, cutile_cumsum_kkt_kernel,
cutile_fused_ckkt_solve_kernel, cutile_prepare_kernel,
cutile_output_multiV_kernel, cutile_output_qcached_occ3{,_v2}_kernel,
cutile_kkt_solve_kernel, cutile_fused_kkt_solve_wy_kernel_v2,
cutile_kkt_kernel, cutile_wy_occ6_kernel, cutile_solve_kernel_occ2,
cutile_wy_kernel_occ2, cutile_recurrence_kernel_bv64,
cutile_recurrence_kernel_v{2,3}, cutile_recurrence_kernel_bv16_occ4,
cutile_recurrence_kernel_bv16_split2, cutile_rec_nowy_bv16_occ2_kernel,
cutile_recompute_wu_occ4_kernel, cutile_fused_wy_rec_output_kernel,
cutile_fused_ckkt_solve_v2_occ4_kernel, cutile_fused_ckkt_wy_occ2_kernel
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Remove standalone cumsum, solve, WY, rec, and fused ckkt+WY kernels that are superseded by the fused ckkt_solve_v2 and fused_wy_rec variants. Only dispatch-active kernels remain. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ines) Eliminate occupancy variants by keeping only the best single variant: - ckkt: keep occ3 only (was occ1/2/3) - output_qcached: keep occ2 only (was occ1/2) - output_tiled: keep occ2 only (was occ1/2) - recurrence: keep occ2 only (was occ2/3) Small grid configs (B=1) lose 0.03-0.05x from ckkt occ3 register pressure at <512 CTAs, acceptable tradeoff for code simplicity. Also share kernel bodies via factory pattern: define body once, create occupancy variant via ct.kernel(occupancy=N)(body). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
… lines) For Qwen3.5 V=128 with BV=128, there's only 1 V-block iteration, so Q-caching has zero overhead. Remove the separate 3D-grid output kernel and simplify dispatch to always use the qcached variant. 5 kernels remain: ckkt_solve (occ3), output_qcached (occ2), recurrence (occ2), fused_wy_rec (occ2), l2norm. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Remove unused _solve_tril_cached and its imports (solve_tril_16x16,
merge_16x16 kernels) — solve is fused into ckkt kernel
- Remove unused numpy import
- Remove stale cache entries ('A', 'Ad') from _init_cache
- Deduplicate comments in dispatch logic
- Update module docstring to reflect current 5-kernel architecture
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Actionable comments posted: 6
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/bench_gdn_prefill_cutile_vs_fla_blackwell.py`:
- Around line 41-71: Move the cuTile/sglang bootstrap out of top-level import so
it only runs when cuTile is actually available: read the root from an env/config
variable (use _SGLANG_ROOT from os.environ or a config fallback) instead of the
hard-coded Path, check CT_AVAILABLE (or validate that the configured
_SGLANG_ROOT exists and contains the expected files) before calling _stub/_load,
and wrap the for-loop that imports "sglang.srt.layers.attention.fla.*" inside
that availability guard; also make _load/_stub resilient by catching and
silencing file-not-found/import errors so the module import will cleanly report
cuTile unavailable rather than crashing.
- Around line 223-226: The call to _check_sm100() assigns is_sm100 but the
script continues even when it returns False, leaving an unused-variable (Ruff
F841) and allowing unsupported GPUs to proceed; change the logic after calling
_check_sm100() so that if it returns False you immediately stop the run (e.g.,
call sys.exit(1) or raise SystemExit) with a clear error message, otherwise
continue and print the GPU info—this both uses the return value and prevents
running the Blackwell-only path on unsupported hardware.
In `@flashinfer/gdn_kernels/cutile_gdn_prefill.py`:
- Around line 249-253: The optional initial_state/initial_state_indices path is
broken because USE_INITIAL only gates the load but both recurrence kernels
always write back into initial_state/initial_state_indices; fix by allocating
internal zero tensors when initial_state or initial_state_indices is None
instead of assuming callers supply them. In the functions that call the CUDA
kernels (e.g., chunk_gated_delta_rule and the other entrypoints around the shown
regions), detect if initial_state is None or initial_state_indices is None and
create zero tensors with the correct shape, device and dtype (initial_state:
[N,H,K,V]; initial_state_indices matching expected index shape), then pass those
buffers into the kernels so the write-back has a valid destination;
alternatively, remove Optional typing and require callers to pass buffers, but
prefer the internal-allocation approach so semantics match existing API. Ensure
allocations respect cu_seqlens / batch size and preserve device/dtype.
- Around line 196-205: The module-level scratch state (_tensor_cache,
_cache_config, _cached_stream) makes the exported kernel non-reentrant; change
the caching to be scoped per-call or at least per-device+stream. Update
_init_cache and callers (e.g., functions that use
_tensor_cache/_cache_config/_cached_stream) to accept a stream identifier and
include it in the cache key (for example key =
(B,T,H,K,V,device,dtype_k,dtype_v,stream)) or replace the globals with a
per-call local cache returned by _init_cache; ensure every read/write of
_tensor_cache/_cache_config/_cached_stream is replaced with lookups into the new
keyed dict or the returned local cache so concurrent invocations on the same
device but different streams do not collide.
- Around line 23-30: The import fallback sets ct = None which causes
AttributeError when module-level code later calls or uses `@ct.kernel`; instead
create a lightweight sentinel that raises ImportError when its kernel attribute
is accessed/used so the optional-import pattern behaves correctly. Replace the
except ImportError block so that CUTILE_AVAILABLE=False and ct is an object
(e.g., a small class instance or SimpleNamespace) that defines a kernel
attribute (callable/decorator) which, when invoked, raises
ImportError("cuda.tile not available"), and set Const to a similar sentinel or
None-consistent value; this ensures all module-level uses of ct.kernel and
references to Const (the symbols ct.kernel and Const) raise ImportError rather
than AttributeError during import.
In `@tests/gdn/test_prefill_cutile_blackwell.py`:
- Around line 29-30: Replace the local CUDA wrapper and manual major-version
check with the repo helpers: import flashinfer.utils.get_compute_capability and
flashinfer.utils.is_sm100a_supported, remove the local
get_compute_capability(device) that calls torch.cuda.get_device_capability, and
replace any checks like cc[0] != 10 with a call to
is_sm100a_supported(get_compute_capability(device)) (or its negation) to decide
skips; apply the same replacement for the other skip block referenced (the logic
around the cc checks at the later block).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: b3847a0e-f13a-4a16-a856-32ed41d1f151
📒 Files selected for processing (4)
benchmarks/bench_gdn_prefill_cutile_vs_fla_blackwell.pyflashinfer/gdn_kernels/__init__.pyflashinfer/gdn_kernels/cutile_gdn_prefill.pytests/gdn/test_prefill_cutile_blackwell.py
| import importlib.util, sys, types, pathlib | ||
| _kernel_path = pathlib.Path(__file__).parent.parent / "flashinfer/gdn_kernels/cutile_gdn_prefill.py" | ||
| _SGLANG_ROOT = pathlib.Path("/home/scratch.xutingz_wwfo/gitsrc/sglang") | ||
| # Bootstrap sglang FLA modules needed by cutile_gdn_prefill at runtime | ||
| def _stub(name): | ||
| parts = name.split('.') | ||
| for i in range(1, len(parts) + 1): | ||
| pkg = '.'.join(parts[:i]) | ||
| if pkg not in sys.modules: | ||
| m = types.ModuleType(pkg); m.__path__ = []; m.__package__ = pkg | ||
| sys.modules[pkg] = m | ||
| def _load(name, path): | ||
| _stub(name) | ||
| spec = importlib.util.spec_from_file_location(name, path) | ||
| mod = importlib.util.module_from_spec(spec) | ||
| mod.__package__ = name.rsplit('.', 1)[0] | ||
| sys.modules[name] = mod | ||
| spec.loader.exec_module(mod) | ||
| return mod | ||
| import torch as _torch | ||
| _stub("sglang.srt.utils.common") | ||
| sys.modules["sglang.srt.utils.common"].torch_release = tuple( | ||
| int(x) for x in _torch.__version__.split(".")[:2] if x.isdigit() | ||
| ) | ||
| _sgl_python = str(_SGLANG_ROOT / "python") | ||
| if _sgl_python not in sys.path: | ||
| sys.path.insert(0, _sgl_python) | ||
| _fla = _SGLANG_ROOT / "python/sglang/srt/layers/attention/fla" | ||
| for _n in ["utils", "l2norm", "op", "index", "cumsum", | ||
| "chunk_scaled_dot_kkt", "solve_tril", "wy_fast", "chunk_delta_h", "chunk_o"]: | ||
| _load(f"sglang.srt.layers.attention.fla.{_n}", str(_fla / f"{_n}.py")) |
There was a problem hiding this comment.
Make the cuTile/sglang bootstrap portable and lazy.
This block hard-codes a workstation-local checkout path and eagerly loads files from it before CT_AVAILABLE is set. On any machine without that exact tree, the benchmark fails during import with a file-loading error instead of cleanly reporting that cuTile is unavailable. This same block is also the current Ruff E401/E702 failure.
Please move the bootstrap behind a real availability check and source the root from config/env instead of an absolute path.
🧰 Tools
🪛 GitHub Actions: pre-commit
[error] 41-41: Ruff: E401 Multiple imports on one line. Split imports across lines.
[error] 50-50: Ruff: E702 Multiple statements on one line (semicolon). Split into separate statements.
[error] 50-52: Ruff: E702 Multiple statements on one line (semicolon) in _load helper block.
[error] 41-43: Ruff: E401 Multiple imports on one line. Split imports across lines.
[warning] 50-50: Ruff: E702 Multiple statements on one line (semicolon) detected; consider breaking into separate lines.
🪛 Ruff (0.15.5)
[error] 50-50: Multiple statements on one line (semicolon)
(E702)
[error] 50-50: Multiple statements on one line (semicolon)
(E702)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@benchmarks/bench_gdn_prefill_cutile_vs_fla_blackwell.py` around lines 41 -
71, Move the cuTile/sglang bootstrap out of top-level import so it only runs
when cuTile is actually available: read the root from an env/config variable
(use _SGLANG_ROOT from os.environ or a config fallback) instead of the
hard-coded Path, check CT_AVAILABLE (or validate that the configured
_SGLANG_ROOT exists and contains the expected files) before calling _stub/_load,
and wrap the for-loop that imports "sglang.srt.layers.attention.fla.*" inside
that availability guard; also make _load/_stub resilient by catching and
silencing file-not-found/import errors so the module import will cleanly report
cuTile unavailable rather than crashing.
| is_sm100 = _check_sm100() | ||
| print(f"GPU: {torch.cuda.get_device_name(0)}") | ||
| print(f"Device capability: SM{torch.cuda.get_device_capability()}") | ||
| print() |
There was a problem hiding this comment.
Stop the run when _check_sm100() fails.
The return value is ignored, so the script still executes the Blackwell-only path on unsupported GPUs after printing a warning. That produces misleading numbers at best and runtime failures at worst, and it leaves the current Ruff F841 error unresolved.
Suggested fix
- is_sm100 = _check_sm100()
+ if not _check_sm100():
+ return
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Device capability: SM{torch.cuda.get_device_capability()}")🧰 Tools
🪛 GitHub Actions: pre-commit
[error] 223-223: Ruff: F841 Local variable is_sm100 is assigned to but never used.
[error] 223-223: Ruff: Remove assignment to unused variable is_sm100.
[warning] 223-223: Ruff: F841 Local variable is_sm100 is assigned to but never used (suggest removal).
🪛 Ruff (0.15.5)
[error] 223-223: Local variable is_sm100 is assigned to but never used
Remove assignment to unused variable is_sm100
(F841)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@benchmarks/bench_gdn_prefill_cutile_vs_fla_blackwell.py` around lines 223 -
226, The call to _check_sm100() assigns is_sm100 but the script continues even
when it returns False, leaving an unused-variable (Ruff F841) and allowing
unsupported GPUs to proceed; change the logic after calling _check_sm100() so
that if it returns False you immediately stop the run (e.g., call sys.exit(1) or
raise SystemExit) with a clear error message, otherwise continue and print the
GPU info—this both uses the return value and prevents running the Blackwell-only
path on unsupported hardware.
| try: | ||
| import cuda.tile as ct | ||
| from cuda.tile import Constant as Const | ||
| CUTILE_AVAILABLE = True | ||
| except ImportError: | ||
| CUTILE_AVAILABLE = False | ||
| ct = None | ||
| Const = None |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
rg -n -A2 -B2 'ct = None|@ct\.kernel|ct\.kernel\(' flashinfer/gdn_kernels/cutile_gdn_prefill.pyRepository: flashinfer-ai/flashinfer
Length of output: 931
🏁 Script executed:
cat -n flashinfer/gdn_kernels/__init__.py | head -50Repository: flashinfer-ai/flashinfer
Length of output: 2005
🏁 Script executed:
rg -n "cutile_gdn_prefill|ImportError" tests/gdn/ -A2 -B2Repository: flashinfer-ai/flashinfer
Length of output: 2291
🏁 Script executed:
rg -n "CUTILE_AVAILABLE" flashinfer/ -A2 -B2Repository: flashinfer-ai/flashinfer
Length of output: 582
🏁 Script executed:
rg -n "cutile_gdn_prefill" flashinfer/gdn_kernels/__init__.pyRepository: flashinfer-ai/flashinfer
Length of output: 136
🏁 Script executed:
sed -n '23,160p' flashinfer/gdn_kernels/cutile_gdn_prefill.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 6562
🏁 Script executed:
sed -n '48,55p' flashinfer/gdn_kernels/__init__.pyRepository: flashinfer-ai/flashinfer
Length of output: 273
Module-level ct.kernel() calls will raise AttributeError, not ImportError, when cuda.tile is unavailable.
When ct = None is set on import failure (line 29), the unconditional module-level code at lines 125, 131, 436, 446, and 627 will attempt to call ct.kernel(...) and use @ct.kernel decorator during import. This raises AttributeError instead of the expected ImportError, which breaks the optional-import pattern used in flashinfer/gdn_kernels/__init__.py (line 51-52) and expected by tests (tests/gdn/test_prefill_cutile_blackwell.py lines 35-40).
Suggested fix
-try:
- import cuda.tile as ct
- from cuda.tile import Constant as Const
- CUTILE_AVAILABLE = True
-except ImportError:
- CUTILE_AVAILABLE = False
- ct = None
- Const = None
+try:
+ import cuda.tile as ct
+ from cuda.tile import Constant as Const
+ CUTILE_AVAILABLE = True
+except ImportError as e:
+ raise ImportError("cuda.tile is required for cutile_gdn_prefill") from e🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_kernels/cutile_gdn_prefill.py` around lines 23 - 30, The
import fallback sets ct = None which causes AttributeError when module-level
code later calls or uses `@ct.kernel`; instead create a lightweight sentinel that
raises ImportError when its kernel attribute is accessed/used so the
optional-import pattern behaves correctly. Replace the except ImportError block
so that CUTILE_AVAILABLE=False and ct is an object (e.g., a small class instance
or SimpleNamespace) that defines a kernel attribute (callable/decorator) which,
when invoked, raises ImportError("cuda.tile not available"), and set Const to a
similar sentinel or None-consistent value; this ensures all module-level uses of
ct.kernel and references to Const (the symbols ct.kernel and Const) raise
ImportError rather than AttributeError during import.
| _tensor_cache = {} | ||
| _cache_config = None # (B, T, H, K, V, device, dtype) tuple for invalidation | ||
|
|
||
| def _init_cache(B, T, H, K, V, device, dtype_k, dtype_v): | ||
| """Pre-allocate all intermediate tensors for the given config.""" | ||
| global _tensor_cache, _cache_config | ||
| key = (B, T, H, K, V, device, dtype_k, dtype_v) | ||
| if _cache_config == key: | ||
| return # Already initialized for this config | ||
| _cache_config = key |
There was a problem hiding this comment.
The module-level scratch state makes this exported kernel non-reentrant.
_tensor_cache, _cache_config, and _cached_stream are shared across all invocations. Two overlapping calls can overwrite the same intermediates/output buffers, and later calls inherit launch state from earlier ones. That is risky for an exported library API even if it helps single-call benchmark numbers.
Please scope scratch buffers and stream selection to the call, or at least shard them by device/stream instead of keeping a single global instance.
Also applies to: 240-289
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_kernels/cutile_gdn_prefill.py` around lines 196 - 205, The
module-level scratch state (_tensor_cache, _cache_config, _cached_stream) makes
the exported kernel non-reentrant; change the caching to be scoped per-call or
at least per-device+stream. Update _init_cache and callers (e.g., functions that
use _tensor_cache/_cache_config/_cached_stream) to accept a stream identifier
and include it in the cache key (for example key =
(B,T,H,K,V,device,dtype_k,dtype_v,stream)) or replace the globals with a
per-call local cache returned by _init_cache; ensure every read/write of
_tensor_cache/_cache_config/_cached_stream is replaced with lookups into the new
keyed dict or the returned local cache so concurrent invocations on the same
device but different streams do not collide.
| scale: Optional[float] = None, | ||
| initial_state: Optional[torch.Tensor] = None, # [N, H, K, V] | ||
| initial_state_indices: Optional[torch.Tensor] = None, | ||
| cu_seqlens: Optional[torch.Tensor] = None, | ||
| use_qk_l2norm_in_kernel: bool = False, |
There was a problem hiding this comment.
The advertised initial_state=None path is still broken.
USE_INITIAL only controls the initial load. Both recurrence kernels always write back through initial_state and initial_state_indices at the end, so the default initial_state=None path cannot work as declared. The new tests avoid tripping this only by always fabricating zero h0/idx tensors when use_initial_state=False.
If this entry point is meant to preserve the existing flashinfer.gdn_prefill.chunk_gated_delta_rule semantics, it needs an internal zero-state/indices allocation when initial_state is omitted, or the parameter should stop being optional.
Suggested fix
USE_INITIAL = initial_state is not None
+
+ if initial_state is None:
+ initial_state = torch.zeros(
+ B, H, K, V, dtype=torch.float32, device=q.device
+ )
+ initial_state_indices = torch.arange(
+ B, dtype=torch.int32, device=q.device
+ )
+ elif initial_state_indices is None:
+ raise ValueError("initial_state_indices is required when initial_state is provided")Also applies to: 278-325, 425-432, 535-542
🧰 Tools
🪛 Ruff (0.15.5)
[warning] 252-252: Unused function argument: cu_seqlens
(ARG001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_kernels/cutile_gdn_prefill.py` around lines 249 - 253, The
optional initial_state/initial_state_indices path is broken because USE_INITIAL
only gates the load but both recurrence kernels always write back into
initial_state/initial_state_indices; fix by allocating internal zero tensors
when initial_state or initial_state_indices is None instead of assuming callers
supply them. In the functions that call the CUDA kernels (e.g.,
chunk_gated_delta_rule and the other entrypoints around the shown regions),
detect if initial_state is None or initial_state_indices is None and create zero
tensors with the correct shape, device and dtype (initial_state: [N,H,K,V];
initial_state_indices matching expected index shape), then pass those buffers
into the kernels so the write-back has a valid destination; alternatively,
remove Optional typing and require callers to pass buffers, but prefer the
internal-allocation approach so semantics match existing API. Ensure allocations
respect cu_seqlens / batch size and preserve device/dtype.
| def get_compute_capability(device): | ||
| return torch.cuda.get_device_capability(device) |
There was a problem hiding this comment.
Use the repo-standard SM100 helpers here.
The local get_compute_capability() wrapper and cc[0] != 10 check sidestep the repository’s architecture-gating helpers and can misclassify unsupported 10.x variants as valid. Please switch this file to flashinfer.utils.get_compute_capability() / is_sm100a_supported() for the skip path.
Suggested fix
-import pytest
-import torch
-import torch.nn.functional as F
-
-def get_compute_capability(device):
- return torch.cuda.get_device_capability(device)
+import pytest
+import torch
+import torch.nn.functional as F
+
+from flashinfer.utils import get_compute_capability, is_sm100a_supported
@@
def _skip_if_not_sm100():
"""Skip test if not SM100 architecture (Blackwell B200)."""
- cc = get_compute_capability(torch.device("cuda"))
- if cc[0] != 10:
+ cc = get_compute_capability(torch.device("cuda"))
+ if not is_sm100a_supported():
pytest.skip(
f"cuTile GDN prefill requires SM100 (Blackwell), but got SM{cc[0]}{cc[1]}"
)As per coding guidelines, "Use flashinfer.utils functions (get_compute_capability(), is_sm90a_supported(), is_sm100a_supported()) to skip tests on unsupported GPU architectures".
Also applies to: 57-63
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/gdn/test_prefill_cutile_blackwell.py` around lines 29 - 30, Replace the
local CUDA wrapper and manual major-version check with the repo helpers: import
flashinfer.utils.get_compute_capability and
flashinfer.utils.is_sm100a_supported, remove the local
get_compute_capability(device) that calls torch.cuda.get_device_capability, and
replace any checks like cc[0] != 10 with a call to
is_sm100a_supported(get_compute_capability(device)) (or its negation) to decide
skips; apply the same replacement for the other skip block referenced (the logic
around the cc checks at the later block).
- Add wall-clock timing alongside GPU kernel time to show true latency FLA Triton launches O(NT) kernels/call; wall-clock >> kernel time for small configs (~2-3x overhead from Python dispatch). cuTile launches ~3 kernels total, so wall-clock ≈ kernel time. - Print cuda.tile/triton versions and CT/FLA availability at startup to help diagnose environment issues. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
🧹 Nitpick comments (4)
benchmarks/bench_gdn_prefill_cutile_vs_fla_blackwell.py (4)
299-318: Loop variable capture in lambda and redundant code.Two issues in the breakdown section:
B023: The lambda at line 312 captures loop variables (
q_n,k_n,v,g,beta,K,h0,idx) by reference. While it works here because_profiler_breakdownis called immediately, this is fragile and can cause subtle bugs if the code is refactored.Redundant import:
F2re-importstorch.nn.functionalwhich is already imported asFat the top.Code duplication: The tensor creation logic duplicates
bench_config. Consider extracting a helper or reusing the tensors.Suggested fix for the lambda issue
- fn_ct = lambda: ct_fwd(q_n, k_n, v, g, beta, K**-0.5, h0.clone(), idx, use_qk_l2norm_in_kernel=False) + def fn_ct(q_n=q_n, k_n=k_n, v=v, g=g, beta=beta, scale=K**-0.5, h0=h0, idx=idx): + return ct_fwd(q_n, k_n, v, g, beta, scale, h0.clone(), idx, use_qk_l2norm_in_kernel=False)Or use
functools.partialto bind values at definition time.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_gdn_prefill_cutile_vs_fla_blackwell.py` around lines 299 - 318, The breakdown block has three issues: the inline lambda fn_ct captures variables by reference (fragile), it re-imports torch.nn.functional as F2 (redundant with F), and duplicates tensor setup already present in bench_config; fix by replacing the lambda with a bound callable (use functools.partial or a small local def that takes no args) that binds q_n, k_n, v, g, beta, h0, idx and the scalar K**-0.5 at definition time (reference: fn_ct and _profiler_breakdown), remove the redundant import F2 and use the existing F, and DRY the tensor construction by reusing or extracting the tensor creation into a helper referenced by the benchmark config so the same tensors are built once and reused for the breakdown run.
222-238: Lambda expressions assigned to variables (E731).Ruff flags assigning lambdas to variables. While common in benchmark code, using
defis preferred for clarity and stack traces.Suggested refactor using def
# ---- cuTile ---- if CT_AVAILABLE: - fn_ct = lambda: ct_fwd( - q_n, k_n, v, g, beta, scale, - h0.clone(), idx, - use_qk_l2norm_in_kernel=False, - ) + def fn_ct(): + return ct_fwd( + q_n, k_n, v, g, beta, scale, + h0.clone(), idx, + use_qk_l2norm_in_kernel=False, + ) results["ct_prof_us"] = _profiler_time_us(fn_ct) results["ct_wall_us"] = _wallclock_time_us(fn_ct) # ---- FLA Triton ---- if FLA_AVAILABLE: - fn_fla = lambda: fla_fwd( - q_n, k_n, v, g, beta, scale, - initial_state=h0.clone(), - output_final_state=False, - ) + def fn_fla(): + return fla_fwd( + q_n, k_n, v, g, beta, scale, + initial_state=h0.clone(), + output_final_state=False, + ) results["fla_prof_us"] = _profiler_time_us(fn_fla) results["fla_wall_us"] = _wallclock_time_us(fn_fla)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_gdn_prefill_cutile_vs_fla_blackwell.py` around lines 222 - 238, The assigned lambda functions fn_ct and fn_fla trigger lint E731; replace them with named functions to improve clarity and stack traces: create a def (e.g., def fn_ct(): ...) that calls ct_fwd with the same arguments (q_n, k_n, v, g, beta, scale, h0.clone(), idx, use_qk_l2norm_in_kernel=False) and another def fn_fla(): ... that calls fla_fwd with the same keyword args (initial_state=h0.clone(), output_final_state=False), then pass those function names into _profiler_time_us and _wallclock_time_us instead of the lambdas.
262-266: Silent exception swallowing loses diagnostic information.The
try-except-passpattern (S110) silently ignores errors when loadingcuda.tile. Consider logging the exception to aid debugging:Suggested fix
try: import cuda.tile as _ct print(f"cuda.tile version: {_ct.__version__}") - except Exception: - pass + except ImportError: + print("cuda.tile: not available")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_gdn_prefill_cutile_vs_fla_blackwell.py` around lines 262 - 266, The current module-level import of cuda.tile silently swallows exceptions (try: import cuda.tile as _ct; print(f"cuda.tile version: {_ct.__version__}") except Exception: pass); change the except to capture the exception (except Exception as e:) and log the error (e.g., using the logging module with logging.exception or logging.error plus traceback) so failures loading cuda.tile are visible; keep the existing print of _ct.__version__ when import succeeds and reference the same symbols (_ct, __version__, cuda.tile) when implementing the logging.
141-167: Consider usingflashinfer.testing.bench_gpu_time()for consistency with repo patterns.The repo provides
flashinfer.testing.bench_gpu_time()which uses CUPTI timing with auto-fallback to CUDA events. Using it would maintain consistency and potentially leverage more accurate timing infrastructure.That said, the current implementation is correct and the wall-clock timing mode specifically measures Python dispatch overhead, which is a valid use case for comparing launch-heavy vs launch-light kernels.
Based on learnings: "Use
flashinfer.testing.bench_gpu_time()for benchmarking kernels, preferring CUPTI timing with auto-fallback to CUDA events."🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_gdn_prefill_cutile_vs_fla_blackwell.py` around lines 141 - 167, Replace the local timing helpers _profiler_time_us and _wallclock_time_us with the repository utility flashinfer.testing.bench_gpu_time(): call bench_gpu_time(fn, ...) for the profiler-style measurement using its CUPTI timing with auto-fallback to CUDA events, and call bench_gpu_time(fn, ...) in the wall-clock/dispatch mode for Python dispatch overhead comparisons; update any callers to use bench_gpu_time instead of these two functions so timing is consistent with repo patterns.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@benchmarks/bench_gdn_prefill_cutile_vs_fla_blackwell.py`:
- Around line 299-318: The breakdown block has three issues: the inline lambda
fn_ct captures variables by reference (fragile), it re-imports
torch.nn.functional as F2 (redundant with F), and duplicates tensor setup
already present in bench_config; fix by replacing the lambda with a bound
callable (use functools.partial or a small local def that takes no args) that
binds q_n, k_n, v, g, beta, h0, idx and the scalar K**-0.5 at definition time
(reference: fn_ct and _profiler_breakdown), remove the redundant import F2 and
use the existing F, and DRY the tensor construction by reusing or extracting the
tensor creation into a helper referenced by the benchmark config so the same
tensors are built once and reused for the breakdown run.
- Around line 222-238: The assigned lambda functions fn_ct and fn_fla trigger
lint E731; replace them with named functions to improve clarity and stack
traces: create a def (e.g., def fn_ct(): ...) that calls ct_fwd with the same
arguments (q_n, k_n, v, g, beta, scale, h0.clone(), idx,
use_qk_l2norm_in_kernel=False) and another def fn_fla(): ... that calls fla_fwd
with the same keyword args (initial_state=h0.clone(), output_final_state=False),
then pass those function names into _profiler_time_us and _wallclock_time_us
instead of the lambdas.
- Around line 262-266: The current module-level import of cuda.tile silently
swallows exceptions (try: import cuda.tile as _ct; print(f"cuda.tile version:
{_ct.__version__}") except Exception: pass); change the except to capture the
exception (except Exception as e:) and log the error (e.g., using the logging
module with logging.exception or logging.error plus traceback) so failures
loading cuda.tile are visible; keep the existing print of _ct.__version__ when
import succeeds and reference the same symbols (_ct, __version__, cuda.tile)
when implementing the logging.
- Around line 141-167: Replace the local timing helpers _profiler_time_us and
_wallclock_time_us with the repository utility
flashinfer.testing.bench_gpu_time(): call bench_gpu_time(fn, ...) for the
profiler-style measurement using its CUPTI timing with auto-fallback to CUDA
events, and call bench_gpu_time(fn, ...) in the wall-clock/dispatch mode for
Python dispatch overhead comparisons; update any callers to use bench_gpu_time
instead of these two functions so timing is consistent with repo patterns.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ec502255-9de7-488a-b0d3-ef57a778ddbe
📒 Files selected for processing (1)
benchmarks/bench_gdn_prefill_cutile_vs_fla_blackwell.py
Summary
Adds accuracy tests and benchmarks for the cuTile-based GDN prefill kernel on
NVIDIA Blackwell (SM100 / B200), comparing against the FLA Triton baseline.
What's included
tests/gdn/test_prefill_cutile_blackwell.py— 33 parametrized accuracy testsbenchmarks/bench_gdn_prefill_cutile_vs_fla_blackwell.py— kernel latencycomparison via
torch.profilerCUDA events.Qwen3.5 GDN Parameters
linear_key_head_dimlinear_value_head_dimlinear_num_value_headslinear_num_key_headsmamba2_state_dtypedefaultQwen3.5-397B TP8 (H=8, K=128, V=128)
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests
Chores