feat(gdn): separate input and output pool indices#2905
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds an optional Changes
Sequence DiagramsequenceDiagram
participant User as User Code
participant API as gated_delta_rule_decode_pretranspose
participant Runner as run_pretranspose_decode / gated_delta_rule_mtp
participant Kernel as CUDA Kernel (pretranspose / bf16 MTP)
participant Pool as State Pool
User->>API: call with initial_state, initial_state_indices (read), output_state_indices (write)
API->>API: validate output_state_indices (pool-mode, shape [B], int dtype)
API->>Runner: forward tensors and indices
Runner->>Kernel: launch with h0_indices (read) and h0_out_indices (write)
Kernel->>Pool: read state from Pool[h0_indices[b]]
Kernel->>Kernel: compute gated-delta updates
Kernel->>Pool: write updated state to Pool[h0_out_indices[b]]
Pool-->>User: outputs and mutated pool
Estimated code review effort🎯 4 (Complex) | ⏱️ ~40 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces the output_state_indices parameter to the Gated Delta Rule decode kernels, enabling the updated state to be written to a different pool slot than the one read from. The changes include updates to the high-level Python API, the underlying CUDA kernels in gdn_decode_bf16_state.py and gdn_decode_pretranspose.py, and the addition of verification tests. Review feedback recommends grouping index reads within the kernels to improve consistency and instruction-level parallelism, as well as simplifying redundant logic in the pretranspose runner.
| pool_batch_idx = gH_slot_indices[batch_idx] | ||
| if pool_batch_idx < 0: | ||
| pool_batch_idx = cutlass.Int32(0) | ||
| write_pool_batch_idx = gH_out_slot_indices[batch_idx] | ||
| if write_pool_batch_idx < 0: | ||
| write_pool_batch_idx = cutlass.Int32(0) |
There was a problem hiding this comment.
This logic for handling negative indices is duplicated in gated_delta_rule_decode_kernel_seqlen234_unified and gated_delta_rule_decode_kernel_seqlen1_lowBS_1chunk. To improve consistency and potentially instruction-level parallelism, consider grouping the reads together before the checks, as done in gated_delta_rule_decode_kernel_seqlen1_lowBS_1chunk.
Additionally, to reduce code duplication across kernels, you could introduce a cute.jit helper function at the module level to handle this pattern.
| pool_batch_idx = gH_slot_indices[batch_idx] | |
| if pool_batch_idx < 0: | |
| pool_batch_idx = cutlass.Int32(0) | |
| write_pool_batch_idx = gH_out_slot_indices[batch_idx] | |
| if write_pool_batch_idx < 0: | |
| write_pool_batch_idx = cutlass.Int32(0) | |
| pool_batch_idx = gH_slot_indices[batch_idx] | |
| write_pool_batch_idx = gH_out_slot_indices[batch_idx] | |
| if pool_batch_idx < 0: | |
| pool_batch_idx = cutlass.Int32(0) | |
| if write_pool_batch_idx < 0: | |
| write_pool_batch_idx = cutlass.Int32(0) |
| if use_pool_indexing and output_state_indices is not None: | ||
| h0_out_indices = output_state_indices.to(torch.int32) | ||
| else: | ||
| h0_out_indices = h0_indices |
There was a problem hiding this comment.
The use_pool_indexing check here is redundant. The public API gated_delta_rule_decode_pretranspose already asserts that output_state_indices can only be provided when use_pool_indexing is true.
You can simplify this logic for better readability.
| if use_pool_indexing and output_state_indices is not None: | |
| h0_out_indices = output_state_indices.to(torch.int32) | |
| else: | |
| h0_out_indices = h0_indices | |
| if output_state_indices is not None: | |
| h0_out_indices = output_state_indices.to(torch.int32) | |
| else: | |
| h0_out_indices = h0_indices |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gdn_decode.py`:
- Around line 195-206: The output_state_indices path currently allows in-place
remaps that can alias other batch items' source slots, making final state
CTA-order dependent; in the block that checks output_state_indices (and uses
use_pool and initial_state / initial_state_indices), validate that
output_state_indices contains no duplicate targets and that none of its target
indices overlap any indices in initial_state_indices (or raise a clear error);
alternatively implement a staged fallback: allocate a temporary buffer, gather
sources into temp using initial_state_indices, perform compute, then scatter
results from temp to initial_state using output_state_indices to avoid
read/write races. Ensure checks/reference to output_state_indices,
initial_state, initial_state_indices and use_pool are used so the change locates
the remap logic.
- Around line 195-206: The code currently only checks shape/dtype of
output_state_indices; add validation that output_state_indices is on the same
device as the pool (reject CPU/non-local tensors) and that all values are within
[0, pool_size-1] to prevent out-of-bounds or aliasing when writing into the pool
(when use_pool/initial_state is active). In the gdn_decode logic where
output_state_indices is handled (the block that asserts use_pool and checks
shape/dtype), add checks for device equality to the pool tensor and use
torch.any((idx < 0) | (idx >= pool_size)) or equivalent to raise a clear
ValueError/Assertion if any index is out of range; keep references to
output_state_indices, use_pool, pool_size, and initial_state so the guard runs
early and fails fast.
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 1756-1767: The tests currently use torch.testing.assert_close to
check that pool_under_test[read_indices] and pool_under_test[~used_mask] match
pool_orig with nonzero atol/rtol; change these to exact-equality checks (e.g.,
use torch.equal or torch.testing.assert_close(..., atol=0, rtol=0)) for the two
assertions involving pool_under_test, pool_orig, read_indices and the computed
used_mask/write_indices so any stray mutation is caught.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 748e9c86-0f49-4322-be15-919a3f8c3f91
📥 Commits
Reviewing files that changed from the base of the PR and between 904fa8c and 93687a1a152e3e533dad263eb24c82bb7e990619.
📒 Files selected for processing (4)
flashinfer/gdn_decode.pyflashinfer/gdn_kernels/gdn_decode_bf16_state.pyflashinfer/gdn_kernels/gdn_decode_pretranspose.pytests/gdn/test_decode_delta_rule.py
| if output_state_indices is not None: | ||
| assert use_pool, ( | ||
| "output_state_indices can only be used with initial_state (pool mode)" | ||
| ) | ||
| assert output_state_indices.shape == (B,), ( | ||
| f"Expected output_state_indices shape [{B}], " | ||
| f"got {output_state_indices.shape}" | ||
| ) | ||
| assert output_state_indices.dtype in (torch.int32, torch.int64), ( | ||
| f"output_state_indices must be int32 or int64, " | ||
| f"got {output_state_indices.dtype}" | ||
| ) |
There was a problem hiding this comment.
Reject in-place remaps that alias another batch item's source slot.
output_state_indices still writes back into the same initial_state buffer during the same kernel launch. If two batch items target the same write slot, or one item writes a slot another item is still reading via initial_state_indices, the final state becomes CTA-order dependent and no longer matches gather→compute→scatter semantics. Please either validate a safe mapping here or route overlapping remaps through a staged fallback.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_decode.py` around lines 195 - 206, The output_state_indices
path currently allows in-place remaps that can alias other batch items' source
slots, making final state CTA-order dependent; in the block that checks
output_state_indices (and uses use_pool and initial_state /
initial_state_indices), validate that output_state_indices contains no duplicate
targets and that none of its target indices overlap any indices in
initial_state_indices (or raise a clear error); alternatively implement a staged
fallback: allocate a temporary buffer, gather sources into temp using
initial_state_indices, perform compute, then scatter results from temp to
initial_state using output_state_indices to avoid read/write races. Ensure
checks/reference to output_state_indices, initial_state, initial_state_indices
and use_pool are used so the change locates the remap logic.
Validate output_state_indices against the pool before dispatch.
The new arg is only shape/dtype-checked. A CPU tensor here will fail late, and a negative or >= pool_size write index can either become an out-of-bounds store on the float32 pretranspose path or silently alias slot 0 on the bf16 path. Please reject non-local or out-of-range write indices here unless you want explicit write-side padding semantics.
💡 Suggested guard
if output_state_indices is not None:
assert use_pool, (
"output_state_indices can only be used with initial_state (pool mode)"
)
assert output_state_indices.shape == (B,), (
f"Expected output_state_indices shape [{B}], "
f"got {output_state_indices.shape}"
)
assert output_state_indices.dtype in (torch.int32, torch.int64), (
f"output_state_indices must be int32 or int64, "
f"got {output_state_indices.dtype}"
)
+ assert output_state_indices.device == initial_state.device, (
+ "output_state_indices must be on the same device as initial_state"
+ )
+ pool_size = int(initial_state.shape[0])
+ in_range = (output_state_indices >= 0) & (output_state_indices < pool_size)
+ assert in_range.all().item(), (
+ f"output_state_indices must be in [0, {pool_size})"
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_decode.py` around lines 195 - 206, The code currently only
checks shape/dtype of output_state_indices; add validation that
output_state_indices is on the same device as the pool (reject CPU/non-local
tensors) and that all values are within [0, pool_size-1] to prevent
out-of-bounds or aliasing when writing into the pool (when
use_pool/initial_state is active). In the gdn_decode logic where
output_state_indices is handled (the block that asserts use_pool and checks
shape/dtype), add checks for device equality to the pool tensor and use
torch.any((idx < 0) | (idx >= pool_size)) or equivalent to raise a clear
ValueError/Assertion if any index is out of range; keep references to
output_state_indices, use_pool, pool_size, and initial_state so the guard runs
early and fails fast.
| # Read slots must be unchanged (we wrote to different slots) | ||
| torch.testing.assert_close( | ||
| pool_under_test[read_indices], pool_orig[read_indices], atol=atol, rtol=rtol | ||
| ) | ||
|
|
||
| # Other slots must be unchanged | ||
| used_mask = torch.zeros(pool_size, dtype=torch.bool, device=device) | ||
| used_mask[read_indices] = True | ||
| used_mask[write_indices] = True | ||
| torch.testing.assert_close( | ||
| pool_under_test[~used_mask], pool_orig[~used_mask], atol=atol, rtol=rtol | ||
| ) |
There was a problem hiding this comment.
Use exact equality for slots that must stay untouched.
These assertions are checking for no mutation, not numerical closeness. Keeping atol/rtol=1e-3 can hide a small stray write, so the read slots and the untouched remainder should be compared with zero tolerance.
💡 Tighten the unchanged-slot assertions
# Read slots must be unchanged (we wrote to different slots)
torch.testing.assert_close(
- pool_under_test[read_indices], pool_orig[read_indices], atol=atol, rtol=rtol
+ pool_under_test[read_indices],
+ pool_orig[read_indices],
+ atol=0.0,
+ rtol=0.0,
)
# Other slots must be unchanged
used_mask = torch.zeros(pool_size, dtype=torch.bool, device=device)
used_mask[read_indices] = True
used_mask[write_indices] = True
torch.testing.assert_close(
- pool_under_test[~used_mask], pool_orig[~used_mask], atol=atol, rtol=rtol
+ pool_under_test[~used_mask],
+ pool_orig[~used_mask],
+ atol=0.0,
+ rtol=0.0,
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/gdn/test_decode_delta_rule.py` around lines 1756 - 1767, The tests
currently use torch.testing.assert_close to check that
pool_under_test[read_indices] and pool_under_test[~used_mask] match pool_orig
with nonzero atol/rtol; change these to exact-equality checks (e.g., use
torch.equal or torch.testing.assert_close(..., atol=0, rtol=0)) for the two
assertions involving pool_under_test, pool_orig, read_indices and the computed
used_mask/write_indices so any stray mutation is caught.
|
Approved conditional on CI. Edit: though it seems I don't have write access so will need another reviewer to take a look |
|
/bot run |
|
[FAILED] Pipeline #47476702: 7/20 passed |
|
@saltyminty thank you for the review! I see flashinfer-bot is reporting failed CI pipeline. I am happy to address any issues, but I don't have access to the CI results. |
|
@kahyunnam thank you for the review! @kahyunnam @saltyminty what should I do to land this PR? |
93687a1 to
dc5af70
Compare
|
In the meanwhile, I rebased to the latest main and adjusted gdn_decode_bf16state_mtp_kernel to support the change. |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Around line 2549-2553: When defaulting output_state_indices (when
output_state_indices is None), preserve padding/null-buffer semantics by cloning
initial_state_indices but mapping padded markers (-1) back to the fallback write
slot (e.g., 0) before use; specifically, in the block handling
output_state_indices, set output_state_indices = initial_state_indices.clone(),
then replace any entries equal to -1 with 0, and finally ensure dtype is
torch.int32. This keeps the kernel's h0_out_indices behavior correct (padded
reads won't write to -1 locations) while keeping the int32 conversion logic.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 63235c41-a9c7-4f65-a721-9399edaa4c08
📥 Commits
Reviewing files that changed from the base of the PR and between 93687a1a152e3e533dad263eb24c82bb7e990619 and dc5af7063167c46a3114b03b72fe6c3e7ef45723.
📒 Files selected for processing (4)
flashinfer/gdn_decode.pyflashinfer/gdn_kernels/gdn_decode_bf16_state.pyflashinfer/gdn_kernels/gdn_decode_pretranspose.pytests/gdn/test_decode_delta_rule.py
🚧 Files skipped from review as they are similar to previous changes (2)
- flashinfer/gdn_decode.py
- flashinfer/gdn_kernels/gdn_decode_pretranspose.py
| # Default output indices to read indices | ||
| if output_state_indices is None: | ||
| output_state_indices = initial_state_indices | ||
| elif output_state_indices.dtype != torch.int32: | ||
| output_state_indices = output_state_indices.to(torch.int32) |
There was a problem hiding this comment.
Preserve padding/null-buffer semantics when defaulting output_state_indices.
This regresses the existing BF16 negative-index path: padded reads still come in as initial_state_indices == -1, but None now copies that -1 straight onto the write side. The kernel uses h0_out_indices for final writeback, so padded rows now write before h0_source instead of falling back to slot 0.
🐛 Minimal fix
- if output_state_indices is None:
- output_state_indices = initial_state_indices
- elif output_state_indices.dtype != torch.int32:
- output_state_indices = output_state_indices.to(torch.int32)
+ if output_state_indices is None:
+ # Preserve the existing slot-0 null-buffer behavior for padded rows.
+ output_state_indices = initial_state_indices.clamp_min(0)
+ if output_state_indices.dtype != torch.int32:
+ output_state_indices = output_state_indices.to(torch.int32)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2549 - 2553,
When defaulting output_state_indices (when output_state_indices is None),
preserve padding/null-buffer semantics by cloning initial_state_indices but
mapping padded markers (-1) back to the fallback write slot (e.g., 0) before
use; specifically, in the block handling output_state_indices, set
output_state_indices = initial_state_indices.clone(), then replace any entries
equal to -1 with 0, and finally ensure dtype is torch.int32. This keeps the
kernel's h0_out_indices behavior correct (padded reads won't write to -1
locations) while keeping the int32 conversion logic.
dc5af70 to
c8d21f0
Compare
There was a problem hiding this comment.
♻️ Duplicate comments (4)
tests/gdn/test_decode_delta_rule.py (1)
2182-2193:⚠️ Potential issue | 🟡 MinorUse exact equality for untouched slots.
These assertions are checking for no mutation, not closeness. Keeping
atol/rtol=1e-3can hide a small stray write in the read slots or the untouched remainder.Tighten the unchanged-slot checks
# Read slots must be unchanged (we wrote to different slots) torch.testing.assert_close( - pool_under_test[read_indices], pool_orig[read_indices], atol=atol, rtol=rtol + pool_under_test[read_indices], + pool_orig[read_indices], + atol=0.0, + rtol=0.0, ) # Other slots must be unchanged used_mask = torch.zeros(pool_size, dtype=torch.bool, device=device) used_mask[read_indices] = True used_mask[write_indices] = True torch.testing.assert_close( - pool_under_test[~used_mask], pool_orig[~used_mask], atol=atol, rtol=rtol + pool_under_test[~used_mask], + pool_orig[~used_mask], + atol=0.0, + rtol=0.0, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_decode_delta_rule.py` around lines 2182 - 2193, Change the "unchanged slots" checks to require exact equality instead of approximate closeness: for the read-only check comparing pool_under_test[read_indices] and pool_orig[read_indices] and for the unused-slot check comparing pool_under_test[~used_mask] and pool_orig[~used_mask], replace the current torch.testing.assert_close usage (which uses atol/rtol) with an exact-equality assertion (e.g., use torch.testing.assert_close with rtol=0 and atol=0 or torch.testing.assert_equal / torch.equal) so any stray writes to pool_under_test (referenced by pool_under_test, pool_orig, read_indices, write_indices, used_mask) will fail the test.flashinfer/gdn_kernels/gdn_decode_bf16_state.py (1)
2549-2553:⚠️ Potential issue | 🔴 CriticalPreserve null-buffer semantics when defaulting write indices.
When
output_state_indicesis omitted, this copies-1padding markers frominitial_state_indicesonto the write side. The kernel clampscache_idx, but final writeback still useswrite_cache_idx, so padded rows can store before the pool instead of slot 0.Minimal fix
- if output_state_indices is None: - output_state_indices = initial_state_indices - elif output_state_indices.dtype != torch.int32: - output_state_indices = output_state_indices.to(torch.int32) + if output_state_indices is None: + # Preserve the slot-0 null-buffer behavior for padded rows. + output_state_indices = initial_state_indices.clamp_min(0) + if output_state_indices.dtype != torch.int32: + output_state_indices = output_state_indices.to(torch.int32)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2549 - 2553, The bug is that defaulting output_state_indices to initial_state_indices copies -1 padding markers to the write side, allowing padded rows to be written; instead, when output_state_indices is None set it to a new int32 tensor of the same shape filled with -1 (preserving null-buffer/write-disabled semantics) rather than aliasing initial_state_indices; ensure subsequent dtype logic still converts tensors to torch.int32 (use torch.full_like(initial_state_indices, -1, dtype=torch.int32) or torch.full(initial_state_indices.shape, -1, dtype=torch.int32) for the assignment to output_state_indices).flashinfer/gdn_decode.py (2)
199-210:⚠️ Potential issue | 🔴 CriticalReject unsafe pool remaps.
output_state_indicesstill allows duplicate destinations and cross-batch read/write overlap. In the in-place pool path that makes the final state CTA-order dependent instead of equivalent to gather→compute→scatter. Please reject those mappings here, or route remaps through a staged buffer.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_decode.py` around lines 199 - 210, The current output_state_indices validation allows duplicate targets and cross-batch overlaps which makes the in-place pool path CTA-order dependent; update the check in the block that currently validates output_state_indices (the code around the output_state_indices assertions in gdn_decode.py) to reject any non-permutation mapping: assert that all values are in range [0, B) and that torch.unique(output_state_indices).numel() == B (i.e., no duplicates and full bijection), and raise a clear assertion/error message like "output_state_indices must be a permutation (no duplicates or cross-batch overlaps) when using pool mode"; alternatively if you prefer to support non-permutations implement a staged-buffer path (copy to a temporary buffer then scatter) and route non-permutation remaps through that path instead of the in-place pool path.
199-210:⚠️ Potential issue | 🔴 CriticalFail fast on invalid destination slots.
This still only checks shape/dtype. A CPU tensor, a negative destination, or an index
>= pool_sizecan reach the kernels and turn into an invalid state write.Suggested guard
if output_state_indices is not None: assert use_pool, ( "output_state_indices can only be used with initial_state (pool mode)" ) assert output_state_indices.shape == (B,), ( f"Expected output_state_indices shape [{B}], " f"got {output_state_indices.shape}" ) assert output_state_indices.dtype in (torch.int32, torch.int64), ( f"output_state_indices must be int32 or int64, " f"got {output_state_indices.dtype}" ) + assert output_state_indices.device == initial_state.device, ( + "output_state_indices must be on the same device as initial_state" + ) + pool_size = int(initial_state.shape[0]) + in_range = (output_state_indices >= 0) & ( + output_state_indices < pool_size + ) + assert torch.all(in_range).item(), ( + f"output_state_indices must be in [0, {pool_size})" + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_decode.py` around lines 199 - 210, Add strict runtime guards for output_state_indices: when output_state_indices is not None (and use_pool is true), assert it is on the same device as the pool/initial_state (or at least a CUDA device, not CPU), and assert all indices are within [0, pool_size-1] (no negatives and none >= pool_size) before passing to kernels; use tensor operations like output_state_indices.min() and output_state_indices.max() (or torch.any checks) to detect out-of-range values and raise clear AssertionError messages referencing output_state_indices, use_pool, and pool_size so invalid destination slots cannot reach the GPU kernels.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@flashinfer/gdn_decode.py`:
- Around line 199-210: The current output_state_indices validation allows
duplicate targets and cross-batch overlaps which makes the in-place pool path
CTA-order dependent; update the check in the block that currently validates
output_state_indices (the code around the output_state_indices assertions in
gdn_decode.py) to reject any non-permutation mapping: assert that all values are
in range [0, B) and that torch.unique(output_state_indices).numel() == B (i.e.,
no duplicates and full bijection), and raise a clear assertion/error message
like "output_state_indices must be a permutation (no duplicates or cross-batch
overlaps) when using pool mode"; alternatively if you prefer to support
non-permutations implement a staged-buffer path (copy to a temporary buffer then
scatter) and route non-permutation remaps through that path instead of the
in-place pool path.
- Around line 199-210: Add strict runtime guards for output_state_indices: when
output_state_indices is not None (and use_pool is true), assert it is on the
same device as the pool/initial_state (or at least a CUDA device, not CPU), and
assert all indices are within [0, pool_size-1] (no negatives and none >=
pool_size) before passing to kernels; use tensor operations like
output_state_indices.min() and output_state_indices.max() (or torch.any checks)
to detect out-of-range values and raise clear AssertionError messages
referencing output_state_indices, use_pool, and pool_size so invalid destination
slots cannot reach the GPU kernels.
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Around line 2549-2553: The bug is that defaulting output_state_indices to
initial_state_indices copies -1 padding markers to the write side, allowing
padded rows to be written; instead, when output_state_indices is None set it to
a new int32 tensor of the same shape filled with -1 (preserving
null-buffer/write-disabled semantics) rather than aliasing
initial_state_indices; ensure subsequent dtype logic still converts tensors to
torch.int32 (use torch.full_like(initial_state_indices, -1, dtype=torch.int32)
or torch.full(initial_state_indices.shape, -1, dtype=torch.int32) for the
assignment to output_state_indices).
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 2182-2193: Change the "unchanged slots" checks to require exact
equality instead of approximate closeness: for the read-only check comparing
pool_under_test[read_indices] and pool_orig[read_indices] and for the
unused-slot check comparing pool_under_test[~used_mask] and
pool_orig[~used_mask], replace the current torch.testing.assert_close usage
(which uses atol/rtol) with an exact-equality assertion (e.g., use
torch.testing.assert_close with rtol=0 and atol=0 or torch.testing.assert_equal
/ torch.equal) so any stray writes to pool_under_test (referenced by
pool_under_test, pool_orig, read_indices, write_indices, used_mask) will fail
the test.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: cd63ccca-d797-4d84-ac61-caebd3460c40
📥 Commits
Reviewing files that changed from the base of the PR and between dc5af7063167c46a3114b03b72fe6c3e7ef45723 and c8d21f0d478a94ff703be8e9ce3eebb40f005823.
📒 Files selected for processing (4)
flashinfer/gdn_decode.pyflashinfer/gdn_kernels/gdn_decode_bf16_state.pyflashinfer/gdn_kernels/gdn_decode_pretranspose.pytests/gdn/test_decode_delta_rule.py
c8d21f0 to
58cc26d
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
tests/gdn/test_decode_delta_rule.py (1)
2182-2193:⚠️ Potential issue | 🟡 MinorUse exact equality for slots that must remain untouched.
These assertions verify that unmodified pool slots remain unchanged. Using
atol=1e-3, rtol=1e-3can mask small stray writes. For immutability checks, use zero tolerance.💡 Tighten the unchanged-slot assertions
# Read slots must be unchanged (we wrote to different slots) torch.testing.assert_close( - pool_under_test[read_indices], pool_orig[read_indices], atol=atol, rtol=rtol + pool_under_test[read_indices], + pool_orig[read_indices], + atol=0.0, + rtol=0.0, ) # Other slots must be unchanged used_mask = torch.zeros(pool_size, dtype=torch.bool, device=device) used_mask[read_indices] = True used_mask[write_indices] = True torch.testing.assert_close( - pool_under_test[~used_mask], pool_orig[~used_mask], atol=atol, rtol=rtol + pool_under_test[~used_mask], + pool_orig[~used_mask], + atol=0.0, + rtol=0.0, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_decode_delta_rule.py` around lines 2182 - 2193, The assertions that verify unmodified slots use non-zero tolerances which can hide stray writes; change the checks on pool_under_test vs pool_orig for read_indices and ~used_mask to use exact equality (e.g., torch.equal or torch.testing.assert_close with atol=0, rtol=0) so that pool_under_test[read_indices] and pool_under_test[~used_mask] must match pool_orig exactly; update the two assert calls referencing pool_under_test, pool_orig, read_indices, write_indices, and used_mask accordingly.flashinfer/gdn_kernels/gdn_decode_bf16_state.py (1)
2549-2553:⚠️ Potential issue | 🔴 CriticalPreserve padding/null-buffer semantics when defaulting
output_state_indices.When
initial_state_indicescontains-1(padding markers), directly assigning it tooutput_state_indiceswill cause the kernel to computeflat_write_idx = -1 * HV + i_hv(a negative offset) and write to invalid memory at lines 1981-2012. The read path has protection (lines 1227-1229 clamp negativecache_idxto 0), but the write path has no equivalent guard.🐛 Proposed fix: clamp negative indices to slot 0 for writes
# Default output indices to read indices if output_state_indices is None: - output_state_indices = initial_state_indices - elif output_state_indices.dtype != torch.int32: + # Preserve the existing slot-0 null-buffer behavior for padded rows. + output_state_indices = initial_state_indices.clamp(min=0) + if output_state_indices.dtype != torch.int32: output_state_indices = output_state_indices.to(torch.int32)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2549 - 2553, When defaulting output_state_indices to initial_state_indices, avoid assigning the same tensor with -1 padding because the write path computes flat_write_idx and will write to negative offsets; instead, in the branch where output_state_indices is None, create a clone of initial_state_indices, replace negative values (e.g. -1) with 0 to preserve the null-slot semantics for writes, and then ensure the tensor is converted to torch.int32 (matching the existing dtype-check branch). Update the code that sets output_state_indices so it uses output_state_indices = initial_state_indices.clone(); output_state_indices[output_state_indices < 0] = 0; output_state_indices = output_state_indices.to(torch.int32) (or equivalent) so flat_write_idx cannot be negative when used with HV and i_hv.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Line 2593: The assignment to h0_out_idx_ calling from_dlpack is misformatted;
reformat that line to satisfy ruff (apply ruff format or adjust
spacing/punctuation) so it matches the project's formatting rules (e.g., proper
spacing around the = and within the function call) in the h0_out_idx_ =
from_dlpack(...) statement; keep the same variable name h0_out_idx_ and function
call from_dlpack with arguments output_state_indices, assumed_align=32,
enable_tvm_ffi=True.
---
Duplicate comments:
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Around line 2549-2553: When defaulting output_state_indices to
initial_state_indices, avoid assigning the same tensor with -1 padding because
the write path computes flat_write_idx and will write to negative offsets;
instead, in the branch where output_state_indices is None, create a clone of
initial_state_indices, replace negative values (e.g. -1) with 0 to preserve the
null-slot semantics for writes, and then ensure the tensor is converted to
torch.int32 (matching the existing dtype-check branch). Update the code that
sets output_state_indices so it uses output_state_indices =
initial_state_indices.clone(); output_state_indices[output_state_indices < 0] =
0; output_state_indices = output_state_indices.to(torch.int32) (or equivalent)
so flat_write_idx cannot be negative when used with HV and i_hv.
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 2182-2193: The assertions that verify unmodified slots use
non-zero tolerances which can hide stray writes; change the checks on
pool_under_test vs pool_orig for read_indices and ~used_mask to use exact
equality (e.g., torch.equal or torch.testing.assert_close with atol=0, rtol=0)
so that pool_under_test[read_indices] and pool_under_test[~used_mask] must match
pool_orig exactly; update the two assert calls referencing pool_under_test,
pool_orig, read_indices, write_indices, and used_mask accordingly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 3392082e-e475-483c-bc9d-773a2995d6ff
📥 Commits
Reviewing files that changed from the base of the PR and between c8d21f0d478a94ff703be8e9ce3eebb40f005823 and 58cc26d2001e4a92f2e1b73b94df9a46386a8082.
📒 Files selected for processing (4)
flashinfer/gdn_decode.pyflashinfer/gdn_kernels/gdn_decode_bf16_state.pyflashinfer/gdn_kernels/gdn_decode_pretranspose.pytests/gdn/test_decode_delta_rule.py
✅ Files skipped from review due to trivial changes (1)
- flashinfer/gdn_kernels/gdn_decode_pretranspose.py
🚧 Files skipped from review as they are similar to previous changes (1)
- flashinfer/gdn_decode.py
| @@ -2552,6 +2590,7 @@ def gated_delta_rule_mtp( | |||
| dt_bias_ = from_dlpack(dt_bias, assumed_align=32, enable_tvm_ffi=True) | |||
| o_ = from_dlpack(output, assumed_align=32, enable_tvm_ffi=True) | |||
| h0_idx_ = from_dlpack(initial_state_indices, assumed_align=32, enable_tvm_ffi=True) | |||
| h0_out_idx_ = from_dlpack(output_state_indices, assumed_align=32, enable_tvm_ffi=True) | |||
There was a problem hiding this comment.
Fix formatting to pass pre-commit checks.
The pipeline failure indicates this line needs reformatting per ruff format.
🔧 Apply ruff formatting
- h0_out_idx_ = from_dlpack(output_state_indices, assumed_align=32, enable_tvm_ffi=True)
+ h0_out_idx_ = from_dlpack(
+ output_state_indices, assumed_align=32, enable_tvm_ffi=True
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| h0_out_idx_ = from_dlpack(output_state_indices, assumed_align=32, enable_tvm_ffi=True) | |
| h0_out_idx_ = from_dlpack( | |
| output_state_indices, assumed_align=32, enable_tvm_ffi=True | |
| ) |
🧰 Tools
🪛 GitHub Actions: pre-commit
[error] 2590-2593: pre-commit failed: ruff format (hook id: ruff-format) reformatted files. Diff shows formatting change in gated_delta_rule_mtp() for h0_out_idx_ = from_dlpack(output_state_indices, ...).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` at line 2593, The assignment
to h0_out_idx_ calling from_dlpack is misformatted; reformat that line to
satisfy ruff (apply ruff format or adjust spacing/punctuation) so it matches the
project's formatting rules (e.g., proper spacing around the = and within the
function call) in the h0_out_idx_ = from_dlpack(...) statement; keep the same
variable name h0_out_idx_ and function call from_dlpack with arguments
output_state_indices, assumed_align=32, enable_tvm_ffi=True.
|
@saltyminty should I do anything here? |
|
@feldsherov this is failing pre-commit tests: https://github.com/flashinfer-ai/flashinfer/actions/runs/24164935099/job/70524339364?pr=2905 Can you please rerun pre-commit and push? |
58cc26d to
b7986c0
Compare
|
@saltyminty updated the PR, should be green now. |
b7986c0 to
54be68d
Compare
<!-- .github/pull_request_template.md --> ## 📌 Description Introduce separate output indices parameter for gated_delta_rule_decode_pretranspose. This addresses decoded part of feature request in flashinfer-ai#2873 ## 🔍 Related Issues flashinfer-ai#2873 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). I've checked only tests/gdn/test_decode_delta_rule.py on H200. I need help with running whole testsuite. ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Optional control to write updated recurrent state into caller-specified output slots (separate from read/input slots). Enabled only in pool (initial-state) mode and requires initial-state info; validates indices shape and integer dtype. Preserves existing behavior when not used. * **Tests** * Added tests covering separate read/write state indexing and the case where output indices equal input indices, validating outputs and pool mutations. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
<!-- .github/pull_request_template.md --> ## Summary Replaces the legacy `gdn_decode_bf16state_cooprow_kernel` and the `gdn_decode_bf16state_mtp_kernel` (ILP=8) with a new **`gdn_wide_vec_kernel`** (LDG.E.128 / STG.E.128 fast path) plus a small-batch `mtp_ilp4` fallback. Drops ~1900 LOC of dead/unused code, adds split-pool support (#2905-compatible) to both surviving BF16 kernels, and ships the OOB fix mirroring upstream PR #3145 — for the BF16 kernels that survived the cleanup. **Supersedes #3118.** That PR's perf delta (T=1 per-call overhead + pool+padding for the ILP kernel) is the first commit on this branch (`8a6e9819`). ## What changes - **New kernel**: `gdn_wide_vec_kernel` — 128 threads/CTA = 8 groups × 16 threads, vec=8 BF16 → LDG.E.128 / STG.E.128, ILP=4 V-rows per thread. Configurable `tile_v ∈ {32, 64, 128}` so the kernel covers small/medium/large `B*HV` work-unit sizes uniformly. - **Pool-only**: BF16 GDN dispatch is strictly pool-mode (matches the production serving contract). Wrapper `gated_delta_rule_decode_pretranspose` auto-promotes legacy non-pool callers internally — public API unchanged. - **Split-pool support** (PR #2905 contract): both surviving BF16 kernels (`gdn_wide_vec_kernel`, `gdn_decode_bf16state_mtp_ilp4_kernel`) natively support `output_state_indices != initial_state_indices`, with bit-equivalent single-pool behavior selected at compile time via `Constexpr[bool] same_pool` for zero-overhead dispatch. - **OOB fix (PR #3145 equivalent)**: `intermediate_states` is indexed by the per-call batch index `i_n` (not the pool-scoped `cache_idx`), so the buffer can be sized `[B, T, HV, V, K]` as production callers expect. Regression test catches the bug; pre-fix triggers `cudaErrorIllegalAddress` in <2 s. ## Removed (~1900 LOC of dead code) | Kernel | Why removed | |---|---| | `gdn_decode_bf16state_cooprow_kernel` (~280 LOC) | Replaced by wide_vec + ILP=4 MTP; had known correctness issues at small batch | | `gdn_decode_bf16state_ilp_kernel` (~740 LOC) | Only reachable at HV<32 with B≥16 — not a Qwen3.5 shape; MTP path covers it | | `gdn_decode_bf16state_mtp_kernel` (ILP=8) (~940 LOC) | After wide_vec extension to split-pool + tile_v=32, mtp_kernel was unreachable | End-state BF16 surface = **2 `@cute.kernel`s in one file**: - `gdn_wide_vec_kernel` — production hot path - `gdn_decode_bf16state_mtp_ilp4_kernel` — small-batch fallback Both pool-only, both split-pool capable, both indexed batch-scoped. ## Speedup vs previous baseline Baseline = pre-wide_vec dispatch (the `mtp_kernel` ILP=8 path, captured on this same branch by monkey-patching `_select_wide_vec_tile_v` to return `None` for every shape). Same harness, same hardware, same config — so the comparison isolates the kernel-level speedup that wide_vec + the cleanup deliver. Setup: B200, HV=64, K=V=128, BF16, qk_l2norm=ON, warmup=5, iters=50, T=1 invoked with `--update-state`, T≥2 invoked with `--cache-intermediate-states`. Kernel time in microseconds (CUPTI). ### Speedup (×, baseline / post-PR) | B | T=1 | T=2 | T=3 | T=4 | T=5 | T=6 | T=7 | T=8 | |-----|-------|-------|-------|-------|-------|-------|-------|-------| | 1 | 1.03× | 1.04× | 1.03× | 1.03× | 1.00× | 1.02× | 1.02× | 1.01× | | 4 | 0.97× | 1.23× | 1.10× | 1.12× | 1.11× | 1.11× | 1.14× | 1.14× | | 8 | 1.08× | 1.11× | 1.11× | 1.12× | 1.13× | 1.15× | 1.15× | 1.14× | | 16 | 1.04× | 1.09× | 1.11× | 1.13× | 1.13× | 1.12× | 1.11× | 1.10× | | 32 | 1.06× | 1.12× | 1.10× | 1.11× | 1.10× | 1.09× | 1.09× | 1.09× | | 64 | 1.04× | 1.11× | 1.08× | 1.09× | 1.06× | 1.07× | 1.08× | 1.06× | | 128 | 1.04× | 1.11× | 1.07× | 1.09× | 1.06× | 1.06× | 1.07× | 1.07× | | 256 | 1.04× | 1.11× | 1.07× | 1.09× | 1.07× | 1.06× | 1.07× | 1.07× | ### Time reduction (%) | B | T=1 | T=2 | T=3 | T=4 | T=5 | T=6 | T=7 | T=8 | |-----|-------|--------|-------|--------|--------|--------|--------|--------| | 1 | +2.8% | +3.5% | +3.2% | +3.3% | +0.1% | +1.7% | +1.8% | +0.7% | | 4 | −3.2% | +18.7% | +9.5% | +10.7% | +9.7% | +10.2% | +12.3% | +12.5% | | 8 | +7.8% | +10.1% | +10.0%| +10.7% | +11.8% | +12.7% | +12.8% | +12.4% | | 16 | +3.4% | +8.5% | +10.0%| +11.2% | +11.4% | +10.7% | +10.3% | +9.4% | | 32 | +6.0% | +10.4% | +9.4% | +9.6% | +8.8% | +8.4% | +8.4% | +7.8% | | 64 | +4.0% | +10.3% | +7.5% | +8.6% | +5.9% | +6.3% | +7.2% | +5.9% | | 128 | +4.2% | +9.8% | +6.5% | +8.4% | +6.0% | +5.9% | +6.3% | +6.6% | | 256 | +4.2% | +9.6% | +6.6% | +8.4% | +6.3% | +5.9% | +6.5% | +6.6% | ### Headline - **T=1 production decode (B≥16)**: 4–6 % time reduction across the full batch sweep — the Qwen3.5 hot path. - **T≥2 with cache=ON (B≥4)**: 6–18 % time reduction at every shape. Best at small-T / mid-batch (B=4 T=2: 1.23×; B=8 T=6: 1.15×). - **Tiny shapes (B=1)**: within ±3 % of baseline (kernel isn't DRAM-bound; small fixed-cost overheads dominate; the ILP=4 fallback was already efficient there). ### Sustained DRAM bandwidth post-PR (TB/s, 8 TB/s peak on B200) | B | T=1 | T=2 | T=3 | T=4 | T=5 | T=6 | T=7 | T=8 | |-----|------|------|------|------|------|------|------|------| | 1 | 1.25 | 1.21 | 1.49 | 1.63 | 1.45 | 1.55 | 1.64 | 1.70 | | 4 | 2.83 | 3.24 | 3.54 | 3.78 | 3.53 | 3.68 | 3.84 | 3.95 | | 8 | 3.97 | 4.09 | 4.40 | 4.54 | 4.42 | 4.55 | 4.59 | 4.61 | | 16 | 4.73 | 4.73 | 5.02 | 5.03 | 4.95 | 4.91 | 4.92 | 4.87 | | 32 | 5.39 | 5.36 | 5.44 | 5.46 | 5.27 | 5.23 | 5.21 | 5.17 | | 64 | 5.83 | 5.76 | 5.80 | 5.77 | 5.45 | 5.44 | 5.45 | 5.33 | | 128 | 6.31 | 6.05 | 6.03 | 6.01 | 5.68 | 5.61 | 5.57 | 5.54 | | 256 | 6.57 | 6.23 | 6.20 | 6.17 | 5.85 | 5.74 | 5.72 | 5.66 | Post-PR peaks at **6.57 TB/s = 82 % of B200 peak DRAM** (T=1 B=256 production decode shape). ### Split-pool With wide_vec now supporting split-pool natively, split-pool matches single-pool to within ±1 % at every measured shape. ## Tests > **513 passed, 0 failed in 18m18s** on B200. Including: 477 existing BF16/wide_vec/pool tests, 12 new split-pool MTP tests, 12 new OOB regression tests covering `pool_size_multiplier ∈ {1, 4}` × `B ∈ {1, 8, 32}` × `T ∈ {2, 4}`, 12 wrapper-level split-pool tests. ## Files changed (4) - `flashinfer/gdn_decode.py` — wrapper auto-promotes BF16 non-pool → pool - `flashinfer/gdn_kernels/gdn_decode_bf16_state.py` — wide_vec inlined; dead kernels removed; split-pool plumbing; OOB fix; same_pool DCE - `tests/gdn/test_decode_delta_rule.py` — split-pool + OOB regression tests - `benchmarks/bench_gdn_decode.py` — `--pool-mode {single,split}` flag <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added `--pool-mode` option to benchmark tool for configuring state pool allocation (`single` or `split` modes). * **Tests** * Expanded BF16 test coverage with regression tests for split-pool semantics and out-of-bounds scenarios; improved batch-dimension handling for intermediate-state comparisons. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
📌 Description
Introduce separate output indices parameter for gated_delta_rule_decode_pretranspose.
This addresses decoded part of feature request in #2873
🔍 Related Issues
#2873
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).I've checked only tests/gdn/test_decode_delta_rule.py on H200. I need help with running whole testsuite.
Reviewer Notes
Summary by CodeRabbit
New Features
Tests