Skip to content

Fix OOB crash in intermediate_states indexing for GDN decode MTP kernel#3145

Closed
wenscarl wants to merge 3 commits into
flashinfer-ai:mainfrom
wenscarl:shuw/gdn_mtp_fix
Closed

Fix OOB crash in intermediate_states indexing for GDN decode MTP kernel#3145
wenscarl wants to merge 3 commits into
flashinfer-ai:mainfrom
wenscarl:shuw/gdn_mtp_fix

Conversation

@wenscarl

@wenscarl wenscarl commented Apr 22, 2026

Copy link
Copy Markdown
Collaborator

co-authored by @YAMY1234

Problem

gdn_decode_bf16state_mtp_kernel crashes with an out-of-bounds GPU memory write when
intermediate_states_buffer is provided and pool_size > B with initial_state_indices
pointing to upper pool slots (the normal serving scenario).

Affected path: flashinfer/gdn_kernels/gdn_decode_bf16_state.py:1873

```python
Before (wrong)
flat_idx = cache_idx * T * HV + i_t * HV + i_hv

After (correct)
flat_idx = i_n * T * HV + i_t * HV + i_hv
```

Root Cause

When intermediate_states support was added to the BF16 state kernel, the author reused
the cache_idx-based addressing pattern from the persistent state pool access:

```python
flat_state_idx = cache_idx * HV + i_hv # correct: h0_source is pool-scoped
```

...and extended it by analogy to add a T dimension for intermediate_states. This is
wrong because the two buffers have different ownership semantics:

  • h0_source (the state pool) is pool-scoped — persists across decode steps, one slot
    per concurrent request in the system → correctly indexed by cache_idx (pool slot)
  • intermediate_states is batch-scoped — a per-forward-pass output capturing states at
    each of the T steps → should be indexed by i_n (batch position)

The float32 counterpart gdn_decode_mtp.py has always used i_n correctly. The BF16
kernel diverged when it was written.

The bug was invisible in existing tests because they always set pool_size = batch_size
with initial_state_indices = arange(B), making cache_idx == i_n identically. The
buggy and correct indexing produce the same result in that configuration. The docstring
describing the buffer shape as [pool_size, T, HV, V, K] further reinforced the incorrect
mental model.

The crash only manifests in the realistic serving scenario where pool_size >> B and
initial_state_indices contains values ≥ B, causing cache_idx-based writes to go
beyond the end of a batch-sized buffer.

Fix

Change cache_idx to i_n at the intermediate_states indexing site, and update the
docstring to reflect the correct buffer shape [B, T, HV, V, K] instead of
[pool_size, T, HV, V, K].

Test Changes

The existing test_gdn_decode_bf16_state_mtp_kernel always used pool_size = batch_size,
which masked the bug entirely. Two changes are made:

  1. pool_size_multiplier parameter added to the helper
    _test_gdn_decode_bf16_state_mtp_kernel. When > 1, it sets
    pool_size = batch_size * pool_size_multiplier and assigns each batch entry to an upper
    pool slot (initial_state_indices = arange(B) + pool_size - B), so cache_idx >= B
    for every entry. The intermediate_states_buffer is allocated with batch_size as its
    first dimension — the semantically correct size — which is smaller than pool_size.
    With the buggy cache_idx indexing the kernel writes out of bounds and crashes; after
    the fix it produces results matching the reference.

  2. @pytest.mark.parametrize("pool_size_multiplier", [1, 4]) added to the public
    test, doubling the existing matrix with a realistic pool-vs-batch configuration. The
    pool_size_multiplier=4, cache_intermediate_states=True cases are the ones that
    directly catch this bug.

Summary by CodeRabbit

  • Bug Fixes

    • Fixed intermediate-state indexing so cached states are written and read per batch slice consistently.
  • Tests

    • Expanded tests to cover pool sizes larger than the batch (parameterized), including remapped indices and updated assertions for those scenarios.
  • Documentation

    • Updated API/docs and launcher expectations to reflect the new intermediate-states buffer shape.

@coderabbitai

coderabbitai Bot commented Apr 22, 2026

Copy link
Copy Markdown
Contributor

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 400f27e5-6161-4bbe-b8f4-2c701e809f88

📥 Commits

Reviewing files that changed from the base of the PR and between d4b9012 and ec69283.

📒 Files selected for processing (2)
  • flashinfer/gdn_kernels/gdn_decode_bf16_state.py
  • tests/gdn/test_decode_delta_rule.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • flashinfer/gdn_kernels/gdn_decode_bf16_state.py
  • tests/gdn/test_decode_delta_rule.py

📝 Walkthrough

Walkthrough

Adjusts the MTP BF16 decode intermediate-state layout and write indexing to use batch-addressed slices (B * T * HV) instead of pool-addressed slots; tests are extended to cover pool sizes larger than the batch to validate the new indexing.

Changes

Cohort / File(s) Summary
Kernel layout & indexing
flashinfer/gdn_kernels/gdn_decode_bf16_state.py
Change documented intermediate_states logical layout from pool-addressed to batch-addressed; when cache_intermediate_states is enabled, write indexing in gdn_decode_bf16state_mtp_kernel uses the launch batch index (i_n) for flat_idx instead of the pooled cache_idx.
Launcher / API docs
flashinfer/gdn_kernels/gdn_decode_bf16_state.py
Update public launcher/launcher API signatures and docs to expect intermediate_states_buffer shaped [B, T, HV, V, K].
Tests — pool-size coverage
tests/gdn/test_decode_delta_rule.py
Parametrize MTP BF16-state test with pool_size_multiplier (e.g., 1,4); enlarge pool when multiplier>1, remap initial_state_indices into upper pool region, allocate intermediate_states_buffer with leading dim [batch_size], and update reference selection and assertion messages to include pool_multiplier context.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • yzh119
  • bkryu
  • kahyunnam
  • yongwww

Poem

🐰 I hopped through memory, counted every lane,

Swapped a pooled address for an i_n name,
I nudged the tests to make pools grow wide,
So each token's state now finds its proper side,
A tiny hop, but indices aligned.

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main fix: correcting an out-of-bounds crash caused by incorrect intermediate_states indexing in the GDN decode MTP kernel.
Description check ✅ Passed The description is comprehensive, addressing problem statement, root cause analysis, the specific fix, and test changes. However, it does not follow the provided template structure with sections like '📌 Description', '🔍 Related Issues', and '🚀 Pull Request Checklist'.
Docstring Coverage ✅ Passed Docstring coverage is 80.00% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes a bug in the gdn_decode_bf16state_mtp_kernel where intermediate states were incorrectly indexed using the pool slot index (cache_idx) instead of the batch index (i_n). To prevent regressions, the test suite has been updated to include scenarios where the pool size is larger than the batch size, specifically by introducing a pool_size_multiplier. Feedback was provided to optimize the test code by removing unnecessary .cpu() and .clone() calls when indexing GPU tensors.

ref_state = input_state_ref_bf16.clone()
# Reference: step through tokens with bf16 state.
# Select only the batch entries' initial states from the pool.
ref_state = input_state_ref_bf16[initial_state_indices.cpu()].clone()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The call to .cpu() on initial_state_indices is unnecessary because both the indices and the tensor being indexed (input_state_ref_bf16) are already on the GPU. Moving indices to the CPU just to index a GPU tensor is inefficient as it may trigger unnecessary host-device synchronization. Additionally, indexing with a tensor in PyTorch always creates a copy, so the .clone() call is redundant.

Suggested change
ref_state = input_state_ref_bf16[initial_state_indices.cpu()].clone()
ref_state = input_state_ref_bf16[initial_state_indices]

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/gdn_kernels/gdn_decode_bf16_state.py (2)

2568-2576: ⚠️ Potential issue | 🟠 Major

Validate the dimensions required by the new flat index.

Line 1880 assumes a [B, T, HV, V, K] layout flattened with T * HV stride. Today buffer_size < B can still OOB, and cache_steps > T is accepted even though the kernel will write with the wrong batch stride.

🛡️ Proposed validation fix
         buffer_size = intermediate_states_buffer.shape[0]
         cache_steps = intermediate_states_buffer.shape[1]
-        assert cache_steps >= T, (
-            f"intermediate_states_buffer dim 1 ({cache_steps}) must be >= T={T}"
+        assert buffer_size >= B, (
+            f"intermediate_states_buffer dim 0 ({buffer_size}) must be >= B={B}"
+        )
+        assert cache_steps == T, (
+            f"intermediate_states_buffer dim 1 ({cache_steps}) must equal T={T}"
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2568 - 2576,
The code flattens intermediate_states_buffer assuming a [B, T, HV, V, K] layout
and a batch stride of T*HV, but it doesn't validate that buffer_size and
cache_steps match the expected B and T (allowing OOB when buffer_size < B or
silent misuse when cache_steps != T). Add explicit validations before reshaping:
assert intermediate_states_buffer.dim() == 5, assert cache_steps == T, assert
buffer_size == B (or equivalently buffer_size * cache_steps == B * T if B is
known), and assert intermediate_states_buffer.shape[2:5] == (HV, V, K); keep the
dtype check for torch.bfloat16 and only then perform the reshape into
intermediate_states to ensure safe indexing in the kernel.

1097-1099: ⚠️ Potential issue | 🟡 Minor

Update stale intermediate-state shape docs.

The implementation now treats intermediate_states as batch-scoped, but these comments still advertise pool-scoped storage. That can mislead callers into allocating/reading the wrong shape.

📝 Proposed doc fix
-    intermediate_states: cute.Tensor,  # [pool_size * T * HV, V, K] as BF16 (or dummy)
+    intermediate_states: cute.Tensor,  # [B * T * HV, V, K] as BF16 (or dummy)
-    intermediate_states: cute.Tensor,  # [pool_size * T * HV, V, K] BF16 (or dummy)
+    intermediate_states: cute.Tensor,  # [B * T * HV, V, K] BF16 (or dummy)
-        intermediate_states_buffer: Optional [pool_size, T, HV, V, K] bf16
+        intermediate_states_buffer: Optional [B, T, HV, V, K] bf16

Also applies to: 2024-2028, 2523-2528

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 1097 - 1099,
The doc comment for the intermediate_states parameter is stale: update the shape
description to reflect that intermediate_states is batch-scoped (not
pool-scoped). Replace the current "[pool_size * T * HV, V, K] as BF16 (or
dummy)" wording with a batch-scoped shape like "[batch_size * T * HV, V, K] as
BF16 (or dummy)" in the parameter docs for intermediate_states in
gdn_decode_bf16_state (and make the identical change in the other occurrences
around the sections noted), so callers allocate/read using batch_size rather
than pool_size.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 1815-1816: Run the project's pre-commit formatters
(black/ruff/pre-commit) on the changed test blocks to remove trailing whitespace
and apply ruff-format rewrites; specifically reformat the new assignment and
surrounding code that uses pool_size, batch_size, and pool_size_multiplier and
the other affected blocks referenced around the same test (the blocks near the
pool_size assignment and the later test sections), ensuring no trailing spaces
remain and ruff/black rules are satisfied.

---

Outside diff comments:
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Around line 2568-2576: The code flattens intermediate_states_buffer assuming a
[B, T, HV, V, K] layout and a batch stride of T*HV, but it doesn't validate that
buffer_size and cache_steps match the expected B and T (allowing OOB when
buffer_size < B or silent misuse when cache_steps != T). Add explicit
validations before reshaping: assert intermediate_states_buffer.dim() == 5,
assert cache_steps == T, assert buffer_size == B (or equivalently buffer_size *
cache_steps == B * T if B is known), and assert
intermediate_states_buffer.shape[2:5] == (HV, V, K); keep the dtype check for
torch.bfloat16 and only then perform the reshape into intermediate_states to
ensure safe indexing in the kernel.
- Around line 1097-1099: The doc comment for the intermediate_states parameter
is stale: update the shape description to reflect that intermediate_states is
batch-scoped (not pool-scoped). Replace the current "[pool_size * T * HV, V, K]
as BF16 (or dummy)" wording with a batch-scoped shape like "[batch_size * T *
HV, V, K] as BF16 (or dummy)" in the parameter docs for intermediate_states in
gdn_decode_bf16_state (and make the identical change in the other occurrences
around the sections noted), so callers allocate/read using batch_size rather
than pool_size.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c9765bd7-41f8-4853-982a-08bc53382b31

📥 Commits

Reviewing files that changed from the base of the PR and between fb3bb44 and bcdcdc4.

📒 Files selected for processing (2)
  • flashinfer/gdn_kernels/gdn_decode_bf16_state.py
  • tests/gdn/test_decode_delta_rule.py

Comment thread tests/gdn/test_decode_delta_rule.py Outdated
@kahyunnam

Copy link
Copy Markdown
Member

/bot run

@flashinfer-bot

Copy link
Copy Markdown
Collaborator

GitLab MR !590 has been created, and the CI pipeline #49302059 is currently running. I'll report back once the pipeline job completes.

@kahyunnam kahyunnam left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/gdn/test_decode_delta_rule.py (1)

1883-1883: Minor: .cpu() round-trip is unnecessary for GPU tensor indexing.

input_state_ref_bf16 lives on CUDA, and initial_state_indices is a CUDA int32 tensor — PyTorch can index directly without moving indices to CPU (the gather then forces a D2H sync). Not a correctness issue (dtype int32 is accepted here), just a small cleanup.

♻️ Proposed simplification
-    ref_state = input_state_ref_bf16[initial_state_indices.cpu()].clone()
+    ref_state = input_state_ref_bf16[initial_state_indices.long()].clone()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_delta_rule.py` at line 1883, The code does an
unnecessary device round-trip by calling .cpu() on initial_state_indices when
indexing a CUDA tensor; update the indexing of input_state_ref_bf16 by removing
the .cpu() call so ref_state =
input_state_ref_bf16[initial_state_indices].clone() (i.e., locate the expression
constructing ref_state and drop the .cpu() to allow direct CUDA-to-CUDA indexing
with initial_state_indices).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/gdn/test_decode_delta_rule.py`:
- Line 1883: The code does an unnecessary device round-trip by calling .cpu() on
initial_state_indices when indexing a CUDA tensor; update the indexing of
input_state_ref_bf16 by removing the .cpu() call so ref_state =
input_state_ref_bf16[initial_state_indices].clone() (i.e., locate the expression
constructing ref_state and drop the .cpu() to allow direct CUDA-to-CUDA indexing
with initial_state_indices).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 8945e231-c4fc-4839-add7-028b5835aede

📥 Commits

Reviewing files that changed from the base of the PR and between bcdcdc4 and d4b9012.

📒 Files selected for processing (1)
  • tests/gdn/test_decode_delta_rule.py

@wenscarl wenscarl requested a review from kahyunnam April 23, 2026 20:42
Comment thread flashinfer/gdn_kernels/gdn_decode_bf16_state.py Outdated

@ameynaik-hub ameynaik-hub left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix.
The code fix looks correct to me. One small cleanup before merge: a few comments/docstrings still describe intermediate_states as pool-scoped, but this PR makes it batch-scoped.

Comment thread flashinfer/gdn_kernels/gdn_decode_bf16_state.py Outdated
Comment thread flashinfer/gdn_kernels/gdn_decode_bf16_state.py Outdated
@wenscarl wenscarl requested a review from ameynaik-hub April 27, 2026 15:09
wenscarl and others added 3 commits April 27, 2026 10:09
Fix comment/docstring descriptions of `intermediate_states` to reflect
the batch-scoped shape [B * T * HV, V, K] / [B, T, HV, V, K] instead
of the outdated pool-scoped shape, as noted in PR review by ameynaik-hub.

AI-assisted

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
ameynaik-hub added a commit to ameynaik-hub/flashinfer that referenced this pull request Apr 28, 2026
… PR flashinfer-ai#3145)

The ``intermediate_states_buffer`` is BATCH-scoped — shape ``[B, T, HV, V, K]``
— but both the wide_vec kernel (line 1115) and the mtp_ilp4 kernel
(line 621) were indexing it with ``cache_idx * T * HV + i_t * HV + i_hv``
where ``cache_idx`` is the POOL slot from ``initial_state_indices[i_n]``.

When ``pool_size > B`` (every realistic serving config) and
``initial_state_indices`` points at slots ``>= B`` (e.g. middle of a
1024-slot pool while servicing a B=32 batch), ``cache_idx * T * HV``
exceeds the buffer's ``B * T * HV`` extent and the
``cute.local_tile`` write goes off the end of the cache buffer ->
``cudaErrorIllegalAddress`` or silent memory corruption.

This is the same bug upstream PR flashinfer-ai#3145 fixed in the now-removed
``gdn_decode_bf16state_mtp_kernel``; both surviving BF16 kernels
inherited the incorrect pattern.

Fix:
- ``gdn_decode_bf16state_mtp_ilp4_kernel``: ``flat_idx = i_n * T * HV + ...``
  (was ``cache_idx * T * HV + ...``).
- ``gdn_wide_vec_kernel``: same.
- Dispatcher (both ``gated_delta_rule_mtp`` and
  ``gated_delta_rule_mtp_wide_vec``): assert
  ``intermediate_states_buffer.shape[0] == B`` and reshape using ``B``
  rather than ``buffer_size``. Also updates the comment / docstring to
  call out batch-scoped semantics explicitly.

Adds ``test_gdn_decode_bf16_state_mtp_pool_larger_than_batch`` (12
cases) which parametrizes ``pool_size_multiplier in {1, 4}`` and
``batch_size in {1, 8, 32}`` and ``seq_len in {2, 4}`` so both the
ilp4 path (B=1) and the wide_vec path (B=8/32) are exercised with
pool indices pointing at the upper half of a 4*B-slot pool. Verified
the test catches the bug: re-introducing the
``cache_idx * T * HV`` form makes the test fail with
``cudaErrorIllegalAddress``; reverting the line makes it pass again.

AI-assisted by Claude Code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
@ameynaik-hub ameynaik-hub mentioned this pull request Apr 28, 2026
5 tasks
@nvpohanh

nvpohanh commented May 4, 2026

Copy link
Copy Markdown
Contributor

/bot run

@flashinfer-bot

Copy link
Copy Markdown
Collaborator

GitLab MR !590 has been updated with latest changes, and the CI pipeline #50199162 is currently running. I'll report back once the pipeline job completes.

kahyunnam pushed a commit that referenced this pull request May 6, 2026
<!-- .github/pull_request_template.md -->

  ## Summary

  Replaces the legacy `gdn_decode_bf16state_cooprow_kernel` and the
  `gdn_decode_bf16state_mtp_kernel` (ILP=8) with a new
  **`gdn_wide_vec_kernel`** (LDG.E.128 / STG.E.128 fast path) plus a
  small-batch `mtp_ilp4` fallback. Drops ~1900 LOC of dead/unused code,
  adds split-pool support (#2905-compatible) to both surviving BF16
  kernels, and ships the OOB fix mirroring upstream PR #3145 — for the
  BF16 kernels that survived the cleanup.

  **Supersedes #3118.** That PR's perf delta (T=1 per-call overhead +
  pool+padding for the ILP kernel) is the first commit on this branch
  (`8a6e9819`). 
  
  ## What changes

  - **New kernel**: `gdn_wide_vec_kernel` — 128 threads/CTA = 8 groups
    × 16 threads, vec=8 BF16 → LDG.E.128 / STG.E.128, ILP=4 V-rows per
    thread. Configurable `tile_v ∈ {32, 64, 128}` so the kernel covers
    small/medium/large `B*HV` work-unit sizes uniformly.
  - **Pool-only**: BF16 GDN dispatch is strictly pool-mode (matches
    the production serving contract). Wrapper
    `gated_delta_rule_decode_pretranspose` auto-promotes legacy non-pool
    callers internally — public API unchanged.
  - **Split-pool support** (PR #2905 contract): both surviving BF16
kernels (`gdn_wide_vec_kernel`, `gdn_decode_bf16state_mtp_ilp4_kernel`)
    natively support `output_state_indices != initial_state_indices`,
    with bit-equivalent single-pool behavior selected at compile time
    via `Constexpr[bool] same_pool` for zero-overhead dispatch.
  - **OOB fix (PR #3145 equivalent)**: `intermediate_states` is indexed
    by the per-call batch index `i_n` (not the pool-scoped `cache_idx`),
    so the buffer can be sized `[B, T, HV, V, K]` as production callers
    expect. Regression test catches the bug; pre-fix triggers
    `cudaErrorIllegalAddress` in <2 s.

  ## Removed (~1900 LOC of dead code)

  | Kernel | Why removed |
  |---|---|
| `gdn_decode_bf16state_cooprow_kernel` (~280 LOC) | Replaced by
wide_vec + ILP=4 MTP; had known correctness issues at small batch |
| `gdn_decode_bf16state_ilp_kernel` (~740 LOC) | Only reachable at HV<32
with B≥16 — not a Qwen3.5 shape; MTP path covers it |
| `gdn_decode_bf16state_mtp_kernel` (ILP=8) (~940 LOC) | After wide_vec
extension to split-pool + tile_v=32, mtp_kernel was unreachable |

  End-state BF16 surface = **2 `@cute.kernel`s in one file**:
  - `gdn_wide_vec_kernel` — production hot path
  - `gdn_decode_bf16state_mtp_ilp4_kernel` — small-batch fallback

  Both pool-only, both split-pool capable, both indexed batch-scoped.

  ## Speedup vs previous baseline

Baseline = pre-wide_vec dispatch (the `mtp_kernel` ILP=8 path, captured
  on this same branch by monkey-patching `_select_wide_vec_tile_v` to
  return `None` for every shape). Same harness, same hardware, same
  config — so the comparison isolates the kernel-level speedup that
  wide_vec + the cleanup deliver.

  Setup: B200, HV=64, K=V=128, BF16, qk_l2norm=ON, warmup=5, iters=50,
  T=1 invoked with `--update-state`, T≥2 invoked with
  `--cache-intermediate-states`. Kernel time in microseconds (CUPTI).

  ### Speedup (×, baseline / post-PR)

| B | T=1 | T=2 | T=3 | T=4 | T=5 | T=6 | T=7 | T=8 |
|-----|-------|-------|-------|-------|-------|-------|-------|-------|
| 1 | 1.03× | 1.04× | 1.03× | 1.03× | 1.00× | 1.02× | 1.02× | 1.01× |
| 4 | 0.97× | 1.23× | 1.10× | 1.12× | 1.11× | 1.11× | 1.14× | 1.14× |
| 8 | 1.08× | 1.11× | 1.11× | 1.12× | 1.13× | 1.15× | 1.15× | 1.14× |
| 16 | 1.04× | 1.09× | 1.11× | 1.13× | 1.13× | 1.12× | 1.11× | 1.10× |
| 32 | 1.06× | 1.12× | 1.10× | 1.11× | 1.10× | 1.09× | 1.09× | 1.09× |
| 64 | 1.04× | 1.11× | 1.08× | 1.09× | 1.06× | 1.07× | 1.08× | 1.06× |
| 128 | 1.04× | 1.11× | 1.07× | 1.09× | 1.06× | 1.06× | 1.07× | 1.07× |
| 256 | 1.04× | 1.11× | 1.07× | 1.09× | 1.07× | 1.06× | 1.07× | 1.07× |

  ### Time reduction (%)

| B | T=1 | T=2 | T=3 | T=4 | T=5 | T=6 | T=7 | T=8 |

|-----|-------|--------|-------|--------|--------|--------|--------|--------|
| 1 | +2.8% | +3.5% | +3.2% | +3.3% | +0.1% | +1.7% | +1.8% | +0.7% |
| 4 | −3.2% | +18.7% | +9.5% | +10.7% | +9.7% | +10.2% | +12.3% | +12.5%
|
| 8 | +7.8% | +10.1% | +10.0%| +10.7% | +11.8% | +12.7% | +12.8% |
+12.4% |
| 16 | +3.4% | +8.5% | +10.0%| +11.2% | +11.4% | +10.7% | +10.3% | +9.4%
|
| 32 | +6.0% | +10.4% | +9.4% | +9.6% | +8.8% | +8.4% | +8.4% | +7.8% |
| 64 | +4.0% | +10.3% | +7.5% | +8.6% | +5.9% | +6.3% | +7.2% | +5.9% |
| 128 | +4.2% | +9.8% | +6.5% | +8.4% | +6.0% | +5.9% | +6.3% | +6.6% |
| 256 | +4.2% | +9.6% | +6.6% | +8.4% | +6.3% | +5.9% | +6.5% | +6.6% |

  ### Headline

- **T=1 production decode (B≥16)**: 4–6 % time reduction across the full
batch sweep — the Qwen3.5 hot path.
- **T≥2 with cache=ON (B≥4)**: 6–18 % time reduction at every shape.
Best at small-T / mid-batch (B=4 T=2: 1.23×; B=8 T=6: 1.15×).
- **Tiny shapes (B=1)**: within ±3 % of baseline (kernel isn't
DRAM-bound; small fixed-cost overheads dominate; the ILP=4 fallback was
already efficient there).

  ### Sustained DRAM bandwidth post-PR (TB/s, 8 TB/s peak on B200)

  |  B  | T=1  | T=2  | T=3  | T=4  | T=5  | T=6  | T=7  | T=8  |
  |-----|------|------|------|------|------|------|------|------|
  |   1 | 1.25 | 1.21 | 1.49 | 1.63 | 1.45 | 1.55 | 1.64 | 1.70 |
  |   4 | 2.83 | 3.24 | 3.54 | 3.78 | 3.53 | 3.68 | 3.84 | 3.95 |
  |   8 | 3.97 | 4.09 | 4.40 | 4.54 | 4.42 | 4.55 | 4.59 | 4.61 |
  |  16 | 4.73 | 4.73 | 5.02 | 5.03 | 4.95 | 4.91 | 4.92 | 4.87 |
  |  32 | 5.39 | 5.36 | 5.44 | 5.46 | 5.27 | 5.23 | 5.21 | 5.17 |
  |  64 | 5.83 | 5.76 | 5.80 | 5.77 | 5.45 | 5.44 | 5.45 | 5.33 |
  | 128 | 6.31 | 6.05 | 6.03 | 6.01 | 5.68 | 5.61 | 5.57 | 5.54 |
  | 256 | 6.57 | 6.23 | 6.20 | 6.17 | 5.85 | 5.74 | 5.72 | 5.66 |

Post-PR peaks at **6.57 TB/s = 82 % of B200 peak DRAM** (T=1 B=256
production decode shape).

  ### Split-pool

  With wide_vec now supporting split-pool natively, split-pool
  matches single-pool to within ±1 % at every measured shape.

  ## Tests

  > **513 passed, 0 failed in 18m18s** on B200.

Including: 477 existing BF16/wide_vec/pool tests, 12 new split-pool MTP
tests, 12 new OOB regression tests covering `pool_size_multiplier ∈ {1,
4}` × `B ∈ {1, 8, 32}` × `T ∈ {2, 4}`, 12 wrapper-level split-pool
tests.

  ## Files changed (4)

- `flashinfer/gdn_decode.py` — wrapper auto-promotes BF16 non-pool →
pool
- `flashinfer/gdn_kernels/gdn_decode_bf16_state.py` — wide_vec inlined;
dead kernels removed; split-pool plumbing; OOB fix; same_pool DCE
- `tests/gdn/test_decode_delta_rule.py` — split-pool + OOB regression
tests
  - `benchmarks/bench_gdn_decode.py` — `--pool-mode {single,split}` flag

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added `--pool-mode` option to benchmark tool for configuring state
pool allocation (`single` or `split` modes).

* **Tests**
* Expanded BF16 test coverage with regression tests for split-pool
semantics and out-of-bounds scenarios; improved batch-dimension handling
for intermediate-state comparisons.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@nvpohanh

Copy link
Copy Markdown
Contributor

@wenscarl Could you rebase this?

ameynaik-hub added a commit to ameynaik-hub/flashinfer that referenced this pull request Jun 4, 2026
…ty with bf16)

This PR fixes the two issues vLLM hit with the fp32 GDN MTP decode path:

  Correctness: the wrapper's `.reshape(pool*HV, V, K)` silently densifies a
               non-contiguous (page-strided) pool. The kernel then writes
               that throwaway copy, dropping updates for vLLM-style pools.
  Perf:        the densification copy runs every call, regardless of whether
               state actually changed.

The fix is in two layers:

1. Native 4D-pool support in both fp32 MTP kernels (gdn_decode_mtp.py):

   - `gdn_verify_kernel_mtp` (warp-spec, B*HV > 128) and
     `gdn_verify_kernel_mtp_inline` (small batch) each gain a
     `use_pool_indexing: cutlass.Constexpr[bool]` switch.
   - Once per CTA the kernel builds a 2D (V, K) view onto the pool slot.
     The constexpr branch is the only site that knows the actual layout:
       * True : 4D `[pool, HV, V, K]` — slice with (cache_idx, i_hv, :, :);
                works for non-contiguous strided pools (vLLM).
       * False: 3D `[pool*HV, V, K]` — slice with (flat_state_idx, :, :);
                free reshape view of a contiguous pool (existing fast path).
   - All ~43 `cute.local_tile(h0_source, (1, 1, vec_size), (flat_*_idx, X,
     lane))` call sites are replaced with the view-based form
     `cute.local_tile(h_*_view, (1, vec_size), (X, lane))`. Same memory
     accesses, same instruction stream for the contiguous fast path.
   - `flat_write_idx` and `write_cache_idx` are pre-declared / clamped to
     satisfy CuTe DSL's "no variable out of control flow" rule. The
     original-sign signal `write_cache_idx_raw` drives the per-site
     write-skip gates so negative output indices still suppress the
     writeback (preserving fp32 padding-skip semantics).
   - Launchers extract `v_dim` / `k_dim` from the correct layout axes
     depending on `use_pool_indexing`.
   - `run_mtp_decode` cache key includes `use_pool_indexing` plus
     `tuple(h0_source.stride())` (only when use_pool_indexing=True) so
     different page-stride patterns each get their own compile and don't
     alias to a stale binary.

2. Wrapper parity with the bf16 MTP path (gdn_decode.py:gated_delta_rule_mtp):

   - Add `output_state_indices` parameter (mirrors the bf16 wrapper).
     Defaults to `initial_state_indices`. Negative write indices skip the
     writeback for that batch slot.
   - Drop redundant `.to(torch.float32)` casts (state was already asserted
     fp32). Validate `intermediate_states_buffer.dtype == float32`.
   - Make `.contiguous()` on the intermediate buffer conditional, matching
     the bf16 wrapper.
   - Dispatch: when `initial_state.is_contiguous()`, take the existing 3D
     fast path (free reshape view, `use_pool_indexing=False`). Else, pass
     the 4D tensor through unchanged with `use_pool_indexing=True` — the
     kernel writes the strided pool in place, no densification, no
     scatter step.
   - `intermediate_states_buffer` is still flat-indexed by batch (i_n), so
     a non-contiguous buffer still triggers a staging copy + scatter back.
     Native 4D for the intermediate buffer is a separate follow-up.

   Additionally, `gated_delta_rule_decode_pretranspose` now routes
   fp32 + T>1 (pool mode) through `gated_delta_rule_mtp` so the dispatcher
   has a single entry point.

Tests (test_decode_delta_rule.py):

  - `test_mtp_fp32_state_pool` (24 parametric variants): non-trivial
    indices, optional separate output_state_indices, optional intermediate
    caching. Verifies gather→direct reference parity, write destination,
    and that non-targeted pool slots are bit-exactly unchanged.
  - `test_mtp_fp32_state_pool_non_contiguous` (8 parametric variants:
    B in {1, 4} x T in {2, 4} x stride_multiplier in {2, 3}). Allocates an
    oversized HV-stride backing tensor and slices every Nth head-slot to
    produce a strided 4D pool. Verifies output parity with a contiguous
    reference, that the strided pool itself receives the updates (the
    exact regression guard), and that interleaved non-selected backing
    slots are bit-exactly unchanged (proves no densification copy).

Validation:

  - Correctness: 149/149 pass across the new non-contig sweep, existing
    contiguous fp32 MTP sweep (B 1..512 x T 2..8), bf16 verify,
    pretranspose pool, negative_indices, and all_padding regressions.
  - Perf (HV=64, B200, --update-state --cache-intermediate-states, 100
    iters / 20 warmup, contiguous-pool path): median delta = 0%, mean
    delta = +0.05% across 72 (BS, T) cells vs the pre-edit baseline. The
    constexpr branch + view indirection compile away on the contiguous
    fast path. Cell-level deltas within +/-5%, within run-to-run noise.
  - Perf (contig vs strided pool, same B200): kernel-level delta is in
    the noise (+/-2% for B >= 16). The strided path's win is in *not*
    doing the per-call densification copy the old code would have done.

Review feedback addressed (CodeRabbit + Gemini):

  - Device+dtype alignment of write indices: use
    `output_state_indices.to(initial_state_indices)` (tensor target,
    not just dtype) so a CPU-side output_state_indices is realigned to
    the kernel's device automatically. No-op if already aligned.

  - Strict intermediate-buffer validation in the wrapper:
      * assert buffer_size >= B (kernel indexes by batch i_n in [0, B);
        a smaller buffer caused silent OOB writes — the bf16 wrapper
        already had this assert via PR flashinfer-ai#3145; fp32 wrapper was missing
        the parallel).
      * assert intermediate_states_buffer.is_contiguous() (caller
        contract, removes the silent staging-copy fallback).
      * Replace .reshape() with .view() so the no-copy contract is
        enforced at runtime — raises if the layout ever doesn't support
        a view.
      * Removed the now-unused post-kernel scatter-back block.

  - Test coverage: test_mtp_fp32_state_pool now passes a parallel
    intermediate buffer to the reference path and verifies the cached
    intermediate states match cell-for-cell when
    cache_intermediate_states=True.

Re-verified: 34/34 spot-check tests pass; HV=64 BS×T perf sweep shows
mean Δ = +0.10%, median 0% vs pre-review-fix (kernel code byte-identical;
review fixes are wrapper-only).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
@wenscarl

wenscarl commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator Author

adopted by another merged PR.

@wenscarl wenscarl closed this Jun 8, 2026
ameynaik-hub added a commit to ameynaik-hub/flashinfer that referenced this pull request Jun 10, 2026
…ty with bf16)

This PR fixes the two issues vLLM hit with the fp32 GDN MTP decode path:

  Correctness: the wrapper's `.reshape(pool*HV, V, K)` silently densifies a
               non-contiguous (page-strided) pool. The kernel then writes
               that throwaway copy, dropping updates for vLLM-style pools.
  Perf:        the densification copy runs every call, regardless of whether
               state actually changed.

The fix is in two layers:

1. Native 4D-pool support in both fp32 MTP kernels (gdn_decode_mtp.py):

   - `gdn_verify_kernel_mtp` (warp-spec, B*HV > 128) and
     `gdn_verify_kernel_mtp_inline` (small batch) each gain a
     `use_pool_indexing: cutlass.Constexpr[bool]` switch.
   - Once per CTA the kernel builds a 2D (V, K) view onto the pool slot.
     The constexpr branch is the only site that knows the actual layout:
       * True : 4D `[pool, HV, V, K]` — slice with (cache_idx, i_hv, :, :);
                works for non-contiguous strided pools (vLLM).
       * False: 3D `[pool*HV, V, K]` — slice with (flat_state_idx, :, :);
                free reshape view of a contiguous pool (existing fast path).
   - All ~43 `cute.local_tile(h0_source, (1, 1, vec_size), (flat_*_idx, X,
     lane))` call sites are replaced with the view-based form
     `cute.local_tile(h_*_view, (1, vec_size), (X, lane))`. Same memory
     accesses, same instruction stream for the contiguous fast path.
   - `flat_write_idx` and `write_cache_idx` are pre-declared / clamped to
     satisfy CuTe DSL's "no variable out of control flow" rule. The
     original-sign signal `write_cache_idx_raw` drives the per-site
     write-skip gates so negative output indices still suppress the
     writeback (preserving fp32 padding-skip semantics).
   - Launchers extract `v_dim` / `k_dim` from the correct layout axes
     depending on `use_pool_indexing`.
   - `run_mtp_decode` cache key includes `use_pool_indexing` plus
     `tuple(h0_source.stride())` (only when use_pool_indexing=True) so
     different page-stride patterns each get their own compile and don't
     alias to a stale binary.

2. Wrapper parity with the bf16 MTP path (gdn_decode.py:gated_delta_rule_mtp):

   - Add `output_state_indices` parameter (mirrors the bf16 wrapper).
     Defaults to `initial_state_indices`. Negative write indices skip the
     writeback for that batch slot.
   - Drop redundant `.to(torch.float32)` casts (state was already asserted
     fp32). Validate `intermediate_states_buffer.dtype == float32`.
   - Make `.contiguous()` on the intermediate buffer conditional, matching
     the bf16 wrapper.
   - Dispatch: when `initial_state.is_contiguous()`, take the existing 3D
     fast path (free reshape view, `use_pool_indexing=False`). Else, pass
     the 4D tensor through unchanged with `use_pool_indexing=True` — the
     kernel writes the strided pool in place, no densification, no
     scatter step.
   - `intermediate_states_buffer` is still flat-indexed by batch (i_n), so
     a non-contiguous buffer still triggers a staging copy + scatter back.
     Native 4D for the intermediate buffer is a separate follow-up.

   Additionally, `gated_delta_rule_decode_pretranspose` now routes
   fp32 + T>1 (pool mode) through `gated_delta_rule_mtp` so the dispatcher
   has a single entry point.

Tests (test_decode_delta_rule.py):

  - `test_mtp_fp32_state_pool` (24 parametric variants): non-trivial
    indices, optional separate output_state_indices, optional intermediate
    caching. Verifies gather→direct reference parity, write destination,
    and that non-targeted pool slots are bit-exactly unchanged.
  - `test_mtp_fp32_state_pool_non_contiguous` (8 parametric variants:
    B in {1, 4} x T in {2, 4} x stride_multiplier in {2, 3}). Allocates an
    oversized HV-stride backing tensor and slices every Nth head-slot to
    produce a strided 4D pool. Verifies output parity with a contiguous
    reference, that the strided pool itself receives the updates (the
    exact regression guard), and that interleaved non-selected backing
    slots are bit-exactly unchanged (proves no densification copy).

Validation:

  - Correctness: 149/149 pass across the new non-contig sweep, existing
    contiguous fp32 MTP sweep (B 1..512 x T 2..8), bf16 verify,
    pretranspose pool, negative_indices, and all_padding regressions.
  - Perf (HV=64, B200, --update-state --cache-intermediate-states, 100
    iters / 20 warmup, contiguous-pool path): median delta = 0%, mean
    delta = +0.05% across 72 (BS, T) cells vs the pre-edit baseline. The
    constexpr branch + view indirection compile away on the contiguous
    fast path. Cell-level deltas within +/-5%, within run-to-run noise.
  - Perf (contig vs strided pool, same B200): kernel-level delta is in
    the noise (+/-2% for B >= 16). The strided path's win is in *not*
    doing the per-call densification copy the old code would have done.

Review feedback addressed (CodeRabbit + Gemini):

  - Device+dtype alignment of write indices: use
    `output_state_indices.to(initial_state_indices)` (tensor target,
    not just dtype) so a CPU-side output_state_indices is realigned to
    the kernel's device automatically. No-op if already aligned.

  - Strict intermediate-buffer validation in the wrapper:
      * assert buffer_size >= B (kernel indexes by batch i_n in [0, B);
        a smaller buffer caused silent OOB writes — the bf16 wrapper
        already had this assert via PR flashinfer-ai#3145; fp32 wrapper was missing
        the parallel).
      * assert intermediate_states_buffer.is_contiguous() (caller
        contract, removes the silent staging-copy fallback).
      * Replace .reshape() with .view() so the no-copy contract is
        enforced at runtime — raises if the layout ever doesn't support
        a view.
      * Removed the now-unused post-kernel scatter-back block.

  - Test coverage: test_mtp_fp32_state_pool now passes a parallel
    intermediate buffer to the reference path and verifies the cached
    intermediate states match cell-for-cell when
    cache_intermediate_states=True.

Re-verified: 34/34 spot-check tests pass; HV=64 BS×T perf sweep shows
mean Δ = +0.10%, median 0% vs pre-review-fix (kernel code byte-identical;
review fixes are wrapper-only).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants