Skip to content

[gdn] cuTile GDN prefill on Blackwell (SM100)#2729

Open
xutizhou wants to merge 12 commits into
flashinfer-ai:mainfrom
xutizhou:dev/cutile-gdn-prefill-blackwell
Open

[gdn] cuTile GDN prefill on Blackwell (SM100)#2729
xutizhou wants to merge 12 commits into
flashinfer-ai:mainfrom
xutizhou:dev/cutile-gdn-prefill-blackwell

Conversation

@xutizhou

@xutizhou xutizhou commented Mar 9, 2026

Copy link
Copy Markdown
Contributor

Summary

Adds accuracy tests and benchmarks for the cuTile-based GDN prefill kernel on
NVIDIA Blackwell (SM100 / B200), comparing against the FLA Triton baseline.

What's included

  • tests/gdn/test_prefill_cutile_blackwell.py — 33 parametrized accuracy tests
  • benchmarks/bench_gdn_prefill_cutile_vs_fla_blackwell.py — kernel latency
    comparison via torch.profiler CUDA events.

Qwen3.5 GDN Parameters

Parameter Value HF config field
K (key head dim) 128 linear_key_head_dim
V (value head dim) 128 linear_value_head_dim
H (value heads) 32 (4B) / 64 (397B-A17B) linear_num_value_heads
H_k (key heads, GQA) 16 linear_num_key_heads
BT (chunk size) 64 FlashInfer/FLA default
q/k/v dtype bfloat16 model weight dtype
SSM state dtype float32 sglang mamba2_state_dtype default
g, beta dtype float32 numerical stability

Qwen3.5-397B TP8 (H=8, K=128, V=128)

Config FLA (ms) cuTile (ms) Speedup
B=1, T=2048 0.092 0.056 1.65x
B=2, T=2048 0.125 0.080 1.55x
B=4, T=2048 0.184 0.120 1.54x
B=8, T=1024 0.171 0.107 1.60x
B=1, T=4096 0.172 0.100 1.71x
B=4, T=4096 0.333 0.218 1.53x
B=8, T=2048 0.304 0.191 1.59x
B=16, T=1024 0.310 0.183 1.69x

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Optimized GDN prefill for Blackwell GPUs with optional L2 normalization and initial-state support.
  • Tests

    • Added accuracy tests against a Triton baseline across small to large configurations, with hardware guards and dtype-aware tolerances.
  • Chores

    • Added a benchmarking tool that reports kernel and end-to-end timings, speedups, and an optional per-kernel breakdown.

…(SM100)

- flashinfer/gdn_kernels/cutile_gdn_prefill.py: cuTile-based GDN chunked
  prefill kernel (4 fused kernels: cumsum, prepare, recurrence, output),
  replacing the 6-stage Triton pipeline. Targets SM100 (Blackwell B200).

- flashinfer/gdn_kernels/__init__.py: export chunk_gated_delta_rule_cutile.

- tests/gdn/test_prefill_cutile_blackwell.py: pytest accuracy tests comparing
  cuTile vs FLA Triton GDN prefill on B200. 33 configs (B=1..16, T=64..4096,
  K/V=64/128/256, w/wo initial state). All 33 PASS (bfloat16, atol=rtol=1e-2).

- benchmarks/bench_gdn_prefill_cutile_vs_fla_blackwell.py: kernel latency
  comparison. Results on B200: 1.38x–1.83x speedup vs Triton baseline.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@gemini-code-assist

Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the GDN prefill functionality by introducing a new cuTile-based kernel optimized for NVIDIA Blackwell GPUs. It includes comprehensive accuracy tests to ensure correctness against an existing baseline and provides benchmarks showcasing substantial performance gains. The overall impact is improved efficiency and broader hardware compatibility for GDN prefill operations.

Highlights

  • Blackwell (SM100) Support: Introduced a cuTile-based GDN prefill kernel specifically optimized for NVIDIA Blackwell (SM100 / B200) GPUs.
  • Accuracy Tests: Added 33 parametrized accuracy tests for the cuTile GDN prefill kernel, comparing its output against the FLA Triton baseline on Blackwell, with all tests passing on B200.
  • Performance Benchmarks: Included benchmarks demonstrating significant performance improvements, with the cuTile kernel achieving a 1.38–1.83x speedup over the FLA Triton baseline on B200 for various workloads.
  • Integration: Integrated the new chunk_gated_delta_rule_cutile function into the flashinfer.gdn_kernels module, making it accessible for use.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • benchmarks/bench_gdn_prefill_cutile_vs_fla_blackwell.py
    • Added a new benchmark script to compare cuTile GDN prefill with FLA Triton on Blackwell GPUs.
  • flashinfer/gdn_kernels/init.py
    • Updated to import and expose the new chunk_gated_delta_rule_cutile function.
  • tests/gdn/test_prefill_cutile_blackwell.py
    • Added new accuracy tests for the cuTile GDN prefill kernel on Blackwell, comparing its output against the FLA Triton baseline.
Activity
  • The author has completed pre-commit checks and confirmed that all tests are passing, as indicated in the pull request checklist.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai

coderabbitai Bot commented Mar 9, 2026

Copy link
Copy Markdown
Contributor
📝 Walkthrough

Walkthrough

Adds a cuTile-based GDN prefill implementation for Blackwell GPUs, plus a benchmark script comparing cuTile vs FLA Triton, test coverage validating outputs against the FLA baseline, and an optional cuTile import exported from gdn_kernels.

Changes

Cohort / File(s) Summary
cuTile Kernel Implementation
flashinfer/gdn_kernels/cutile_gdn_prefill.py, flashinfer/gdn_kernels/__init__.py
New cuTile-based GDN prefill module with public entry chunk_gated_delta_rule_cutile, L2-normalization helpers, multiple cuTile kernels (output, recurrence, fused WY+REC, CKKT solve), internal caching/streams, and __init__ export guard for optional import. Review kernels, tiling/config heuristics, and cache lifecycle.
Benchmarks
benchmarks/bench_gdn_prefill_cutile_vs_fla_blackwell.py
New standalone benchmark comparing cuTile vs FLA on Blackwell. Adds CONFIGS, CLI flags (--large-only, --breakdown), profiler and wall-clock timing utilities, dynamic imports for FLA and cuTile, and formatted speedup output. Check profiler usage and device synchronization.
Tests (Blackwell-only)
tests/gdn/test_prefill_cutile_blackwell.py
New tests that compare cuTile outputs to the FLA Triton baseline under SM100 guard. Parametrized accuracy tests (bf16) and large-config tests, with availability guards and tolerances. Confirm skip conditions, dtype handling, and tolerances.
Integrations / Exports
flashinfer/gdn_kernels/__init__.py
Adds optional import/export of chunk_gated_delta_rule_cutile with ImportError fallback; verify public API surface and import semantics.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant Entrypoint as chunk_gated_delta_rule_cutile
    participant L2Norm
    participant CacheInit
    participant Scheduler as ExecutionPath
    participant Kernels
    participant Profiler
    participant Output

    Caller->>Entrypoint: call(q,k,v,g,beta, initial_state, ...)
    Entrypoint->>L2Norm: optionally normalize q,k
    L2Norm-->>Entrypoint: normalized tensors
    Entrypoint->>CacheInit: allocate/init caches & tensors
    CacheInit-->>Entrypoint: cache handles
    Entrypoint->>Scheduler: choose execution path (fused vs cached WY/REC)
    Scheduler->>Kernels: launch cuTile kernels per chunk
    Kernels->>Kernels: compute WY / CKKT solve / recurrence / output
    Kernels-->>Entrypoint: kernel outputs (o, h, optional aux)
    Entrypoint->>Profiler: record kernel & wall-clock timing (benchmarks)
    Profiler-->>Caller: timing results (when run)
    Entrypoint->>Output: finalize and return (o, None, h)
    Output-->>Caller: outputs
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested labels

run-ci, ready

Suggested reviewers

  • cyx-6
  • yzh119
  • kahyunnam
  • bkryu
  • nvmbreughe
  • kaixih
  • jimmyzho

Poem

🐰 I hopped through kernels, tiling with delight,

bf16 dreams on Blackwell nights,
Benchmarks hum and tests align,
CuTile dances, outputs shine—
A little rabbit cheers the speed in flight.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 64.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly and specifically describes the main change: adding cuTile GDN prefill support for Blackwell (SM100) GPUs, which is the primary objective of this changeset.
Description check ✅ Passed The PR description includes a summary of changes, details about what's included, related parameter tables, benchmark results, and follows the template structure with completed sections.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
⚔️ Resolve merge conflicts
  • Resolve merge conflict in branch dev/cutile-gdn-prefill-blackwell

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces benchmarks and accuracy tests for the new cuTile-based GDN prefill kernel targeting NVIDIA's Blackwell (SM100) architecture. The code is well-structured, with clear separation between tests and benchmarks. My review includes a suggestion to make the benchmark script more robust by ensuring it exits on non-SM100 hardware, which will prevent misleading performance results. The accuracy tests are comprehensive and the changes to the package's __init__.py are correct.

Comment on lines +157 to +160
is_sm100 = _check_sm100()
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Device capability: SM{torch.cuda.get_device_capability()}")
print()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _check_sm100 function's return value is stored in is_sm100 but never used. The benchmark proceeds even on non-SM100 hardware, printing a warning but potentially producing misleading or incorrect performance numbers. Since the cuTile kernel is specifically for SM100+, the benchmark should exit if the hardware is not supported to avoid confusion.

Suggested change
is_sm100 = _check_sm100()
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Device capability: SM{torch.cuda.get_device_capability()}")
print()
if not _check_sm100():
print("\nExiting benchmark: This benchmark is for SM100+ (Blackwell) GPUs.")
return
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Device capability: SM{torch.cuda.get_device_capability()}")
print()

xutizhou and others added 3 commits March 10, 2026 21:45
Key optimizations over the initial kernel submission:
- Adaptive output kernel occupancy: use occ2 for NT*BH>=256 (was >512),
  saving ~4us on medium-grid configs (B=2/4/8 with T=1024-2048)
- Adaptive recurrence BV dispatch: tune CTA count per BH range
  (BH>=64→512 CTAs, BH<=8→128, BH<=4+long→64) for optimal wave count
- Fused ckkt+KKT+solve v2 with 10-step squaring trick (occ1/2/3 variants)
- Q-cached output kernel: load Q once in registers, reuse across V-blocks
- Many experimental kernels explored (fused ckkt+WY, no-WY rec, ILP-2 split,
  higher-occ rec/output) — kept only the ones that improved perf

AI-assisted (Claude Code)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Update benchmark configs to match Qwen3.5 linear attention parameters:
  K=128, V=128, H=32 (was K=256, V=256, H=4)
- Change initial_state (SSM state) dtype from bf16 to fp32, matching
  sglang's default mamba2_state_dtype
- Update accuracy tests with Qwen3.5 parametrization
- Add performance summary markdown with kernel breakdown

AI-assisted (Claude Code)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Add TP8 benchmark configs: H=4 (4B model) and H=8 (397B model)
- Optimize rec BV for small V: cap max_rec_CTAs=256 when BH>=64 and
  V<=128, using BV=32 (0.87 waves) instead of BV=16 (1.73 waves).
  This improves large-batch H=8 configs by 0.10-0.21x.
- Key improvements:
  B=16,T=1024,H=4: 1.53x -> 1.77x
  B=8,T=1024,H=8:  1.42x -> 1.60x
  B=8,T=2048,H=8:  1.38x -> 1.59x
  B=16,T=1024,H=8: 1.62x -> 1.69x

AI-assisted (Claude Code)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@xutizhou xutizhou force-pushed the dev/cutile-gdn-prefill-blackwell branch from f524be1 to 99ce4d8 Compare March 11, 2026 06:24
xutizhou and others added 7 commits March 10, 2026 23:31
Add cutile_fused_wy_rec_kernel that computes w and u on-the-fly per
chunk inside the recurrence loop, eliminating w/u HBM write+read.
For B=8,T=2048,H=8,K=128: saves ~128MB bandwidth (~16us at 8TB/s).

Dispatch condition: K<=128, BH>=64, BV<=32, NT*BH>=512.
This targets large-batch TP8 configs where bandwidth savings
outweigh the extra per-chunk A_inv WGMMA compute.

Key improvements on Qwen3.5 TP8:
  B=16,T=1024,H=4: 1.74x -> 1.90x
  B=8 ,T=1024,H=8: 1.60x -> 1.76x
  B=8 ,T=2048,H=8: 1.59x -> 1.72x

AI-assisted (Claude Code)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
When fused WY+rec dispatch would skip a config due to rec_BV>32
(e.g. B=16,T=1024,H=8 with BV=64), fall back to BV=32 with more CTAs
for the fused path. This ensures all BH>=64 K<=128 configs use the
bandwidth-optimized fused kernel.

No perf change for B=16 configs (BV=32/512CTAs ≈ BV=64/256CTAs),
but enables future optimizations on the unified fused path.

AI-assisted (Claude Code)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Remove unused experimental kernels that were explored during optimization
but never activated in the dispatch path. No functional change.

Removed: cutile_chunk_local_cumsum, cutile_cumsum_kkt_kernel,
cutile_fused_ckkt_solve_kernel, cutile_prepare_kernel,
cutile_output_multiV_kernel, cutile_output_qcached_occ3{,_v2}_kernel,
cutile_kkt_solve_kernel, cutile_fused_kkt_solve_wy_kernel_v2,
cutile_kkt_kernel, cutile_wy_occ6_kernel, cutile_solve_kernel_occ2,
cutile_wy_kernel_occ2, cutile_recurrence_kernel_bv64,
cutile_recurrence_kernel_v{2,3}, cutile_recurrence_kernel_bv16_occ4,
cutile_recurrence_kernel_bv16_split2, cutile_rec_nowy_bv16_occ2_kernel,
cutile_recompute_wu_occ4_kernel, cutile_fused_wy_rec_output_kernel,
cutile_fused_ckkt_solve_v2_occ4_kernel, cutile_fused_ckkt_wy_occ2_kernel

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Remove standalone cumsum, solve, WY, rec, and fused ckkt+WY kernels
that are superseded by the fused ckkt_solve_v2 and fused_wy_rec
variants. Only dispatch-active kernels remain.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ines)

Eliminate occupancy variants by keeping only the best single variant:
- ckkt: keep occ3 only (was occ1/2/3)
- output_qcached: keep occ2 only (was occ1/2)
- output_tiled: keep occ2 only (was occ1/2)
- recurrence: keep occ2 only (was occ2/3)

Small grid configs (B=1) lose 0.03-0.05x from ckkt occ3 register
pressure at <512 CTAs, acceptable tradeoff for code simplicity.

Also share kernel bodies via factory pattern: define body once,
create occupancy variant via ct.kernel(occupancy=N)(body).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
… lines)

For Qwen3.5 V=128 with BV=128, there's only 1 V-block iteration,
so Q-caching has zero overhead. Remove the separate 3D-grid output
kernel and simplify dispatch to always use the qcached variant.

5 kernels remain: ckkt_solve (occ3), output_qcached (occ2),
recurrence (occ2), fused_wy_rec (occ2), l2norm.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Remove unused _solve_tril_cached and its imports (solve_tril_16x16,
  merge_16x16 kernels) — solve is fused into ckkt kernel
- Remove unused numpy import
- Remove stale cache entries ('A', 'Ad') from _init_cache
- Deduplicate comments in dispatch logic
- Update module docstring to reflect current 5-kernel architecture

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@xutizhou xutizhou marked this pull request as ready for review March 11, 2026 09:40

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/bench_gdn_prefill_cutile_vs_fla_blackwell.py`:
- Around line 41-71: Move the cuTile/sglang bootstrap out of top-level import so
it only runs when cuTile is actually available: read the root from an env/config
variable (use _SGLANG_ROOT from os.environ or a config fallback) instead of the
hard-coded Path, check CT_AVAILABLE (or validate that the configured
_SGLANG_ROOT exists and contains the expected files) before calling _stub/_load,
and wrap the for-loop that imports "sglang.srt.layers.attention.fla.*" inside
that availability guard; also make _load/_stub resilient by catching and
silencing file-not-found/import errors so the module import will cleanly report
cuTile unavailable rather than crashing.
- Around line 223-226: The call to _check_sm100() assigns is_sm100 but the
script continues even when it returns False, leaving an unused-variable (Ruff
F841) and allowing unsupported GPUs to proceed; change the logic after calling
_check_sm100() so that if it returns False you immediately stop the run (e.g.,
call sys.exit(1) or raise SystemExit) with a clear error message, otherwise
continue and print the GPU info—this both uses the return value and prevents
running the Blackwell-only path on unsupported hardware.

In `@flashinfer/gdn_kernels/cutile_gdn_prefill.py`:
- Around line 249-253: The optional initial_state/initial_state_indices path is
broken because USE_INITIAL only gates the load but both recurrence kernels
always write back into initial_state/initial_state_indices; fix by allocating
internal zero tensors when initial_state or initial_state_indices is None
instead of assuming callers supply them. In the functions that call the CUDA
kernels (e.g., chunk_gated_delta_rule and the other entrypoints around the shown
regions), detect if initial_state is None or initial_state_indices is None and
create zero tensors with the correct shape, device and dtype (initial_state:
[N,H,K,V]; initial_state_indices matching expected index shape), then pass those
buffers into the kernels so the write-back has a valid destination;
alternatively, remove Optional typing and require callers to pass buffers, but
prefer the internal-allocation approach so semantics match existing API. Ensure
allocations respect cu_seqlens / batch size and preserve device/dtype.
- Around line 196-205: The module-level scratch state (_tensor_cache,
_cache_config, _cached_stream) makes the exported kernel non-reentrant; change
the caching to be scoped per-call or at least per-device+stream. Update
_init_cache and callers (e.g., functions that use
_tensor_cache/_cache_config/_cached_stream) to accept a stream identifier and
include it in the cache key (for example key =
(B,T,H,K,V,device,dtype_k,dtype_v,stream)) or replace the globals with a
per-call local cache returned by _init_cache; ensure every read/write of
_tensor_cache/_cache_config/_cached_stream is replaced with lookups into the new
keyed dict or the returned local cache so concurrent invocations on the same
device but different streams do not collide.
- Around line 23-30: The import fallback sets ct = None which causes
AttributeError when module-level code later calls or uses `@ct.kernel`; instead
create a lightweight sentinel that raises ImportError when its kernel attribute
is accessed/used so the optional-import pattern behaves correctly. Replace the
except ImportError block so that CUTILE_AVAILABLE=False and ct is an object
(e.g., a small class instance or SimpleNamespace) that defines a kernel
attribute (callable/decorator) which, when invoked, raises
ImportError("cuda.tile not available"), and set Const to a similar sentinel or
None-consistent value; this ensures all module-level uses of ct.kernel and
references to Const (the symbols ct.kernel and Const) raise ImportError rather
than AttributeError during import.

In `@tests/gdn/test_prefill_cutile_blackwell.py`:
- Around line 29-30: Replace the local CUDA wrapper and manual major-version
check with the repo helpers: import flashinfer.utils.get_compute_capability and
flashinfer.utils.is_sm100a_supported, remove the local
get_compute_capability(device) that calls torch.cuda.get_device_capability, and
replace any checks like cc[0] != 10 with a call to
is_sm100a_supported(get_compute_capability(device)) (or its negation) to decide
skips; apply the same replacement for the other skip block referenced (the logic
around the cc checks at the later block).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: b3847a0e-f13a-4a16-a856-32ed41d1f151

📥 Commits

Reviewing files that changed from the base of the PR and between fe06b91 and ef40d21.

📒 Files selected for processing (4)
  • benchmarks/bench_gdn_prefill_cutile_vs_fla_blackwell.py
  • flashinfer/gdn_kernels/__init__.py
  • flashinfer/gdn_kernels/cutile_gdn_prefill.py
  • tests/gdn/test_prefill_cutile_blackwell.py

Comment on lines +41 to +71
import importlib.util, sys, types, pathlib
_kernel_path = pathlib.Path(__file__).parent.parent / "flashinfer/gdn_kernels/cutile_gdn_prefill.py"
_SGLANG_ROOT = pathlib.Path("/home/scratch.xutingz_wwfo/gitsrc/sglang")
# Bootstrap sglang FLA modules needed by cutile_gdn_prefill at runtime
def _stub(name):
parts = name.split('.')
for i in range(1, len(parts) + 1):
pkg = '.'.join(parts[:i])
if pkg not in sys.modules:
m = types.ModuleType(pkg); m.__path__ = []; m.__package__ = pkg
sys.modules[pkg] = m
def _load(name, path):
_stub(name)
spec = importlib.util.spec_from_file_location(name, path)
mod = importlib.util.module_from_spec(spec)
mod.__package__ = name.rsplit('.', 1)[0]
sys.modules[name] = mod
spec.loader.exec_module(mod)
return mod
import torch as _torch
_stub("sglang.srt.utils.common")
sys.modules["sglang.srt.utils.common"].torch_release = tuple(
int(x) for x in _torch.__version__.split(".")[:2] if x.isdigit()
)
_sgl_python = str(_SGLANG_ROOT / "python")
if _sgl_python not in sys.path:
sys.path.insert(0, _sgl_python)
_fla = _SGLANG_ROOT / "python/sglang/srt/layers/attention/fla"
for _n in ["utils", "l2norm", "op", "index", "cumsum",
"chunk_scaled_dot_kkt", "solve_tril", "wy_fast", "chunk_delta_h", "chunk_o"]:
_load(f"sglang.srt.layers.attention.fla.{_n}", str(_fla / f"{_n}.py"))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Make the cuTile/sglang bootstrap portable and lazy.

This block hard-codes a workstation-local checkout path and eagerly loads files from it before CT_AVAILABLE is set. On any machine without that exact tree, the benchmark fails during import with a file-loading error instead of cleanly reporting that cuTile is unavailable. This same block is also the current Ruff E401/E702 failure.

Please move the bootstrap behind a real availability check and source the root from config/env instead of an absolute path.

🧰 Tools
🪛 GitHub Actions: pre-commit

[error] 41-41: Ruff: E401 Multiple imports on one line. Split imports across lines.


[error] 50-50: Ruff: E702 Multiple statements on one line (semicolon). Split into separate statements.


[error] 50-52: Ruff: E702 Multiple statements on one line (semicolon) in _load helper block.


[error] 41-43: Ruff: E401 Multiple imports on one line. Split imports across lines.


[warning] 50-50: Ruff: E702 Multiple statements on one line (semicolon) detected; consider breaking into separate lines.

🪛 Ruff (0.15.5)

[error] 50-50: Multiple statements on one line (semicolon)

(E702)


[error] 50-50: Multiple statements on one line (semicolon)

(E702)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_gdn_prefill_cutile_vs_fla_blackwell.py` around lines 41 -
71, Move the cuTile/sglang bootstrap out of top-level import so it only runs
when cuTile is actually available: read the root from an env/config variable
(use _SGLANG_ROOT from os.environ or a config fallback) instead of the
hard-coded Path, check CT_AVAILABLE (or validate that the configured
_SGLANG_ROOT exists and contains the expected files) before calling _stub/_load,
and wrap the for-loop that imports "sglang.srt.layers.attention.fla.*" inside
that availability guard; also make _load/_stub resilient by catching and
silencing file-not-found/import errors so the module import will cleanly report
cuTile unavailable rather than crashing.

Comment on lines +223 to +226
is_sm100 = _check_sm100()
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Device capability: SM{torch.cuda.get_device_capability()}")
print()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Stop the run when _check_sm100() fails.

The return value is ignored, so the script still executes the Blackwell-only path on unsupported GPUs after printing a warning. That produces misleading numbers at best and runtime failures at worst, and it leaves the current Ruff F841 error unresolved.

Suggested fix
-    is_sm100 = _check_sm100()
+    if not _check_sm100():
+        return
     print(f"GPU: {torch.cuda.get_device_name(0)}")
     print(f"Device capability: SM{torch.cuda.get_device_capability()}")
🧰 Tools
🪛 GitHub Actions: pre-commit

[error] 223-223: Ruff: F841 Local variable is_sm100 is assigned to but never used.


[error] 223-223: Ruff: Remove assignment to unused variable is_sm100.


[warning] 223-223: Ruff: F841 Local variable is_sm100 is assigned to but never used (suggest removal).

🪛 Ruff (0.15.5)

[error] 223-223: Local variable is_sm100 is assigned to but never used

Remove assignment to unused variable is_sm100

(F841)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_gdn_prefill_cutile_vs_fla_blackwell.py` around lines 223 -
226, The call to _check_sm100() assigns is_sm100 but the script continues even
when it returns False, leaving an unused-variable (Ruff F841) and allowing
unsupported GPUs to proceed; change the logic after calling _check_sm100() so
that if it returns False you immediately stop the run (e.g., call sys.exit(1) or
raise SystemExit) with a clear error message, otherwise continue and print the
GPU info—this both uses the return value and prevents running the Blackwell-only
path on unsupported hardware.

Comment on lines +23 to +30
try:
import cuda.tile as ct
from cuda.tile import Constant as Const
CUTILE_AVAILABLE = True
except ImportError:
CUTILE_AVAILABLE = False
ct = None
Const = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail
rg -n -A2 -B2 'ct = None|@ct\.kernel|ct\.kernel\(' flashinfer/gdn_kernels/cutile_gdn_prefill.py

Repository: flashinfer-ai/flashinfer

Length of output: 931


🏁 Script executed:

cat -n flashinfer/gdn_kernels/__init__.py | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 2005


🏁 Script executed:

rg -n "cutile_gdn_prefill|ImportError" tests/gdn/ -A2 -B2

Repository: flashinfer-ai/flashinfer

Length of output: 2291


🏁 Script executed:

rg -n "CUTILE_AVAILABLE" flashinfer/ -A2 -B2

Repository: flashinfer-ai/flashinfer

Length of output: 582


🏁 Script executed:

rg -n "cutile_gdn_prefill" flashinfer/gdn_kernels/__init__.py

Repository: flashinfer-ai/flashinfer

Length of output: 136


🏁 Script executed:

sed -n '23,160p' flashinfer/gdn_kernels/cutile_gdn_prefill.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 6562


🏁 Script executed:

sed -n '48,55p' flashinfer/gdn_kernels/__init__.py

Repository: flashinfer-ai/flashinfer

Length of output: 273


Module-level ct.kernel() calls will raise AttributeError, not ImportError, when cuda.tile is unavailable.

When ct = None is set on import failure (line 29), the unconditional module-level code at lines 125, 131, 436, 446, and 627 will attempt to call ct.kernel(...) and use @ct.kernel decorator during import. This raises AttributeError instead of the expected ImportError, which breaks the optional-import pattern used in flashinfer/gdn_kernels/__init__.py (line 51-52) and expected by tests (tests/gdn/test_prefill_cutile_blackwell.py lines 35-40).

Suggested fix
-try:
-    import cuda.tile as ct
-    from cuda.tile import Constant as Const
-    CUTILE_AVAILABLE = True
-except ImportError:
-    CUTILE_AVAILABLE = False
-    ct = None
-    Const = None
+try:
+    import cuda.tile as ct
+    from cuda.tile import Constant as Const
+    CUTILE_AVAILABLE = True
+except ImportError as e:
+    raise ImportError("cuda.tile is required for cutile_gdn_prefill") from e
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/cutile_gdn_prefill.py` around lines 23 - 30, The
import fallback sets ct = None which causes AttributeError when module-level
code later calls or uses `@ct.kernel`; instead create a lightweight sentinel that
raises ImportError when its kernel attribute is accessed/used so the
optional-import pattern behaves correctly. Replace the except ImportError block
so that CUTILE_AVAILABLE=False and ct is an object (e.g., a small class instance
or SimpleNamespace) that defines a kernel attribute (callable/decorator) which,
when invoked, raises ImportError("cuda.tile not available"), and set Const to a
similar sentinel or None-consistent value; this ensures all module-level uses of
ct.kernel and references to Const (the symbols ct.kernel and Const) raise
ImportError rather than AttributeError during import.

Comment on lines +196 to +205
_tensor_cache = {}
_cache_config = None # (B, T, H, K, V, device, dtype) tuple for invalidation

def _init_cache(B, T, H, K, V, device, dtype_k, dtype_v):
"""Pre-allocate all intermediate tensors for the given config."""
global _tensor_cache, _cache_config
key = (B, T, H, K, V, device, dtype_k, dtype_v)
if _cache_config == key:
return # Already initialized for this config
_cache_config = key

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

The module-level scratch state makes this exported kernel non-reentrant.

_tensor_cache, _cache_config, and _cached_stream are shared across all invocations. Two overlapping calls can overwrite the same intermediates/output buffers, and later calls inherit launch state from earlier ones. That is risky for an exported library API even if it helps single-call benchmark numbers.

Please scope scratch buffers and stream selection to the call, or at least shard them by device/stream instead of keeping a single global instance.

Also applies to: 240-289

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/cutile_gdn_prefill.py` around lines 196 - 205, The
module-level scratch state (_tensor_cache, _cache_config, _cached_stream) makes
the exported kernel non-reentrant; change the caching to be scoped per-call or
at least per-device+stream. Update _init_cache and callers (e.g., functions that
use _tensor_cache/_cache_config/_cached_stream) to accept a stream identifier
and include it in the cache key (for example key =
(B,T,H,K,V,device,dtype_k,dtype_v,stream)) or replace the globals with a
per-call local cache returned by _init_cache; ensure every read/write of
_tensor_cache/_cache_config/_cached_stream is replaced with lookups into the new
keyed dict or the returned local cache so concurrent invocations on the same
device but different streams do not collide.

Comment on lines +249 to +253
scale: Optional[float] = None,
initial_state: Optional[torch.Tensor] = None, # [N, H, K, V]
initial_state_indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
use_qk_l2norm_in_kernel: bool = False,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

The advertised initial_state=None path is still broken.

USE_INITIAL only controls the initial load. Both recurrence kernels always write back through initial_state and initial_state_indices at the end, so the default initial_state=None path cannot work as declared. The new tests avoid tripping this only by always fabricating zero h0/idx tensors when use_initial_state=False.

If this entry point is meant to preserve the existing flashinfer.gdn_prefill.chunk_gated_delta_rule semantics, it needs an internal zero-state/indices allocation when initial_state is omitted, or the parameter should stop being optional.

Suggested fix
     USE_INITIAL = initial_state is not None
+
+    if initial_state is None:
+        initial_state = torch.zeros(
+            B, H, K, V, dtype=torch.float32, device=q.device
+        )
+        initial_state_indices = torch.arange(
+            B, dtype=torch.int32, device=q.device
+        )
+    elif initial_state_indices is None:
+        raise ValueError("initial_state_indices is required when initial_state is provided")

Also applies to: 278-325, 425-432, 535-542

🧰 Tools
🪛 Ruff (0.15.5)

[warning] 252-252: Unused function argument: cu_seqlens

(ARG001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/cutile_gdn_prefill.py` around lines 249 - 253, The
optional initial_state/initial_state_indices path is broken because USE_INITIAL
only gates the load but both recurrence kernels always write back into
initial_state/initial_state_indices; fix by allocating internal zero tensors
when initial_state or initial_state_indices is None instead of assuming callers
supply them. In the functions that call the CUDA kernels (e.g.,
chunk_gated_delta_rule and the other entrypoints around the shown regions),
detect if initial_state is None or initial_state_indices is None and create zero
tensors with the correct shape, device and dtype (initial_state: [N,H,K,V];
initial_state_indices matching expected index shape), then pass those buffers
into the kernels so the write-back has a valid destination; alternatively,
remove Optional typing and require callers to pass buffers, but prefer the
internal-allocation approach so semantics match existing API. Ensure allocations
respect cu_seqlens / batch size and preserve device/dtype.

Comment on lines +29 to +30
def get_compute_capability(device):
return torch.cuda.get_device_capability(device)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use the repo-standard SM100 helpers here.

The local get_compute_capability() wrapper and cc[0] != 10 check sidestep the repository’s architecture-gating helpers and can misclassify unsupported 10.x variants as valid. Please switch this file to flashinfer.utils.get_compute_capability() / is_sm100a_supported() for the skip path.

Suggested fix
-import pytest
-import torch
-import torch.nn.functional as F
-
-def get_compute_capability(device):
-    return torch.cuda.get_device_capability(device)
+import pytest
+import torch
+import torch.nn.functional as F
+
+from flashinfer.utils import get_compute_capability, is_sm100a_supported
@@
 def _skip_if_not_sm100():
     """Skip test if not SM100 architecture (Blackwell B200)."""
-    cc = get_compute_capability(torch.device("cuda"))
-    if cc[0] != 10:
+    cc = get_compute_capability(torch.device("cuda"))
+    if not is_sm100a_supported():
         pytest.skip(
             f"cuTile GDN prefill requires SM100 (Blackwell), but got SM{cc[0]}{cc[1]}"
         )

As per coding guidelines, "Use flashinfer.utils functions (get_compute_capability(), is_sm90a_supported(), is_sm100a_supported()) to skip tests on unsupported GPU architectures".

Also applies to: 57-63

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_prefill_cutile_blackwell.py` around lines 29 - 30, Replace the
local CUDA wrapper and manual major-version check with the repo helpers: import
flashinfer.utils.get_compute_capability and
flashinfer.utils.is_sm100a_supported, remove the local
get_compute_capability(device) that calls torch.cuda.get_device_capability, and
replace any checks like cc[0] != 10 with a call to
is_sm100a_supported(get_compute_capability(device)) (or its negation) to decide
skips; apply the same replacement for the other skip block referenced (the logic
around the cc checks at the later block).

@kaixih

kaixih commented Apr 1, 2026

Copy link
Copy Markdown
Collaborator

@xutizhou, thx for the pr. do we compare the perf with this PR: #2742?

@vadiklyutiy vadiklyutiy mentioned this pull request Apr 1, 2026
5 tasks
@xutizhou

xutizhou commented Apr 7, 2026

Copy link
Copy Markdown
Contributor Author

@xutizhou, thx for the pr. do we compare the perf with this PR: #2742?

the two PRs are complementary — ours wins the TP8 serving case (small H, FLA wall-clock overhead dominates), theirs wins the large-H full-model case.

- Add wall-clock timing alongside GPU kernel time to show true latency
  FLA Triton launches O(NT) kernels/call; wall-clock >> kernel time for
  small configs (~2-3x overhead from Python dispatch).
  cuTile launches ~3 kernels total, so wall-clock ≈ kernel time.
- Print cuda.tile/triton versions and CT/FLA availability at startup
  to help diagnose environment issues.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@xutizhou xutizhou requested a review from yongwww as a code owner April 7, 2026 06:36

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (4)
benchmarks/bench_gdn_prefill_cutile_vs_fla_blackwell.py (4)

299-318: Loop variable capture in lambda and redundant code.

Two issues in the breakdown section:

  1. B023: The lambda at line 312 captures loop variables (q_n, k_n, v, g, beta, K, h0, idx) by reference. While it works here because _profiler_breakdown is called immediately, this is fragile and can cause subtle bugs if the code is refactored.

  2. Redundant import: F2 re-imports torch.nn.functional which is already imported as F at the top.

  3. Code duplication: The tensor creation logic duplicates bench_config. Consider extracting a helper or reusing the tensors.

Suggested fix for the lambda issue
-            fn_ct = lambda: ct_fwd(q_n, k_n, v, g, beta, K**-0.5, h0.clone(), idx, use_qk_l2norm_in_kernel=False)
+            def fn_ct(q_n=q_n, k_n=k_n, v=v, g=g, beta=beta, scale=K**-0.5, h0=h0, idx=idx):
+                return ct_fwd(q_n, k_n, v, g, beta, scale, h0.clone(), idx, use_qk_l2norm_in_kernel=False)

Or use functools.partial to bind values at definition time.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_gdn_prefill_cutile_vs_fla_blackwell.py` around lines 299 -
318, The breakdown block has three issues: the inline lambda fn_ct captures
variables by reference (fragile), it re-imports torch.nn.functional as F2
(redundant with F), and duplicates tensor setup already present in bench_config;
fix by replacing the lambda with a bound callable (use functools.partial or a
small local def that takes no args) that binds q_n, k_n, v, g, beta, h0, idx and
the scalar K**-0.5 at definition time (reference: fn_ct and
_profiler_breakdown), remove the redundant import F2 and use the existing F, and
DRY the tensor construction by reusing or extracting the tensor creation into a
helper referenced by the benchmark config so the same tensors are built once and
reused for the breakdown run.

222-238: Lambda expressions assigned to variables (E731).

Ruff flags assigning lambdas to variables. While common in benchmark code, using def is preferred for clarity and stack traces.

Suggested refactor using def
     # ---- cuTile ----
     if CT_AVAILABLE:
-        fn_ct = lambda: ct_fwd(
-            q_n, k_n, v, g, beta, scale,
-            h0.clone(), idx,
-            use_qk_l2norm_in_kernel=False,
-        )
+        def fn_ct():
+            return ct_fwd(
+                q_n, k_n, v, g, beta, scale,
+                h0.clone(), idx,
+                use_qk_l2norm_in_kernel=False,
+            )
         results["ct_prof_us"] = _profiler_time_us(fn_ct)
         results["ct_wall_us"] = _wallclock_time_us(fn_ct)
 
     # ---- FLA Triton ----
     if FLA_AVAILABLE:
-        fn_fla = lambda: fla_fwd(
-            q_n, k_n, v, g, beta, scale,
-            initial_state=h0.clone(),
-            output_final_state=False,
-        )
+        def fn_fla():
+            return fla_fwd(
+                q_n, k_n, v, g, beta, scale,
+                initial_state=h0.clone(),
+                output_final_state=False,
+            )
         results["fla_prof_us"] = _profiler_time_us(fn_fla)
         results["fla_wall_us"] = _wallclock_time_us(fn_fla)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_gdn_prefill_cutile_vs_fla_blackwell.py` around lines 222 -
238, The assigned lambda functions fn_ct and fn_fla trigger lint E731; replace
them with named functions to improve clarity and stack traces: create a def
(e.g., def fn_ct(): ...) that calls ct_fwd with the same arguments (q_n, k_n, v,
g, beta, scale, h0.clone(), idx, use_qk_l2norm_in_kernel=False) and another def
fn_fla(): ... that calls fla_fwd with the same keyword args
(initial_state=h0.clone(), output_final_state=False), then pass those function
names into _profiler_time_us and _wallclock_time_us instead of the lambdas.

262-266: Silent exception swallowing loses diagnostic information.

The try-except-pass pattern (S110) silently ignores errors when loading cuda.tile. Consider logging the exception to aid debugging:

Suggested fix
     try:
         import cuda.tile as _ct
         print(f"cuda.tile version: {_ct.__version__}")
-    except Exception:
-        pass
+    except ImportError:
+        print("cuda.tile: not available")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_gdn_prefill_cutile_vs_fla_blackwell.py` around lines 262 -
266, The current module-level import of cuda.tile silently swallows exceptions
(try: import cuda.tile as _ct; print(f"cuda.tile version: {_ct.__version__}")
except Exception: pass); change the except to capture the exception (except
Exception as e:) and log the error (e.g., using the logging module with
logging.exception or logging.error plus traceback) so failures loading cuda.tile
are visible; keep the existing print of _ct.__version__ when import succeeds and
reference the same symbols (_ct, __version__, cuda.tile) when implementing the
logging.

141-167: Consider using flashinfer.testing.bench_gpu_time() for consistency with repo patterns.

The repo provides flashinfer.testing.bench_gpu_time() which uses CUPTI timing with auto-fallback to CUDA events. Using it would maintain consistency and potentially leverage more accurate timing infrastructure.

That said, the current implementation is correct and the wall-clock timing mode specifically measures Python dispatch overhead, which is a valid use case for comparing launch-heavy vs launch-light kernels.

Based on learnings: "Use flashinfer.testing.bench_gpu_time() for benchmarking kernels, preferring CUPTI timing with auto-fallback to CUDA events."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_gdn_prefill_cutile_vs_fla_blackwell.py` around lines 141 -
167, Replace the local timing helpers _profiler_time_us and _wallclock_time_us
with the repository utility flashinfer.testing.bench_gpu_time(): call
bench_gpu_time(fn, ...) for the profiler-style measurement using its CUPTI
timing with auto-fallback to CUDA events, and call bench_gpu_time(fn, ...) in
the wall-clock/dispatch mode for Python dispatch overhead comparisons; update
any callers to use bench_gpu_time instead of these two functions so timing is
consistent with repo patterns.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@benchmarks/bench_gdn_prefill_cutile_vs_fla_blackwell.py`:
- Around line 299-318: The breakdown block has three issues: the inline lambda
fn_ct captures variables by reference (fragile), it re-imports
torch.nn.functional as F2 (redundant with F), and duplicates tensor setup
already present in bench_config; fix by replacing the lambda with a bound
callable (use functools.partial or a small local def that takes no args) that
binds q_n, k_n, v, g, beta, h0, idx and the scalar K**-0.5 at definition time
(reference: fn_ct and _profiler_breakdown), remove the redundant import F2 and
use the existing F, and DRY the tensor construction by reusing or extracting the
tensor creation into a helper referenced by the benchmark config so the same
tensors are built once and reused for the breakdown run.
- Around line 222-238: The assigned lambda functions fn_ct and fn_fla trigger
lint E731; replace them with named functions to improve clarity and stack
traces: create a def (e.g., def fn_ct(): ...) that calls ct_fwd with the same
arguments (q_n, k_n, v, g, beta, scale, h0.clone(), idx,
use_qk_l2norm_in_kernel=False) and another def fn_fla(): ... that calls fla_fwd
with the same keyword args (initial_state=h0.clone(), output_final_state=False),
then pass those function names into _profiler_time_us and _wallclock_time_us
instead of the lambdas.
- Around line 262-266: The current module-level import of cuda.tile silently
swallows exceptions (try: import cuda.tile as _ct; print(f"cuda.tile version:
{_ct.__version__}") except Exception: pass); change the except to capture the
exception (except Exception as e:) and log the error (e.g., using the logging
module with logging.exception or logging.error plus traceback) so failures
loading cuda.tile are visible; keep the existing print of _ct.__version__ when
import succeeds and reference the same symbols (_ct, __version__, cuda.tile)
when implementing the logging.
- Around line 141-167: Replace the local timing helpers _profiler_time_us and
_wallclock_time_us with the repository utility
flashinfer.testing.bench_gpu_time(): call bench_gpu_time(fn, ...) for the
profiler-style measurement using its CUPTI timing with auto-fallback to CUDA
events, and call bench_gpu_time(fn, ...) in the wall-clock/dispatch mode for
Python dispatch overhead comparisons; update any callers to use bench_gpu_time
instead of these two functions so timing is consistent with repo patterns.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ec502255-9de7-488a-b0d3-ef57a778ddbe

📥 Commits

Reviewing files that changed from the base of the PR and between ef40d21 and 38f3cc5.

📒 Files selected for processing (1)
  • benchmarks/bench_gdn_prefill_cutile_vs_fla_blackwell.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants