Skip to content

feat(gdn): separate input and output pool indices#2905

Merged
kahyunnam merged 1 commit into
flashinfer-ai:mainfrom
feldsherov:gdn-decode-separate-input-and-output-indices
Apr 17, 2026
Merged

feat(gdn): separate input and output pool indices#2905
kahyunnam merged 1 commit into
flashinfer-ai:mainfrom
feldsherov:gdn-decode-separate-input-and-output-indices

Conversation

@feldsherov

@feldsherov feldsherov commented Mar 28, 2026

Copy link
Copy Markdown
Contributor

📌 Description

Introduce separate output indices parameter for gated_delta_rule_decode_pretranspose.

This addresses decoded part of feature request in #2873

🔍 Related Issues

#2873

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

I've checked only tests/gdn/test_decode_delta_rule.py on H200. I need help with running whole testsuite.

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Optional control to write updated recurrent state into caller-specified output slots (separate from read/input slots). Enabled only in pool (initial-state) mode and requires initial-state info; validates indices shape and integer dtype. Preserves existing behavior when not used.
  • Tests

    • Added tests covering separate read/write state indexing and the case where output indices equal input indices, validating outputs and pool mutations.

@coderabbitai

coderabbitai Bot commented Mar 28, 2026

Copy link
Copy Markdown
Contributor

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds an optional output_state_indices argument across the gated-delta-rule decode path and kernels so updated recurrent states can be written to caller-specified pool slots different from the read indices. The argument is validated for pool mode and plumbed through BF16 MTP and pretranspose kernels; tests cover separate read/write indexing.

Changes

Cohort / File(s) Summary
Public API
flashinfer/gdn_decode.py
Added output_state_indices: Optional[torch.Tensor] = None to gated_delta_rule_decode_pretranspose(); validated for pool mode and integer shape [B].
BF16 MTP kernel + wrapper
flashinfer/gdn_kernels/gdn_decode_bf16_state.py
Added h0_out_indices parameter to kernel and launcher; writeback now uses h0_out_indices (write mapping) while h0_indices remains read mapping. gated_delta_rule_mtp(...) accepts and forwards output_state_indices.
Pretranspose kernels + launcher
flashinfer/gdn_kernels/gdn_decode_pretranspose.py
Threaded h0_out_indices through small/big pretranspose kernels and launchers; kernels compute separate read (h0_indices) and write (h0_out_indices) pool indices. run_pretranspose_decode() accepts output_state_indices.
Tests
tests/gdn/test_decode_delta_rule.py
Added SM90+ tests test_output_state_indices and test_output_state_indices_same_as_input to validate distinct read/write behavior and equivalence when output equals input.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant API as gated_delta_rule_decode_pretranspose
    participant Runner as run_pretranspose_decode / gated_delta_rule_mtp
    participant Kernel as CUDA Kernel (pretranspose / bf16 MTP)
    participant Pool as State Pool

    User->>API: call with initial_state, initial_state_indices (read), output_state_indices (write)
    API->>API: validate output_state_indices (pool-mode, shape [B], int dtype)
    API->>Runner: forward tensors and indices
    Runner->>Kernel: launch with h0_indices (read) and h0_out_indices (write)
    Kernel->>Pool: read state from Pool[h0_indices[b]]
    Kernel->>Kernel: compute gated-delta updates
    Kernel->>Pool: write updated state to Pool[h0_out_indices[b]]
    Pool-->>User: outputs and mutated pool
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~40 minutes

Possibly related PRs

Suggested reviewers

  • yzh119
  • bkryu
  • yongwww
  • kahyunnam
  • saltyminty

Poem

🐰 I hop and I map each index with care,
Read from one burrow, write to another fair.
Pools split their roads, no more accidental blends,
Small rabbit hops make sure each state finds friends. 🥕

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: introducing separate output indices for pool-based state management in gated_delta_rule_decode_pretranspose.
Description check ✅ Passed The description provides the feature purpose and related issue, though pre-commit check status and full test suite status are incomplete as noted by the author.
Docstring Coverage ✅ Passed Docstring coverage is 90.91% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the output_state_indices parameter to the Gated Delta Rule decode kernels, enabling the updated state to be written to a different pool slot than the one read from. The changes include updates to the high-level Python API, the underlying CUDA kernels in gdn_decode_bf16_state.py and gdn_decode_pretranspose.py, and the addition of verification tests. Review feedback recommends grouping index reads within the kernels to improve consistency and instruction-level parallelism, as well as simplifying redundant logic in the pretranspose runner.

Comment on lines +754 to +759
pool_batch_idx = gH_slot_indices[batch_idx]
if pool_batch_idx < 0:
pool_batch_idx = cutlass.Int32(0)
write_pool_batch_idx = gH_out_slot_indices[batch_idx]
if write_pool_batch_idx < 0:
write_pool_batch_idx = cutlass.Int32(0)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for handling negative indices is duplicated in gated_delta_rule_decode_kernel_seqlen234_unified and gated_delta_rule_decode_kernel_seqlen1_lowBS_1chunk. To improve consistency and potentially instruction-level parallelism, consider grouping the reads together before the checks, as done in gated_delta_rule_decode_kernel_seqlen1_lowBS_1chunk.

Additionally, to reduce code duplication across kernels, you could introduce a cute.jit helper function at the module level to handle this pattern.

Suggested change
pool_batch_idx = gH_slot_indices[batch_idx]
if pool_batch_idx < 0:
pool_batch_idx = cutlass.Int32(0)
write_pool_batch_idx = gH_out_slot_indices[batch_idx]
if write_pool_batch_idx < 0:
write_pool_batch_idx = cutlass.Int32(0)
pool_batch_idx = gH_slot_indices[batch_idx]
write_pool_batch_idx = gH_out_slot_indices[batch_idx]
if pool_batch_idx < 0:
pool_batch_idx = cutlass.Int32(0)
if write_pool_batch_idx < 0:
write_pool_batch_idx = cutlass.Int32(0)

Comment on lines +976 to +979
if use_pool_indexing and output_state_indices is not None:
h0_out_indices = output_state_indices.to(torch.int32)
else:
h0_out_indices = h0_indices

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use_pool_indexing check here is redundant. The public API gated_delta_rule_decode_pretranspose already asserts that output_state_indices can only be provided when use_pool_indexing is true.

You can simplify this logic for better readability.

Suggested change
if use_pool_indexing and output_state_indices is not None:
h0_out_indices = output_state_indices.to(torch.int32)
else:
h0_out_indices = h0_indices
if output_state_indices is not None:
h0_out_indices = output_state_indices.to(torch.int32)
else:
h0_out_indices = h0_indices

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_decode.py`:
- Around line 195-206: The output_state_indices path currently allows in-place
remaps that can alias other batch items' source slots, making final state
CTA-order dependent; in the block that checks output_state_indices (and uses
use_pool and initial_state / initial_state_indices), validate that
output_state_indices contains no duplicate targets and that none of its target
indices overlap any indices in initial_state_indices (or raise a clear error);
alternatively implement a staged fallback: allocate a temporary buffer, gather
sources into temp using initial_state_indices, perform compute, then scatter
results from temp to initial_state using output_state_indices to avoid
read/write races. Ensure checks/reference to output_state_indices,
initial_state, initial_state_indices and use_pool are used so the change locates
the remap logic.
- Around line 195-206: The code currently only checks shape/dtype of
output_state_indices; add validation that output_state_indices is on the same
device as the pool (reject CPU/non-local tensors) and that all values are within
[0, pool_size-1] to prevent out-of-bounds or aliasing when writing into the pool
(when use_pool/initial_state is active). In the gdn_decode logic where
output_state_indices is handled (the block that asserts use_pool and checks
shape/dtype), add checks for device equality to the pool tensor and use
torch.any((idx < 0) | (idx >= pool_size)) or equivalent to raise a clear
ValueError/Assertion if any index is out of range; keep references to
output_state_indices, use_pool, pool_size, and initial_state so the guard runs
early and fails fast.

In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 1756-1767: The tests currently use torch.testing.assert_close to
check that pool_under_test[read_indices] and pool_under_test[~used_mask] match
pool_orig with nonzero atol/rtol; change these to exact-equality checks (e.g.,
use torch.equal or torch.testing.assert_close(..., atol=0, rtol=0)) for the two
assertions involving pool_under_test, pool_orig, read_indices and the computed
used_mask/write_indices so any stray mutation is caught.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 748e9c86-0f49-4322-be15-919a3f8c3f91

📥 Commits

Reviewing files that changed from the base of the PR and between 904fa8c and 93687a1a152e3e533dad263eb24c82bb7e990619.

📒 Files selected for processing (4)
  • flashinfer/gdn_decode.py
  • flashinfer/gdn_kernels/gdn_decode_bf16_state.py
  • flashinfer/gdn_kernels/gdn_decode_pretranspose.py
  • tests/gdn/test_decode_delta_rule.py

Comment thread flashinfer/gdn_decode.py
Comment on lines +195 to +206
if output_state_indices is not None:
assert use_pool, (
"output_state_indices can only be used with initial_state (pool mode)"
)
assert output_state_indices.shape == (B,), (
f"Expected output_state_indices shape [{B}], "
f"got {output_state_indices.shape}"
)
assert output_state_indices.dtype in (torch.int32, torch.int64), (
f"output_state_indices must be int32 or int64, "
f"got {output_state_indices.dtype}"
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Reject in-place remaps that alias another batch item's source slot.

output_state_indices still writes back into the same initial_state buffer during the same kernel launch. If two batch items target the same write slot, or one item writes a slot another item is still reading via initial_state_indices, the final state becomes CTA-order dependent and no longer matches gather→compute→scatter semantics. Please either validate a safe mapping here or route overlapping remaps through a staged fallback.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 195 - 206, The output_state_indices
path currently allows in-place remaps that can alias other batch items' source
slots, making final state CTA-order dependent; in the block that checks
output_state_indices (and uses use_pool and initial_state /
initial_state_indices), validate that output_state_indices contains no duplicate
targets and that none of its target indices overlap any indices in
initial_state_indices (or raise a clear error); alternatively implement a staged
fallback: allocate a temporary buffer, gather sources into temp using
initial_state_indices, perform compute, then scatter results from temp to
initial_state using output_state_indices to avoid read/write races. Ensure
checks/reference to output_state_indices, initial_state, initial_state_indices
and use_pool are used so the change locates the remap logic.

⚠️ Potential issue | 🔴 Critical

Validate output_state_indices against the pool before dispatch.

The new arg is only shape/dtype-checked. A CPU tensor here will fail late, and a negative or >= pool_size write index can either become an out-of-bounds store on the float32 pretranspose path or silently alias slot 0 on the bf16 path. Please reject non-local or out-of-range write indices here unless you want explicit write-side padding semantics.

💡 Suggested guard
     if output_state_indices is not None:
         assert use_pool, (
             "output_state_indices can only be used with initial_state (pool mode)"
         )
         assert output_state_indices.shape == (B,), (
             f"Expected output_state_indices shape [{B}], "
             f"got {output_state_indices.shape}"
         )
         assert output_state_indices.dtype in (torch.int32, torch.int64), (
             f"output_state_indices must be int32 or int64, "
             f"got {output_state_indices.dtype}"
         )
+        assert output_state_indices.device == initial_state.device, (
+            "output_state_indices must be on the same device as initial_state"
+        )
+        pool_size = int(initial_state.shape[0])
+        in_range = (output_state_indices >= 0) & (output_state_indices < pool_size)
+        assert in_range.all().item(), (
+            f"output_state_indices must be in [0, {pool_size})"
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 195 - 206, The code currently only
checks shape/dtype of output_state_indices; add validation that
output_state_indices is on the same device as the pool (reject CPU/non-local
tensors) and that all values are within [0, pool_size-1] to prevent
out-of-bounds or aliasing when writing into the pool (when
use_pool/initial_state is active). In the gdn_decode logic where
output_state_indices is handled (the block that asserts use_pool and checks
shape/dtype), add checks for device equality to the pool tensor and use
torch.any((idx < 0) | (idx >= pool_size)) or equivalent to raise a clear
ValueError/Assertion if any index is out of range; keep references to
output_state_indices, use_pool, pool_size, and initial_state so the guard runs
early and fails fast.

Comment on lines +1756 to +1767
# Read slots must be unchanged (we wrote to different slots)
torch.testing.assert_close(
pool_under_test[read_indices], pool_orig[read_indices], atol=atol, rtol=rtol
)

# Other slots must be unchanged
used_mask = torch.zeros(pool_size, dtype=torch.bool, device=device)
used_mask[read_indices] = True
used_mask[write_indices] = True
torch.testing.assert_close(
pool_under_test[~used_mask], pool_orig[~used_mask], atol=atol, rtol=rtol
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use exact equality for slots that must stay untouched.

These assertions are checking for no mutation, not numerical closeness. Keeping atol/rtol=1e-3 can hide a small stray write, so the read slots and the untouched remainder should be compared with zero tolerance.

💡 Tighten the unchanged-slot assertions
     # Read slots must be unchanged (we wrote to different slots)
     torch.testing.assert_close(
-        pool_under_test[read_indices], pool_orig[read_indices], atol=atol, rtol=rtol
+        pool_under_test[read_indices],
+        pool_orig[read_indices],
+        atol=0.0,
+        rtol=0.0,
     )

     # Other slots must be unchanged
     used_mask = torch.zeros(pool_size, dtype=torch.bool, device=device)
     used_mask[read_indices] = True
     used_mask[write_indices] = True
     torch.testing.assert_close(
-        pool_under_test[~used_mask], pool_orig[~used_mask], atol=atol, rtol=rtol
+        pool_under_test[~used_mask],
+        pool_orig[~used_mask],
+        atol=0.0,
+        rtol=0.0,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_delta_rule.py` around lines 1756 - 1767, The tests
currently use torch.testing.assert_close to check that
pool_under_test[read_indices] and pool_under_test[~used_mask] match pool_orig
with nonzero atol/rtol; change these to exact-equality checks (e.g., use
torch.equal or torch.testing.assert_close(..., atol=0, rtol=0)) for the two
assertions involving pool_under_test, pool_orig, read_indices and the computed
used_mask/write_indices so any stray mutation is caught.

@saltyminty

saltyminty commented Apr 1, 2026

Copy link
Copy Markdown
Collaborator

Approved conditional on CI.

Edit: though it seems I don't have write access so will need another reviewer to take a look

@saltyminty

Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot

Copy link
Copy Markdown
Collaborator

GitLab MR !488 has been created, and the CI pipeline #47476702 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot

Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #47476702: 7/20 passed

@feldsherov

feldsherov commented Apr 2, 2026

Copy link
Copy Markdown
Contributor Author

@saltyminty thank you for the review!

I see flashinfer-bot is reporting failed CI pipeline. I am happy to address any issues, but I don't have access to the CI results.
Can you help me in any way?

@feldsherov

Copy link
Copy Markdown
Contributor Author

@kahyunnam thank you for the review!

@kahyunnam @saltyminty what should I do to land this PR?

@feldsherov feldsherov force-pushed the gdn-decode-separate-input-and-output-indices branch from 93687a1 to dc5af70 Compare April 6, 2026 14:26
@feldsherov

Copy link
Copy Markdown
Contributor Author

In the meanwhile, I rebased to the latest main and adjusted gdn_decode_bf16state_mtp_kernel to support the change.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Around line 2549-2553: When defaulting output_state_indices (when
output_state_indices is None), preserve padding/null-buffer semantics by cloning
initial_state_indices but mapping padded markers (-1) back to the fallback write
slot (e.g., 0) before use; specifically, in the block handling
output_state_indices, set output_state_indices = initial_state_indices.clone(),
then replace any entries equal to -1 with 0, and finally ensure dtype is
torch.int32. This keeps the kernel's h0_out_indices behavior correct (padded
reads won't write to -1 locations) while keeping the int32 conversion logic.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 63235c41-a9c7-4f65-a721-9399edaa4c08

📥 Commits

Reviewing files that changed from the base of the PR and between 93687a1a152e3e533dad263eb24c82bb7e990619 and dc5af7063167c46a3114b03b72fe6c3e7ef45723.

📒 Files selected for processing (4)
  • flashinfer/gdn_decode.py
  • flashinfer/gdn_kernels/gdn_decode_bf16_state.py
  • flashinfer/gdn_kernels/gdn_decode_pretranspose.py
  • tests/gdn/test_decode_delta_rule.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • flashinfer/gdn_decode.py
  • flashinfer/gdn_kernels/gdn_decode_pretranspose.py

Comment on lines +2549 to +2553
# Default output indices to read indices
if output_state_indices is None:
output_state_indices = initial_state_indices
elif output_state_indices.dtype != torch.int32:
output_state_indices = output_state_indices.to(torch.int32)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Preserve padding/null-buffer semantics when defaulting output_state_indices.

This regresses the existing BF16 negative-index path: padded reads still come in as initial_state_indices == -1, but None now copies that -1 straight onto the write side. The kernel uses h0_out_indices for final writeback, so padded rows now write before h0_source instead of falling back to slot 0.

🐛 Minimal fix
-    if output_state_indices is None:
-        output_state_indices = initial_state_indices
-    elif output_state_indices.dtype != torch.int32:
-        output_state_indices = output_state_indices.to(torch.int32)
+    if output_state_indices is None:
+        # Preserve the existing slot-0 null-buffer behavior for padded rows.
+        output_state_indices = initial_state_indices.clamp_min(0)
+    if output_state_indices.dtype != torch.int32:
+        output_state_indices = output_state_indices.to(torch.int32)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2549 - 2553,
When defaulting output_state_indices (when output_state_indices is None),
preserve padding/null-buffer semantics by cloning initial_state_indices but
mapping padded markers (-1) back to the fallback write slot (e.g., 0) before
use; specifically, in the block handling output_state_indices, set
output_state_indices = initial_state_indices.clone(), then replace any entries
equal to -1 with 0, and finally ensure dtype is torch.int32. This keeps the
kernel's h0_out_indices behavior correct (padded reads won't write to -1
locations) while keeping the int32 conversion logic.

@saltyminty saltyminty force-pushed the gdn-decode-separate-input-and-output-indices branch from dc5af70 to c8d21f0 Compare April 8, 2026 17:24

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (4)
tests/gdn/test_decode_delta_rule.py (1)

2182-2193: ⚠️ Potential issue | 🟡 Minor

Use exact equality for untouched slots.

These assertions are checking for no mutation, not closeness. Keeping atol/rtol=1e-3 can hide a small stray write in the read slots or the untouched remainder.

Tighten the unchanged-slot checks
     # Read slots must be unchanged (we wrote to different slots)
     torch.testing.assert_close(
-        pool_under_test[read_indices], pool_orig[read_indices], atol=atol, rtol=rtol
+        pool_under_test[read_indices],
+        pool_orig[read_indices],
+        atol=0.0,
+        rtol=0.0,
     )
 
     # Other slots must be unchanged
     used_mask = torch.zeros(pool_size, dtype=torch.bool, device=device)
     used_mask[read_indices] = True
     used_mask[write_indices] = True
     torch.testing.assert_close(
-        pool_under_test[~used_mask], pool_orig[~used_mask], atol=atol, rtol=rtol
+        pool_under_test[~used_mask],
+        pool_orig[~used_mask],
+        atol=0.0,
+        rtol=0.0,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_delta_rule.py` around lines 2182 - 2193, Change the
"unchanged slots" checks to require exact equality instead of approximate
closeness: for the read-only check comparing pool_under_test[read_indices] and
pool_orig[read_indices] and for the unused-slot check comparing
pool_under_test[~used_mask] and pool_orig[~used_mask], replace the current
torch.testing.assert_close usage (which uses atol/rtol) with an exact-equality
assertion (e.g., use torch.testing.assert_close with rtol=0 and atol=0 or
torch.testing.assert_equal / torch.equal) so any stray writes to pool_under_test
(referenced by pool_under_test, pool_orig, read_indices, write_indices,
used_mask) will fail the test.
flashinfer/gdn_kernels/gdn_decode_bf16_state.py (1)

2549-2553: ⚠️ Potential issue | 🔴 Critical

Preserve null-buffer semantics when defaulting write indices.

When output_state_indices is omitted, this copies -1 padding markers from initial_state_indices onto the write side. The kernel clamps cache_idx, but final writeback still uses write_cache_idx, so padded rows can store before the pool instead of slot 0.

Minimal fix
-    if output_state_indices is None:
-        output_state_indices = initial_state_indices
-    elif output_state_indices.dtype != torch.int32:
-        output_state_indices = output_state_indices.to(torch.int32)
+    if output_state_indices is None:
+        # Preserve the slot-0 null-buffer behavior for padded rows.
+        output_state_indices = initial_state_indices.clamp_min(0)
+    if output_state_indices.dtype != torch.int32:
+        output_state_indices = output_state_indices.to(torch.int32)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2549 - 2553,
The bug is that defaulting output_state_indices to initial_state_indices copies
-1 padding markers to the write side, allowing padded rows to be written;
instead, when output_state_indices is None set it to a new int32 tensor of the
same shape filled with -1 (preserving null-buffer/write-disabled semantics)
rather than aliasing initial_state_indices; ensure subsequent dtype logic still
converts tensors to torch.int32 (use torch.full_like(initial_state_indices, -1,
dtype=torch.int32) or torch.full(initial_state_indices.shape, -1,
dtype=torch.int32) for the assignment to output_state_indices).
flashinfer/gdn_decode.py (2)

199-210: ⚠️ Potential issue | 🔴 Critical

Reject unsafe pool remaps.

output_state_indices still allows duplicate destinations and cross-batch read/write overlap. In the in-place pool path that makes the final state CTA-order dependent instead of equivalent to gather→compute→scatter. Please reject those mappings here, or route remaps through a staged buffer.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 199 - 210, The current
output_state_indices validation allows duplicate targets and cross-batch
overlaps which makes the in-place pool path CTA-order dependent; update the
check in the block that currently validates output_state_indices (the code
around the output_state_indices assertions in gdn_decode.py) to reject any
non-permutation mapping: assert that all values are in range [0, B) and that
torch.unique(output_state_indices).numel() == B (i.e., no duplicates and full
bijection), and raise a clear assertion/error message like "output_state_indices
must be a permutation (no duplicates or cross-batch overlaps) when using pool
mode"; alternatively if you prefer to support non-permutations implement a
staged-buffer path (copy to a temporary buffer then scatter) and route
non-permutation remaps through that path instead of the in-place pool path.

199-210: ⚠️ Potential issue | 🔴 Critical

Fail fast on invalid destination slots.

This still only checks shape/dtype. A CPU tensor, a negative destination, or an index >= pool_size can reach the kernels and turn into an invalid state write.

Suggested guard
     if output_state_indices is not None:
         assert use_pool, (
             "output_state_indices can only be used with initial_state (pool mode)"
         )
         assert output_state_indices.shape == (B,), (
             f"Expected output_state_indices shape [{B}], "
             f"got {output_state_indices.shape}"
         )
         assert output_state_indices.dtype in (torch.int32, torch.int64), (
             f"output_state_indices must be int32 or int64, "
             f"got {output_state_indices.dtype}"
         )
+        assert output_state_indices.device == initial_state.device, (
+            "output_state_indices must be on the same device as initial_state"
+        )
+        pool_size = int(initial_state.shape[0])
+        in_range = (output_state_indices >= 0) & (
+            output_state_indices < pool_size
+        )
+        assert torch.all(in_range).item(), (
+            f"output_state_indices must be in [0, {pool_size})"
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 199 - 210, Add strict runtime guards
for output_state_indices: when output_state_indices is not None (and use_pool is
true), assert it is on the same device as the pool/initial_state (or at least a
CUDA device, not CPU), and assert all indices are within [0, pool_size-1] (no
negatives and none >= pool_size) before passing to kernels; use tensor
operations like output_state_indices.min() and output_state_indices.max() (or
torch.any checks) to detect out-of-range values and raise clear AssertionError
messages referencing output_state_indices, use_pool, and pool_size so invalid
destination slots cannot reach the GPU kernels.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@flashinfer/gdn_decode.py`:
- Around line 199-210: The current output_state_indices validation allows
duplicate targets and cross-batch overlaps which makes the in-place pool path
CTA-order dependent; update the check in the block that currently validates
output_state_indices (the code around the output_state_indices assertions in
gdn_decode.py) to reject any non-permutation mapping: assert that all values are
in range [0, B) and that torch.unique(output_state_indices).numel() == B (i.e.,
no duplicates and full bijection), and raise a clear assertion/error message
like "output_state_indices must be a permutation (no duplicates or cross-batch
overlaps) when using pool mode"; alternatively if you prefer to support
non-permutations implement a staged-buffer path (copy to a temporary buffer then
scatter) and route non-permutation remaps through that path instead of the
in-place pool path.
- Around line 199-210: Add strict runtime guards for output_state_indices: when
output_state_indices is not None (and use_pool is true), assert it is on the
same device as the pool/initial_state (or at least a CUDA device, not CPU), and
assert all indices are within [0, pool_size-1] (no negatives and none >=
pool_size) before passing to kernels; use tensor operations like
output_state_indices.min() and output_state_indices.max() (or torch.any checks)
to detect out-of-range values and raise clear AssertionError messages
referencing output_state_indices, use_pool, and pool_size so invalid destination
slots cannot reach the GPU kernels.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Around line 2549-2553: The bug is that defaulting output_state_indices to
initial_state_indices copies -1 padding markers to the write side, allowing
padded rows to be written; instead, when output_state_indices is None set it to
a new int32 tensor of the same shape filled with -1 (preserving
null-buffer/write-disabled semantics) rather than aliasing
initial_state_indices; ensure subsequent dtype logic still converts tensors to
torch.int32 (use torch.full_like(initial_state_indices, -1, dtype=torch.int32)
or torch.full(initial_state_indices.shape, -1, dtype=torch.int32) for the
assignment to output_state_indices).

In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 2182-2193: Change the "unchanged slots" checks to require exact
equality instead of approximate closeness: for the read-only check comparing
pool_under_test[read_indices] and pool_orig[read_indices] and for the
unused-slot check comparing pool_under_test[~used_mask] and
pool_orig[~used_mask], replace the current torch.testing.assert_close usage
(which uses atol/rtol) with an exact-equality assertion (e.g., use
torch.testing.assert_close with rtol=0 and atol=0 or torch.testing.assert_equal
/ torch.equal) so any stray writes to pool_under_test (referenced by
pool_under_test, pool_orig, read_indices, write_indices, used_mask) will fail
the test.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: cd63ccca-d797-4d84-ac61-caebd3460c40

📥 Commits

Reviewing files that changed from the base of the PR and between dc5af7063167c46a3114b03b72fe6c3e7ef45723 and c8d21f0d478a94ff703be8e9ce3eebb40f005823.

📒 Files selected for processing (4)
  • flashinfer/gdn_decode.py
  • flashinfer/gdn_kernels/gdn_decode_bf16_state.py
  • flashinfer/gdn_kernels/gdn_decode_pretranspose.py
  • tests/gdn/test_decode_delta_rule.py

@saltyminty saltyminty force-pushed the gdn-decode-separate-input-and-output-indices branch from c8d21f0 to 58cc26d Compare April 8, 2026 23:57

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
tests/gdn/test_decode_delta_rule.py (1)

2182-2193: ⚠️ Potential issue | 🟡 Minor

Use exact equality for slots that must remain untouched.

These assertions verify that unmodified pool slots remain unchanged. Using atol=1e-3, rtol=1e-3 can mask small stray writes. For immutability checks, use zero tolerance.

💡 Tighten the unchanged-slot assertions
     # Read slots must be unchanged (we wrote to different slots)
     torch.testing.assert_close(
-        pool_under_test[read_indices], pool_orig[read_indices], atol=atol, rtol=rtol
+        pool_under_test[read_indices],
+        pool_orig[read_indices],
+        atol=0.0,
+        rtol=0.0,
     )

     # Other slots must be unchanged
     used_mask = torch.zeros(pool_size, dtype=torch.bool, device=device)
     used_mask[read_indices] = True
     used_mask[write_indices] = True
     torch.testing.assert_close(
-        pool_under_test[~used_mask], pool_orig[~used_mask], atol=atol, rtol=rtol
+        pool_under_test[~used_mask],
+        pool_orig[~used_mask],
+        atol=0.0,
+        rtol=0.0,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_delta_rule.py` around lines 2182 - 2193, The assertions
that verify unmodified slots use non-zero tolerances which can hide stray
writes; change the checks on pool_under_test vs pool_orig for read_indices and
~used_mask to use exact equality (e.g., torch.equal or
torch.testing.assert_close with atol=0, rtol=0) so that
pool_under_test[read_indices] and pool_under_test[~used_mask] must match
pool_orig exactly; update the two assert calls referencing pool_under_test,
pool_orig, read_indices, write_indices, and used_mask accordingly.
flashinfer/gdn_kernels/gdn_decode_bf16_state.py (1)

2549-2553: ⚠️ Potential issue | 🔴 Critical

Preserve padding/null-buffer semantics when defaulting output_state_indices.

When initial_state_indices contains -1 (padding markers), directly assigning it to output_state_indices will cause the kernel to compute flat_write_idx = -1 * HV + i_hv (a negative offset) and write to invalid memory at lines 1981-2012. The read path has protection (lines 1227-1229 clamp negative cache_idx to 0), but the write path has no equivalent guard.

🐛 Proposed fix: clamp negative indices to slot 0 for writes
     # Default output indices to read indices
     if output_state_indices is None:
-        output_state_indices = initial_state_indices
-    elif output_state_indices.dtype != torch.int32:
+        # Preserve the existing slot-0 null-buffer behavior for padded rows.
+        output_state_indices = initial_state_indices.clamp(min=0)
+    if output_state_indices.dtype != torch.int32:
         output_state_indices = output_state_indices.to(torch.int32)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2549 - 2553,
When defaulting output_state_indices to initial_state_indices, avoid assigning
the same tensor with -1 padding because the write path computes flat_write_idx
and will write to negative offsets; instead, in the branch where
output_state_indices is None, create a clone of initial_state_indices, replace
negative values (e.g. -1) with 0 to preserve the null-slot semantics for writes,
and then ensure the tensor is converted to torch.int32 (matching the existing
dtype-check branch). Update the code that sets output_state_indices so it uses
output_state_indices = initial_state_indices.clone();
output_state_indices[output_state_indices < 0] = 0; output_state_indices =
output_state_indices.to(torch.int32) (or equivalent) so flat_write_idx cannot be
negative when used with HV and i_hv.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Line 2593: The assignment to h0_out_idx_ calling from_dlpack is misformatted;
reformat that line to satisfy ruff (apply ruff format or adjust
spacing/punctuation) so it matches the project's formatting rules (e.g., proper
spacing around the = and within the function call) in the h0_out_idx_ =
from_dlpack(...) statement; keep the same variable name h0_out_idx_ and function
call from_dlpack with arguments output_state_indices, assumed_align=32,
enable_tvm_ffi=True.

---

Duplicate comments:
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Around line 2549-2553: When defaulting output_state_indices to
initial_state_indices, avoid assigning the same tensor with -1 padding because
the write path computes flat_write_idx and will write to negative offsets;
instead, in the branch where output_state_indices is None, create a clone of
initial_state_indices, replace negative values (e.g. -1) with 0 to preserve the
null-slot semantics for writes, and then ensure the tensor is converted to
torch.int32 (matching the existing dtype-check branch). Update the code that
sets output_state_indices so it uses output_state_indices =
initial_state_indices.clone(); output_state_indices[output_state_indices < 0] =
0; output_state_indices = output_state_indices.to(torch.int32) (or equivalent)
so flat_write_idx cannot be negative when used with HV and i_hv.

In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 2182-2193: The assertions that verify unmodified slots use
non-zero tolerances which can hide stray writes; change the checks on
pool_under_test vs pool_orig for read_indices and ~used_mask to use exact
equality (e.g., torch.equal or torch.testing.assert_close with atol=0, rtol=0)
so that pool_under_test[read_indices] and pool_under_test[~used_mask] must match
pool_orig exactly; update the two assert calls referencing pool_under_test,
pool_orig, read_indices, write_indices, and used_mask accordingly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 3392082e-e475-483c-bc9d-773a2995d6ff

📥 Commits

Reviewing files that changed from the base of the PR and between c8d21f0d478a94ff703be8e9ce3eebb40f005823 and 58cc26d2001e4a92f2e1b73b94df9a46386a8082.

📒 Files selected for processing (4)
  • flashinfer/gdn_decode.py
  • flashinfer/gdn_kernels/gdn_decode_bf16_state.py
  • flashinfer/gdn_kernels/gdn_decode_pretranspose.py
  • tests/gdn/test_decode_delta_rule.py
✅ Files skipped from review due to trivial changes (1)
  • flashinfer/gdn_kernels/gdn_decode_pretranspose.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • flashinfer/gdn_decode.py

@@ -2552,6 +2590,7 @@ def gated_delta_rule_mtp(
dt_bias_ = from_dlpack(dt_bias, assumed_align=32, enable_tvm_ffi=True)
o_ = from_dlpack(output, assumed_align=32, enable_tvm_ffi=True)
h0_idx_ = from_dlpack(initial_state_indices, assumed_align=32, enable_tvm_ffi=True)
h0_out_idx_ = from_dlpack(output_state_indices, assumed_align=32, enable_tvm_ffi=True)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix formatting to pass pre-commit checks.

The pipeline failure indicates this line needs reformatting per ruff format.

🔧 Apply ruff formatting
-    h0_out_idx_ = from_dlpack(output_state_indices, assumed_align=32, enable_tvm_ffi=True)
+    h0_out_idx_ = from_dlpack(
+        output_state_indices, assumed_align=32, enable_tvm_ffi=True
+    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
h0_out_idx_ = from_dlpack(output_state_indices, assumed_align=32, enable_tvm_ffi=True)
h0_out_idx_ = from_dlpack(
output_state_indices, assumed_align=32, enable_tvm_ffi=True
)
🧰 Tools
🪛 GitHub Actions: pre-commit

[error] 2590-2593: pre-commit failed: ruff format (hook id: ruff-format) reformatted files. Diff shows formatting change in gated_delta_rule_mtp() for h0_out_idx_ = from_dlpack(output_state_indices, ...).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` at line 2593, The assignment
to h0_out_idx_ calling from_dlpack is misformatted; reformat that line to
satisfy ruff (apply ruff format or adjust spacing/punctuation) so it matches the
project's formatting rules (e.g., proper spacing around the = and within the
function call) in the h0_out_idx_ = from_dlpack(...) statement; keep the same
variable name h0_out_idx_ and function call from_dlpack with arguments
output_state_indices, assumed_align=32, enable_tvm_ffi=True.

@feldsherov

Copy link
Copy Markdown
Contributor Author

@saltyminty should I do anything here?

@kahyunnam

Copy link
Copy Markdown
Member

@feldsherov this is failing pre-commit tests: https://github.com/flashinfer-ai/flashinfer/actions/runs/24164935099/job/70524339364?pr=2905

Can you please rerun pre-commit and push?

@feldsherov feldsherov force-pushed the gdn-decode-separate-input-and-output-indices branch from 58cc26d to b7986c0 Compare April 16, 2026 10:20
@feldsherov

Copy link
Copy Markdown
Contributor Author

@saltyminty updated the PR, should be green now.

@feldsherov feldsherov force-pushed the gdn-decode-separate-input-and-output-indices branch from b7986c0 to 54be68d Compare April 16, 2026 15:39
@kahyunnam kahyunnam merged commit 24f2032 into flashinfer-ai:main Apr 17, 2026
29 of 37 checks passed
ziang-and pushed a commit to zianglih/flashinfer that referenced this pull request Apr 17, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

Introduce separate output indices parameter for
gated_delta_rule_decode_pretranspose.

This addresses decoded part of feature request in flashinfer-ai#2873 

## 🔍 Related Issues

flashinfer-ai#2873 

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.). 

I've checked only tests/gdn/test_decode_delta_rule.py on H200. I need
help with running whole testsuite.

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Optional control to write updated recurrent state into
caller-specified output slots (separate from read/input slots). Enabled
only in pool (initial-state) mode and requires initial-state info;
validates indices shape and integer dtype. Preserves existing behavior
when not used.

* **Tests**
* Added tests covering separate read/write state indexing and the case
where output indices equal input indices, validating outputs and pool
mutations.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
kahyunnam pushed a commit that referenced this pull request May 6, 2026
<!-- .github/pull_request_template.md -->

  ## Summary

  Replaces the legacy `gdn_decode_bf16state_cooprow_kernel` and the
  `gdn_decode_bf16state_mtp_kernel` (ILP=8) with a new
  **`gdn_wide_vec_kernel`** (LDG.E.128 / STG.E.128 fast path) plus a
  small-batch `mtp_ilp4` fallback. Drops ~1900 LOC of dead/unused code,
  adds split-pool support (#2905-compatible) to both surviving BF16
  kernels, and ships the OOB fix mirroring upstream PR #3145 — for the
  BF16 kernels that survived the cleanup.

  **Supersedes #3118.** That PR's perf delta (T=1 per-call overhead +
  pool+padding for the ILP kernel) is the first commit on this branch
  (`8a6e9819`). 
  
  ## What changes

  - **New kernel**: `gdn_wide_vec_kernel` — 128 threads/CTA = 8 groups
    × 16 threads, vec=8 BF16 → LDG.E.128 / STG.E.128, ILP=4 V-rows per
    thread. Configurable `tile_v ∈ {32, 64, 128}` so the kernel covers
    small/medium/large `B*HV` work-unit sizes uniformly.
  - **Pool-only**: BF16 GDN dispatch is strictly pool-mode (matches
    the production serving contract). Wrapper
    `gated_delta_rule_decode_pretranspose` auto-promotes legacy non-pool
    callers internally — public API unchanged.
  - **Split-pool support** (PR #2905 contract): both surviving BF16
kernels (`gdn_wide_vec_kernel`, `gdn_decode_bf16state_mtp_ilp4_kernel`)
    natively support `output_state_indices != initial_state_indices`,
    with bit-equivalent single-pool behavior selected at compile time
    via `Constexpr[bool] same_pool` for zero-overhead dispatch.
  - **OOB fix (PR #3145 equivalent)**: `intermediate_states` is indexed
    by the per-call batch index `i_n` (not the pool-scoped `cache_idx`),
    so the buffer can be sized `[B, T, HV, V, K]` as production callers
    expect. Regression test catches the bug; pre-fix triggers
    `cudaErrorIllegalAddress` in <2 s.

  ## Removed (~1900 LOC of dead code)

  | Kernel | Why removed |
  |---|---|
| `gdn_decode_bf16state_cooprow_kernel` (~280 LOC) | Replaced by
wide_vec + ILP=4 MTP; had known correctness issues at small batch |
| `gdn_decode_bf16state_ilp_kernel` (~740 LOC) | Only reachable at HV<32
with B≥16 — not a Qwen3.5 shape; MTP path covers it |
| `gdn_decode_bf16state_mtp_kernel` (ILP=8) (~940 LOC) | After wide_vec
extension to split-pool + tile_v=32, mtp_kernel was unreachable |

  End-state BF16 surface = **2 `@cute.kernel`s in one file**:
  - `gdn_wide_vec_kernel` — production hot path
  - `gdn_decode_bf16state_mtp_ilp4_kernel` — small-batch fallback

  Both pool-only, both split-pool capable, both indexed batch-scoped.

  ## Speedup vs previous baseline

Baseline = pre-wide_vec dispatch (the `mtp_kernel` ILP=8 path, captured
  on this same branch by monkey-patching `_select_wide_vec_tile_v` to
  return `None` for every shape). Same harness, same hardware, same
  config — so the comparison isolates the kernel-level speedup that
  wide_vec + the cleanup deliver.

  Setup: B200, HV=64, K=V=128, BF16, qk_l2norm=ON, warmup=5, iters=50,
  T=1 invoked with `--update-state`, T≥2 invoked with
  `--cache-intermediate-states`. Kernel time in microseconds (CUPTI).

  ### Speedup (×, baseline / post-PR)

| B | T=1 | T=2 | T=3 | T=4 | T=5 | T=6 | T=7 | T=8 |
|-----|-------|-------|-------|-------|-------|-------|-------|-------|
| 1 | 1.03× | 1.04× | 1.03× | 1.03× | 1.00× | 1.02× | 1.02× | 1.01× |
| 4 | 0.97× | 1.23× | 1.10× | 1.12× | 1.11× | 1.11× | 1.14× | 1.14× |
| 8 | 1.08× | 1.11× | 1.11× | 1.12× | 1.13× | 1.15× | 1.15× | 1.14× |
| 16 | 1.04× | 1.09× | 1.11× | 1.13× | 1.13× | 1.12× | 1.11× | 1.10× |
| 32 | 1.06× | 1.12× | 1.10× | 1.11× | 1.10× | 1.09× | 1.09× | 1.09× |
| 64 | 1.04× | 1.11× | 1.08× | 1.09× | 1.06× | 1.07× | 1.08× | 1.06× |
| 128 | 1.04× | 1.11× | 1.07× | 1.09× | 1.06× | 1.06× | 1.07× | 1.07× |
| 256 | 1.04× | 1.11× | 1.07× | 1.09× | 1.07× | 1.06× | 1.07× | 1.07× |

  ### Time reduction (%)

| B | T=1 | T=2 | T=3 | T=4 | T=5 | T=6 | T=7 | T=8 |

|-----|-------|--------|-------|--------|--------|--------|--------|--------|
| 1 | +2.8% | +3.5% | +3.2% | +3.3% | +0.1% | +1.7% | +1.8% | +0.7% |
| 4 | −3.2% | +18.7% | +9.5% | +10.7% | +9.7% | +10.2% | +12.3% | +12.5%
|
| 8 | +7.8% | +10.1% | +10.0%| +10.7% | +11.8% | +12.7% | +12.8% |
+12.4% |
| 16 | +3.4% | +8.5% | +10.0%| +11.2% | +11.4% | +10.7% | +10.3% | +9.4%
|
| 32 | +6.0% | +10.4% | +9.4% | +9.6% | +8.8% | +8.4% | +8.4% | +7.8% |
| 64 | +4.0% | +10.3% | +7.5% | +8.6% | +5.9% | +6.3% | +7.2% | +5.9% |
| 128 | +4.2% | +9.8% | +6.5% | +8.4% | +6.0% | +5.9% | +6.3% | +6.6% |
| 256 | +4.2% | +9.6% | +6.6% | +8.4% | +6.3% | +5.9% | +6.5% | +6.6% |

  ### Headline

- **T=1 production decode (B≥16)**: 4–6 % time reduction across the full
batch sweep — the Qwen3.5 hot path.
- **T≥2 with cache=ON (B≥4)**: 6–18 % time reduction at every shape.
Best at small-T / mid-batch (B=4 T=2: 1.23×; B=8 T=6: 1.15×).
- **Tiny shapes (B=1)**: within ±3 % of baseline (kernel isn't
DRAM-bound; small fixed-cost overheads dominate; the ILP=4 fallback was
already efficient there).

  ### Sustained DRAM bandwidth post-PR (TB/s, 8 TB/s peak on B200)

  |  B  | T=1  | T=2  | T=3  | T=4  | T=5  | T=6  | T=7  | T=8  |
  |-----|------|------|------|------|------|------|------|------|
  |   1 | 1.25 | 1.21 | 1.49 | 1.63 | 1.45 | 1.55 | 1.64 | 1.70 |
  |   4 | 2.83 | 3.24 | 3.54 | 3.78 | 3.53 | 3.68 | 3.84 | 3.95 |
  |   8 | 3.97 | 4.09 | 4.40 | 4.54 | 4.42 | 4.55 | 4.59 | 4.61 |
  |  16 | 4.73 | 4.73 | 5.02 | 5.03 | 4.95 | 4.91 | 4.92 | 4.87 |
  |  32 | 5.39 | 5.36 | 5.44 | 5.46 | 5.27 | 5.23 | 5.21 | 5.17 |
  |  64 | 5.83 | 5.76 | 5.80 | 5.77 | 5.45 | 5.44 | 5.45 | 5.33 |
  | 128 | 6.31 | 6.05 | 6.03 | 6.01 | 5.68 | 5.61 | 5.57 | 5.54 |
  | 256 | 6.57 | 6.23 | 6.20 | 6.17 | 5.85 | 5.74 | 5.72 | 5.66 |

Post-PR peaks at **6.57 TB/s = 82 % of B200 peak DRAM** (T=1 B=256
production decode shape).

  ### Split-pool

  With wide_vec now supporting split-pool natively, split-pool
  matches single-pool to within ±1 % at every measured shape.

  ## Tests

  > **513 passed, 0 failed in 18m18s** on B200.

Including: 477 existing BF16/wide_vec/pool tests, 12 new split-pool MTP
tests, 12 new OOB regression tests covering `pool_size_multiplier ∈ {1,
4}` × `B ∈ {1, 8, 32}` × `T ∈ {2, 4}`, 12 wrapper-level split-pool
tests.

  ## Files changed (4)

- `flashinfer/gdn_decode.py` — wrapper auto-promotes BF16 non-pool →
pool
- `flashinfer/gdn_kernels/gdn_decode_bf16_state.py` — wide_vec inlined;
dead kernels removed; split-pool plumbing; OOB fix; same_pool DCE
- `tests/gdn/test_decode_delta_rule.py` — split-pool + OOB regression
tests
  - `benchmarks/bench_gdn_decode.py` — `--pool-mode {single,split}` flag

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added `--pool-mode` option to benchmark tool for configuring state
pool allocation (`single` or `split` modes).

* **Tests**
* Expanded BF16 test coverage with regression tests for split-pool
semantics and out-of-bounds scenarios; improved batch-dimension handling
for intermediate-state comparisons.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@coderabbitai coderabbitai Bot mentioned this pull request May 8, 2026
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants