Skip to content

[gdn] support non-contiguous state for decoding#2727

Merged
kahyunnam merged 1 commit into
flashinfer-ai:mainfrom
ZJY0516:gdn_decode_non_con
Mar 25, 2026
Merged

[gdn] support non-contiguous state for decoding#2727
kahyunnam merged 1 commit into
flashinfer-ai:mainfrom
ZJY0516:gdn_decode_non_con

Conversation

@ZJY0516

@ZJY0516 ZJY0516 commented Mar 9, 2026

Copy link
Copy Markdown
Contributor

📌 Description

vllm uses non-contiguous state for gdn. Make flashinfer also support it

🔍 Related Issues

#2521
#2687

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
@gemini-code-assist

Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces support for non-contiguous state tensors within the gated delta rule decode pretranspose operation. By adapting the Python wrapper and CUDA kernels, the system can now efficiently process state data with flexible memory layouts, such as page-strided pools, without requiring explicit contiguity, thereby improving memory management and potentially enabling new use cases. The changes ensure that K-contiguity is maintained for performance while allowing non-contiguity in other dimensions.

Highlights

  • Non-Contiguous State Support: The core gated_delta_rule_decode_pretranspose function and its underlying CUDA kernels have been updated to correctly handle non-contiguous initial_state tensors, specifically for pooled states. This allows for more flexible memory layouts, such as page-strided pools.
  • Kernel Logic Adaptation: The CUDA kernels (gdn_decode_kernel_small_batch_pretranspose and gdn_decode_kernel_big_batch_pretranspose) were modified to correctly index and update 4D [pool_size, HV, V, K] state tensors, removing the previous requirement for these tensors to be contiguous.
  • K-Contiguity Enforcement: A new assertion was added in gated_delta_rule_decode_pretranspose to ensure that the initial_state remains K-contiguous (stride[-1] == 1), which is crucial for the kernel's pointer arithmetic and performance.
  • Kernel Caching Updates: The kernel caching mechanism (_get_compiled_decode_kernel) was enhanced to include pool_size and the strides of the h0_source tensor in its cache key, ensuring that different non-contiguous layouts compile and cache correctly.
  • New Test Case: A new test file (test_decode_pretranspose_noncontiguous_pool.py) was added to thoroughly validate the functionality of the gated delta rule decode with non-contiguous pooled states, comparing results against a gathered contiguous reference.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flashinfer/gdn_decode.py
    • Removed is_contiguous() assertion for initial_state when use_pool is true.
    • Added an assertion to ensure initial_state is K-contiguous (stride[-1] == 1).
    • Modified h0_source construction to preserve the original [pool_size, HV, V, K] view for pooled states.
  • flashinfer/gdn_kernels/gdn_decode_pretranspose.py
    • Updated kernel logic to correctly handle 4D [pool_size, HV, V, K] h0_source layout for pooled states.
    • Adjusted gSrc_batch and gDst indexing for pooled states to access the 4D tensor correctly.
    • Modified _get_compiled_decode_kernel to include pool_size and strides in the kernel cache key.
    • Removed symbolic pool dimension creation for h0_source_tensor during compilation, using the actual tensor view instead.
    • Corrected a debug line calculation from v_dim * k_dim * batch_size * 4 / 1024 / 1024 to v_dim * k_dim * 4 / 1024 / 1024.
  • tests/gdn/test_decode_pretranspose_noncontiguous_pool.py
    • Added a new test file to validate gated_delta_rule_decode_pretranspose with non-contiguous pooled states.
    • Implemented a test case that creates a non-contiguous initial_state and verifies output and state updates against a reference.
Activity
  • No human activity has been recorded on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@ZJY0516 ZJY0516 changed the title [gdn] support non-contiguous state [gdn] support non-contiguous state for decoding Mar 9, 2026
@coderabbitai

coderabbitai Bot commented Mar 9, 2026

Copy link
Copy Markdown
Contributor
📝 Walkthrough

Walkthrough

This pull request adds pool-indexing support to the GDN pretranspose decode path. It introduces K-contiguity assertion, relaxes pool-path state handling to preserve non-contiguous layouts, and extends kernel compilation caching with pool metadata parameters (pool_size, stride0-3). A new test validates noncontiguous pool handling.

Changes

Cohort / File(s) Summary
Core Decode Logic
flashinfer/gdn_decode.py
Added K-contiguity assertion in pretranspose path with descriptive error message; relaxed pool-path h0_source handling to preserve original tensor layout instead of reshaping, enabling support for non-contiguous, page-strided pools.
Pretranspose Kernel and Compilation
flashinfer/gdn_kernels/gdn_decode_pretranspose.py
Introduced pool-indexing aware state slicing throughout kernels (pool_idx, i_hv based indexing vs batch_idx); added 4D state layout handling [pool_size, HV, V, K] for pool path; extended _get_compiled_decode_kernel signature with pool_size and 4 stride parameters; updated cache keys to include pool metadata; replaced symbolic tensor creation with real tensor views via from_dlpack for accurate stride preservation.
Pool Noncontiguity Test
tests/gdn/test_decode_pretranspose_noncontiguous_pool.py
New comprehensive test module validating noncontiguous pool tensor handling in pretranspose decode, including SM90+ capability check, page-strided pool construction, parametrized page_gap testing, and cross-validation against direct-state reference path with strict tolerances.

Sequence Diagram(s)

sequenceDiagram
    participant Host as Host (Python)
    participant Cache as Kernel Cache
    participant Compiler as CUDA Compiler
    participant Kernel as Pretranspose Kernel
    participant State as State Tensor

    rect rgba(100, 150, 200, 0.5)
    Note over Host: Prepare h0_source with pool indexing enabled
    Host->>Host: Validate initial_state K-contiguity
    Host->>Host: Set h0_source = initial_state<br/>(preserve non-contiguous layout)
    end

    rect rgba(150, 100, 200, 0.5)
    Note over Host,Compiler: Kernel Compilation with Pool Metadata
    Host->>Host: Extract pool_size, stride0-3 from h0_source
    Host->>Host: Build cache_key with pool parameters
    Host->>Cache: Check cache for kernel
    alt Cache Hit
        Cache-->>Kernel: Return compiled kernel
    else Cache Miss
        Host->>Compiler: Compile kernel with pool metadata
        Compiler-->>Cache: Store compiled kernel
    end
    end

    rect rgba(200, 150, 100, 0.5)
    Note over Host,Kernel: Kernel Execution with Pool-based Indexing
    Host->>Kernel: Launch with pool_idx, i_hv indexing
    Kernel->>State: Read gSrc_batch [pool_size, HV, V, K]
    Kernel->>Kernel: Process with pool-aware state slicing
    Kernel->>State: Write gDst [pool_size, HV, V, K]
    Kernel-->>Host: Return output & updated_state
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

Suggested labels

model: qwen3-next

Suggested reviewers

  • kaixih
  • bkryu
  • kahyunnam
  • yzh119
  • cyx-6
  • nvmbreughe
  • jimmyzho

Poem

🐰 A rabbit hops through stride and pool,
Non-contiguous paths? No problem, cool!
Kernels now know their page-wise way,
Cache keys bloom with strides in play!
Pool-indexing hops—efficient and spry! 🌿

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ❓ Inconclusive The PR description is minimal and lacks detail about the implementation and design decisions. Expand the description with specific details about how non-contiguous state support was implemented, what changes were made to handle pool-indexing and stride information, and any trade-offs or design decisions.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely summarizes the main objective: adding support for non-contiguous state in GDN decoding, which aligns with the code changes made across three files.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for non-contiguous state tensors in the Gated Delta Rule (GDN) decode functionality by updating the Python wrapper and modifying the CuTe DSL kernels to handle 4D non-contiguous memory layout. However, a high-severity security issue has been identified: the kernels lack validation for user-provided pool indices, which can lead to out-of-bounds memory access on the GPU in multi-tenant environments. It is recommended to add a check in the Python API layer to validate that all indices in initial_state_indices are within the valid range of the state pool. Additionally, a couple of instances of dead code were found that should be removed to improve code clarity and maintainability.

Comment on lines +144 to +147
gSrc_batch = h0_source[(pool_idx, i_hv, None, None)] # (V, K)
gDst = cute.local_tile(
h0_source, (1, 1, TILE_V, TILE_K), (pool_idx, i_hv, None, 0)
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The kernel uses pool_idx (derived from h0_indices) to index into h0_source without validating that the index is within the bounds of the first dimension of the tensor. An attacker providing malicious initial_state_indices could cause out-of-bounds reads and writes on GPU memory. It is recommended to add a validation check in the Python wrapper to ensure all indices are within the range [-1, pool_size).

Comment on lines +446 to +449
gSrc_batch = h0_source[(pool_idx, i_hv, None, None)] # (V, K)
gDst = cute.local_tile(
h0_source, (1, 1, TILE_V, TILE_K), (pool_idx, i_hv, None, 0)
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

Similar to the small batch kernel, the big batch kernel also lacks validation for pool_idx, leading to potential out-of-bounds memory access on the GPU.

Comment on lines +311 to +316
if cutlass.const_expr(use_pool_indexing):
gDst_tile = cute.local_tile(
gDst,
(1, 1, 1, vec_size, 1),
(0, 0, row + row_offset, lane_id, v_tiles),
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The writeback logic also uses pool_idx (via gDst) without validation, which can lead to out-of-bounds writes on GPU memory.

Comment thread flashinfer/gdn_kernels/gdn_decode_pretranspose.py
Comment thread flashinfer/gdn_kernels/gdn_decode_pretranspose.py

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
tests/gdn/test_decode_pretranspose_noncontiguous_pool.py (1)

28-33: Use the repo’s standard SM90+ gate here.

cc[0] not in [9, 10, 11, 12] will skip valid future architectures, and the torch.cuda.is_available() branch is inconsistent with the rest of tests/.

♻️ Suggested cleanup
 def _skip_if_not_sm90_or_later() -> None:
-    if not torch.cuda.is_available():
-        pytest.skip("CUDA is required")
     cc = get_compute_capability(torch.device("cuda"))
-    if cc[0] not in [9, 10, 11, 12]:
-        pytest.skip(f"GDN decode requires SM90+ or SM100+, but got SM{cc[0]}{cc[1]}")
+    if cc[0] < 9:
+        pytest.skip(f"GDN decode requires SM90+ or later, but got SM{cc[0]}{cc[1]}")
Based on learnings, tests in the repository assume CUDA is available and do not require `torch.cuda.is_available()` guards in pytest fixtures; as per coding guidelines, use `flashinfer.utils` functions (`get_compute_capability()`, `is_sm90a_supported()`, `is_sm100a_supported()`) to skip tests on unsupported GPU architectures.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_pretranspose_noncontiguous_pool.py` around lines 28 -
33, Replace the custom CUDA gating in _skip_if_not_sm90_or_later: remove the
torch.cuda.is_available() check and the manual cc[0] membership test, and
instead call the repo utilities is_sm90a_supported() and is_sm100a_supported()
(and/or get_compute_capability() if you need details) from flashinfer.utils to
decide skipping; specifically, in function _skip_if_not_sm90_or_later use if not
(is_sm90a_supported() or is_sm100a_supported()): pytest.skip(...) so the test
follows the repo standard GPU support check.
flashinfer/gdn_kernels/gdn_decode_pretranspose.py (1)

705-706: Remove the no-op size expressions.

v_dim * k_dim * 4 / 1024 / 1024 reads like leftover debug code and just adds noise in the launcher path.

🧹 Suggested cleanup
     num_v_tiles = cute.ceil_div(v_dim, TILE_V)
-    v_dim * k_dim * 4 / 1024 / 1024

Also applies to: 811-812

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_pretranspose.py` around lines 705 - 706,
Remove the stray no-op size expressions that were left as debug noise (e.g., the
standalone expression "v_dim * k_dim * 4 / 1024 / 1024") — locate the blocks
around the computation of num_v_tiles (uses cute.ceil_div and TILE_V) and the
similar occurrence further down (the second instance referenced around lines
811-812) and delete those unused arithmetic expressions so only meaningful
assignments and operations (like num_v_tiles = cute.ceil_div(v_dim, TILE_V))
remain.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_kernels/gdn_decode_pretranspose.py`:
- Around line 929-950: The cache key construction for the pooled path (variable
cache_key) omits the tensor storage offset, so different views with identical
shape/strides but different h0_source.storage_offset() can incorrectly reuse a
specialization that assumed 16-byte alignment for from_dlpack(...,
assumed_align=16) and cpasync.CopyG2SOp(); fix by including
h0_source.storage_offset() in the cache_key tuple (alongside
B,T,H,HV,K,V,q.dtype,scale,use_qk_l2norm,use_pool_indexing,pool_size,stride0..3)
or alternatively add an explicit alignment check when use_pool_indexing is true
that rejects/raises on misaligned h0_source.storage_offset() before compilation.

---

Nitpick comments:
In `@flashinfer/gdn_kernels/gdn_decode_pretranspose.py`:
- Around line 705-706: Remove the stray no-op size expressions that were left as
debug noise (e.g., the standalone expression "v_dim * k_dim * 4 / 1024 / 1024")
— locate the blocks around the computation of num_v_tiles (uses cute.ceil_div
and TILE_V) and the similar occurrence further down (the second instance
referenced around lines 811-812) and delete those unused arithmetic expressions
so only meaningful assignments and operations (like num_v_tiles =
cute.ceil_div(v_dim, TILE_V)) remain.

In `@tests/gdn/test_decode_pretranspose_noncontiguous_pool.py`:
- Around line 28-33: Replace the custom CUDA gating in
_skip_if_not_sm90_or_later: remove the torch.cuda.is_available() check and the
manual cc[0] membership test, and instead call the repo utilities
is_sm90a_supported() and is_sm100a_supported() (and/or get_compute_capability()
if you need details) from flashinfer.utils to decide skipping; specifically, in
function _skip_if_not_sm90_or_later use if not (is_sm90a_supported() or
is_sm100a_supported()): pytest.skip(...) so the test follows the repo standard
GPU support check.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: d8607a65-4a37-4a8b-ae88-a1ec44c5e95f

📥 Commits

Reviewing files that changed from the base of the PR and between bcdf8d8 and 573cea6.

📒 Files selected for processing (3)
  • flashinfer/gdn_decode.py
  • flashinfer/gdn_kernels/gdn_decode_pretranspose.py
  • tests/gdn/test_decode_pretranspose_noncontiguous_pool.py

Comment on lines +929 to +950
if use_pool_indexing:
pool_size = int(h0_source.shape[0])
stride0, stride1, stride2, stride3 = tuple(int(x) for x in h0_source.stride())
else:
pool_size = stride0 = stride1 = stride2 = stride3 = 0
cache_key = (
B,
T,
H,
HV,
K,
V,
q.dtype,
scale,
use_qk_l2norm,
use_pool_indexing,
pool_size,
stride0,
stride1,
stride2,
stride3,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

In CUDA/CUTLASS CuTe DSL, what alignment guarantees are required for cpasync.CopyG2SOp(..., num_bits_per_copy=128)together withfrom_dlpack(..., assumed_align=16)? Are tensor views with stride(-1) == 1but non-zero or misalignedstorage_offset() safe to use with that contract?

💡 Result:

For cpasync.CopyG2SOp(..., num_bits_per_copy=128), the generated instruction is a 16-byte (128-bit) cp.async-style vector copy per lane, which relies on both the global-memory source address and the shared-memory destination address being at least 16-byte aligned for the accessed vector; and the transfer size is effectively in 16-byte quanta. This matches CUDA’s cp.async / memcpy_async alignment contract: if you want it implemented purely with async copies of size 16, both pointers must be 16B-aligned to 16 (and the moved size must be compatible with that granularity). [1]

from_dlpack(..., assumed_align=16) in CuTe is exactly such a promise: it tells CuTe/CUTLASS it may assume the runtime pointer is 16-byte aligned (and thus it may select 16B-vectorized paths like the 128-bit async-copy atom) without inserting runtime alignment checks. [2] If that assumption is false, you can hit alignment verification failures or miscompiled/undefined behavior, as reported in CUTLASS CuTe DSL issues when assumed_align=16 is used but the effective pointer isn’t actually 16B aligned. [3]

So: a tensor view with stride(-1) == 1 is not automatically safe under this contract. It is safe only if the effective base address that CuTe will use (i.e., data_ptr + storage_offset * sizeof(dtype)) is 16-byte aligned, and the per-thread vectorized addressing pattern preserves that 16B alignment. If storage_offset() is non-zero and makes the effective base misaligned (e.g., for fp16, an odd storage_offset() makes the byte offset 2 mod 16), then it violates assumed_align=16 and is not safe for num_bits_per_copy=128. [1][2][3]

Sources: CUDA async copy alignment requirements [1], CUTLASS CuTe from_dlpack(assumed_align=...) API doc [2], CUTLASS issue demonstrating failure when assumed_align=16 doesn’t hold for a CuTe async-copy path [3].


🏁 Script executed:

# First, examine the file at the specific lines mentioned
cat -n flashinfer/gdn_kernels/gdn_decode_pretranspose.py | sed -n '920,960p'

Repository: flashinfer-ai/flashinfer

Length of output: 1804


🏁 Script executed:

# Also check lines 968-969 mentioned in "Also applies to"
cat -n flashinfer/gdn_kernels/gdn_decode_pretranspose.py | sed -n '960,975p'

Repository: flashinfer-ai/flashinfer

Length of output: 911


🏁 Script executed:

# Search for assumed_align usage in the file
rg "assumed_align" flashinfer/gdn_kernels/gdn_decode_pretranspose.py -n

Repository: flashinfer-ai/flashinfer

Length of output: 760


🏁 Script executed:

# Look for where use_pool_indexing is handled and any alignment checks
rg "use_pool_indexing" flashinfer/gdn_kernels/gdn_decode_pretranspose.py -B 2 -A 5 -n

Repository: flashinfer-ai/flashinfer

Length of output: 6504


🏁 Script executed:

# Check the test file to see what it covers
cat -n flashinfer/tests/gdn/test_decode_pretranspose_noncontiguous_pool.py | sed -n '50,80p'

Repository: flashinfer-ai/flashinfer

Length of output: 167


🏁 Script executed:

# Find the test file mentioned in the review
fd "test_decode_pretranspose" flashinfer/tests

Repository: flashinfer-ai/flashinfer

Length of output: 171


🏁 Script executed:

# Let me also check if there's any storage_offset being tracked in the cache_key
grep -n "storage_offset\|offset_" flashinfer/gdn_kernels/gdn_decode_pretranspose.py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if there are any alignment checks before kernel compilation
grep -n "align" flashinfer/gdn_kernels/gdn_decode_pretranspose.py -i

Repository: flashinfer-ai/flashinfer

Length of output: 760


🏁 Script executed:

# Look at the actual kernel code where h0_source is used with pooled indexing
cat -n flashinfer/gdn_kernels/gdn_decode_pretranspose.py | sed -n '140,150p'

Repository: flashinfer-ai/flashinfer

Length of output: 638


🏁 Script executed:

# Search for CuTe async copy usage or cp.async patterns
rg "cp\.async|CopyG2SOp|cpasync" flashinfer/

Repository: flashinfer-ai/flashinfer

Length of output: 23222


Cache key missing storage_offset allows misaligned pooled views to reuse wrong specialization.

The pooled path compiles with from_dlpack(..., assumed_align=16) and cpasync.CopyG2SOp() (line 969), which requires the effective base address (data_ptr + storage_offset * sizeof(dtype)) to be 16-byte aligned. However, the cache key (lines 929–950) includes shape and strides but not storage_offset().

Two tensor views can have identical shape and strides but different storage_offset()—for example, buf[..., :128] (offset 0) and buf[..., 1:129] (offset 1). If the first is page-aligned and cached, the second would incorrectly reuse that specialization despite being misaligned (e.g., byte offset 2 mod 16 for fp16), violating the assumed_align=16 contract.

Fix by either including storage_offset() in the cache key or rejecting misaligned pooled views before compilation.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_pretranspose.py` around lines 929 - 950,
The cache key construction for the pooled path (variable cache_key) omits the
tensor storage offset, so different views with identical shape/strides but
different h0_source.storage_offset() can incorrectly reuse a specialization that
assumed 16-byte alignment for from_dlpack(..., assumed_align=16) and
cpasync.CopyG2SOp(); fix by including h0_source.storage_offset() in the
cache_key tuple (alongside
B,T,H,HV,K,V,q.dtype,scale,use_qk_l2norm,use_pool_indexing,pool_size,stride0..3)
or alternatively add an explicit alignment check when use_pool_indexing is true
that rejects/raises on misaligned h0_source.storage_offset() before compilation.

@vadiklyutiy

Copy link
Copy Markdown
Contributor

Doesn't it impact on performance?

@ZJY0516

ZJY0516 commented Mar 11, 2026

Copy link
Copy Markdown
Contributor Author

Doesn't it impact on performance?

Good question, will test it later

@vadiklyutiy

Copy link
Copy Markdown
Contributor

@ameynaik-hub

@ameynaik-hub

Copy link
Copy Markdown
Contributor

does vllm also call the MTP path with non-contiguous pool tensors? If yes, MTP needs the same fix?

ameynaik-hub added a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 11, 2026
…transpose

Add non-contiguous pool test (page-strided tensors) to
test_decode_delta_rule.py and review notes with correctness
results (33/33 passed on B200) and benchmark data (no regression,
1.41x vs Triton).

AI-assisted review with Claude Code.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@ZJY0516

ZJY0516 commented Mar 12, 2026

Copy link
Copy Markdown
Contributor Author

does vllm also call the MTP path with non-contiguous pool tensors? If yes, MTP needs the same fix?

ssm_state is non-contiguous as a whole tensor, but each per-slot slice ssm_state[i, :, :, :] is still contiguous.

@vadiklyutiy

Copy link
Copy Markdown
Contributor

Could you please kindly review this PR? It addresses an important issue that was missed in the GDN decode path.

@ameynaik-hub

Copy link
Copy Markdown
Contributor

stride[0] always constant across all pool slots?

@ZJY0516

ZJY0516 commented Mar 17, 2026

Copy link
Copy Markdown
Contributor Author

stride[0] always constant across all pool slots?

yes

@ameynaik-hub

Copy link
Copy Markdown
Contributor

@kahyunnam can you please help merge this PR; seems like it is critical for vllm integration. cc: @vadiklyutiy

@vadiklyutiy

Copy link
Copy Markdown
Contributor

@kahyunnam can you please help merge this PR; seems like it is critical for vllm integration. cc: @vadiklyutiy

yes, it is really don't allow to use FI's GDN decode

@bkryu bkryu added the run-ci label Mar 19, 2026
@ZJY0516

ZJY0516 commented Mar 24, 2026

Copy link
Copy Markdown
Contributor Author

Doesn't it impact on performance?

Have verified, no regression

case batch main (us) branch (us) difference
direct_state 8 22.66 23.22 +2.47%
direct_state 32 48.35 48.98 +1.29%
direct_state 128 148.72 148.93 +0.14%
pool_contig 8 24.13 24.10 -0.13%
pool_contig 32 49.73 49.70 -0.06%
pool_contig 128 151.87 151.90 +0.02%

@kahyunnam

Copy link
Copy Markdown
Member

/bot run

@flashinfer-bot

Copy link
Copy Markdown
Collaborator

GitLab MR !456 has been created, and the CI pipeline #46893983 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot

Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #46893983: 10/20 passed

@ameynaik-hub

Copy link
Copy Markdown
Contributor

failures seem unrelated to this PR? cc: @ZJY0516 @kahyunnam

@ZJY0516

ZJY0516 commented Mar 25, 2026

Copy link
Copy Markdown
Contributor Author

failures seem unrelated to this PR? cc: @ZJY0516 @kahyunnam

I think it's environment issue?

failed to extract layer (application/vnd.docker.image.rootfs.diff.tar.gzip sha256:bd247c667ef3aae56ea1a803ea2455ab7519ca65a547927be24508309f2ce4d0) to overlayfs as "extract-786880983-Xb5r sha256:b81ebc12800474b10f90d1a02b04ef08b3c01d479a75c1fcfdf504c850693fd9": Canceled: grpc: the client connection is closing: context canceled
error during connect: Get "http://%2Fvar%2Frun%2Fdocker.sock/_ping": read unix @->/run/docker.sock: read: connection reset by peer

@kahyunnam kahyunnam merged commit d505e4e into flashinfer-ai:main Mar 25, 2026
31 of 45 checks passed
@ZJY0516 ZJY0516 deleted the gdn_decode_non_con branch March 25, 2026 16:24
@coderabbitai coderabbitai Bot mentioned this pull request May 8, 2026
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants