[gdn] support non-contiguous state for decoding#2727
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces support for non-contiguous state tensors within the gated delta rule decode pretranspose operation. By adapting the Python wrapper and CUDA kernels, the system can now efficiently process state data with flexible memory layouts, such as page-strided pools, without requiring explicit contiguity, thereby improving memory management and potentially enabling new use cases. The changes ensure that K-contiguity is maintained for performance while allowing non-contiguity in other dimensions. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
📝 WalkthroughWalkthroughThis pull request adds pool-indexing support to the GDN pretranspose decode path. It introduces K-contiguity assertion, relaxes pool-path state handling to preserve non-contiguous layouts, and extends kernel compilation caching with pool metadata parameters (pool_size, stride0-3). A new test validates noncontiguous pool handling. Changes
Sequence Diagram(s)sequenceDiagram
participant Host as Host (Python)
participant Cache as Kernel Cache
participant Compiler as CUDA Compiler
participant Kernel as Pretranspose Kernel
participant State as State Tensor
rect rgba(100, 150, 200, 0.5)
Note over Host: Prepare h0_source with pool indexing enabled
Host->>Host: Validate initial_state K-contiguity
Host->>Host: Set h0_source = initial_state<br/>(preserve non-contiguous layout)
end
rect rgba(150, 100, 200, 0.5)
Note over Host,Compiler: Kernel Compilation with Pool Metadata
Host->>Host: Extract pool_size, stride0-3 from h0_source
Host->>Host: Build cache_key with pool parameters
Host->>Cache: Check cache for kernel
alt Cache Hit
Cache-->>Kernel: Return compiled kernel
else Cache Miss
Host->>Compiler: Compile kernel with pool metadata
Compiler-->>Cache: Store compiled kernel
end
end
rect rgba(200, 150, 100, 0.5)
Note over Host,Kernel: Kernel Execution with Pool-based Indexing
Host->>Kernel: Launch with pool_idx, i_hv indexing
Kernel->>State: Read gSrc_batch [pool_size, HV, V, K]
Kernel->>Kernel: Process with pool-aware state slicing
Kernel->>State: Write gDst [pool_size, HV, V, K]
Kernel-->>Host: Return output & updated_state
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces support for non-contiguous state tensors in the Gated Delta Rule (GDN) decode functionality by updating the Python wrapper and modifying the CuTe DSL kernels to handle 4D non-contiguous memory layout. However, a high-severity security issue has been identified: the kernels lack validation for user-provided pool indices, which can lead to out-of-bounds memory access on the GPU in multi-tenant environments. It is recommended to add a check in the Python API layer to validate that all indices in initial_state_indices are within the valid range of the state pool. Additionally, a couple of instances of dead code were found that should be removed to improve code clarity and maintainability.
| gSrc_batch = h0_source[(pool_idx, i_hv, None, None)] # (V, K) | ||
| gDst = cute.local_tile( | ||
| h0_source, (1, 1, TILE_V, TILE_K), (pool_idx, i_hv, None, 0) | ||
| ) |
There was a problem hiding this comment.
The kernel uses pool_idx (derived from h0_indices) to index into h0_source without validating that the index is within the bounds of the first dimension of the tensor. An attacker providing malicious initial_state_indices could cause out-of-bounds reads and writes on GPU memory. It is recommended to add a validation check in the Python wrapper to ensure all indices are within the range [-1, pool_size).
| gSrc_batch = h0_source[(pool_idx, i_hv, None, None)] # (V, K) | ||
| gDst = cute.local_tile( | ||
| h0_source, (1, 1, TILE_V, TILE_K), (pool_idx, i_hv, None, 0) | ||
| ) |
| if cutlass.const_expr(use_pool_indexing): | ||
| gDst_tile = cute.local_tile( | ||
| gDst, | ||
| (1, 1, 1, vec_size, 1), | ||
| (0, 0, row + row_offset, lane_id, v_tiles), | ||
| ) |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
tests/gdn/test_decode_pretranspose_noncontiguous_pool.py (1)
28-33: Use the repo’s standard SM90+ gate here.
cc[0] not in [9, 10, 11, 12]will skip valid future architectures, and thetorch.cuda.is_available()branch is inconsistent with the rest oftests/.Based on learnings, tests in the repository assume CUDA is available and do not require `torch.cuda.is_available()` guards in pytest fixtures; as per coding guidelines, use `flashinfer.utils` functions (`get_compute_capability()`, `is_sm90a_supported()`, `is_sm100a_supported()`) to skip tests on unsupported GPU architectures.♻️ Suggested cleanup
def _skip_if_not_sm90_or_later() -> None: - if not torch.cuda.is_available(): - pytest.skip("CUDA is required") cc = get_compute_capability(torch.device("cuda")) - if cc[0] not in [9, 10, 11, 12]: - pytest.skip(f"GDN decode requires SM90+ or SM100+, but got SM{cc[0]}{cc[1]}") + if cc[0] < 9: + pytest.skip(f"GDN decode requires SM90+ or later, but got SM{cc[0]}{cc[1]}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_decode_pretranspose_noncontiguous_pool.py` around lines 28 - 33, Replace the custom CUDA gating in _skip_if_not_sm90_or_later: remove the torch.cuda.is_available() check and the manual cc[0] membership test, and instead call the repo utilities is_sm90a_supported() and is_sm100a_supported() (and/or get_compute_capability() if you need details) from flashinfer.utils to decide skipping; specifically, in function _skip_if_not_sm90_or_later use if not (is_sm90a_supported() or is_sm100a_supported()): pytest.skip(...) so the test follows the repo standard GPU support check.flashinfer/gdn_kernels/gdn_decode_pretranspose.py (1)
705-706: Remove the no-op size expressions.
v_dim * k_dim * 4 / 1024 / 1024reads like leftover debug code and just adds noise in the launcher path.🧹 Suggested cleanup
num_v_tiles = cute.ceil_div(v_dim, TILE_V) - v_dim * k_dim * 4 / 1024 / 1024Also applies to: 811-812
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_kernels/gdn_decode_pretranspose.py` around lines 705 - 706, Remove the stray no-op size expressions that were left as debug noise (e.g., the standalone expression "v_dim * k_dim * 4 / 1024 / 1024") — locate the blocks around the computation of num_v_tiles (uses cute.ceil_div and TILE_V) and the similar occurrence further down (the second instance referenced around lines 811-812) and delete those unused arithmetic expressions so only meaningful assignments and operations (like num_v_tiles = cute.ceil_div(v_dim, TILE_V)) remain.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gdn_kernels/gdn_decode_pretranspose.py`:
- Around line 929-950: The cache key construction for the pooled path (variable
cache_key) omits the tensor storage offset, so different views with identical
shape/strides but different h0_source.storage_offset() can incorrectly reuse a
specialization that assumed 16-byte alignment for from_dlpack(...,
assumed_align=16) and cpasync.CopyG2SOp(); fix by including
h0_source.storage_offset() in the cache_key tuple (alongside
B,T,H,HV,K,V,q.dtype,scale,use_qk_l2norm,use_pool_indexing,pool_size,stride0..3)
or alternatively add an explicit alignment check when use_pool_indexing is true
that rejects/raises on misaligned h0_source.storage_offset() before compilation.
---
Nitpick comments:
In `@flashinfer/gdn_kernels/gdn_decode_pretranspose.py`:
- Around line 705-706: Remove the stray no-op size expressions that were left as
debug noise (e.g., the standalone expression "v_dim * k_dim * 4 / 1024 / 1024")
— locate the blocks around the computation of num_v_tiles (uses cute.ceil_div
and TILE_V) and the similar occurrence further down (the second instance
referenced around lines 811-812) and delete those unused arithmetic expressions
so only meaningful assignments and operations (like num_v_tiles =
cute.ceil_div(v_dim, TILE_V)) remain.
In `@tests/gdn/test_decode_pretranspose_noncontiguous_pool.py`:
- Around line 28-33: Replace the custom CUDA gating in
_skip_if_not_sm90_or_later: remove the torch.cuda.is_available() check and the
manual cc[0] membership test, and instead call the repo utilities
is_sm90a_supported() and is_sm100a_supported() (and/or get_compute_capability()
if you need details) from flashinfer.utils to decide skipping; specifically, in
function _skip_if_not_sm90_or_later use if not (is_sm90a_supported() or
is_sm100a_supported()): pytest.skip(...) so the test follows the repo standard
GPU support check.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: d8607a65-4a37-4a8b-ae88-a1ec44c5e95f
📒 Files selected for processing (3)
flashinfer/gdn_decode.pyflashinfer/gdn_kernels/gdn_decode_pretranspose.pytests/gdn/test_decode_pretranspose_noncontiguous_pool.py
| if use_pool_indexing: | ||
| pool_size = int(h0_source.shape[0]) | ||
| stride0, stride1, stride2, stride3 = tuple(int(x) for x in h0_source.stride()) | ||
| else: | ||
| pool_size = stride0 = stride1 = stride2 = stride3 = 0 | ||
| cache_key = ( | ||
| B, | ||
| T, | ||
| H, | ||
| HV, | ||
| K, | ||
| V, | ||
| q.dtype, | ||
| scale, | ||
| use_qk_l2norm, | ||
| use_pool_indexing, | ||
| pool_size, | ||
| stride0, | ||
| stride1, | ||
| stride2, | ||
| stride3, | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
In CUDA/CUTLASS CuTe DSL, what alignment guarantees are required for cpasync.CopyG2SOp(..., num_bits_per_copy=128)together withfrom_dlpack(..., assumed_align=16)? Are tensor views with stride(-1) == 1but non-zero or misalignedstorage_offset() safe to use with that contract?
💡 Result:
For cpasync.CopyG2SOp(..., num_bits_per_copy=128), the generated instruction is a 16-byte (128-bit) cp.async-style vector copy per lane, which relies on both the global-memory source address and the shared-memory destination address being at least 16-byte aligned for the accessed vector; and the transfer size is effectively in 16-byte quanta. This matches CUDA’s cp.async / memcpy_async alignment contract: if you want it implemented purely with async copies of size 16, both pointers must be 16B-aligned to 16 (and the moved size must be compatible with that granularity). [1]
from_dlpack(..., assumed_align=16) in CuTe is exactly such a promise: it tells CuTe/CUTLASS it may assume the runtime pointer is 16-byte aligned (and thus it may select 16B-vectorized paths like the 128-bit async-copy atom) without inserting runtime alignment checks. [2] If that assumption is false, you can hit alignment verification failures or miscompiled/undefined behavior, as reported in CUTLASS CuTe DSL issues when assumed_align=16 is used but the effective pointer isn’t actually 16B aligned. [3]
So: a tensor view with stride(-1) == 1 is not automatically safe under this contract. It is safe only if the effective base address that CuTe will use (i.e., data_ptr + storage_offset * sizeof(dtype)) is 16-byte aligned, and the per-thread vectorized addressing pattern preserves that 16B alignment. If storage_offset() is non-zero and makes the effective base misaligned (e.g., for fp16, an odd storage_offset() makes the byte offset 2 mod 16), then it violates assumed_align=16 and is not safe for num_bits_per_copy=128. [1][2][3]
Sources: CUDA async copy alignment requirements [1], CUTLASS CuTe from_dlpack(assumed_align=...) API doc [2], CUTLASS issue demonstrating failure when assumed_align=16 doesn’t hold for a CuTe async-copy path [3].
🏁 Script executed:
# First, examine the file at the specific lines mentioned
cat -n flashinfer/gdn_kernels/gdn_decode_pretranspose.py | sed -n '920,960p'Repository: flashinfer-ai/flashinfer
Length of output: 1804
🏁 Script executed:
# Also check lines 968-969 mentioned in "Also applies to"
cat -n flashinfer/gdn_kernels/gdn_decode_pretranspose.py | sed -n '960,975p'Repository: flashinfer-ai/flashinfer
Length of output: 911
🏁 Script executed:
# Search for assumed_align usage in the file
rg "assumed_align" flashinfer/gdn_kernels/gdn_decode_pretranspose.py -nRepository: flashinfer-ai/flashinfer
Length of output: 760
🏁 Script executed:
# Look for where use_pool_indexing is handled and any alignment checks
rg "use_pool_indexing" flashinfer/gdn_kernels/gdn_decode_pretranspose.py -B 2 -A 5 -nRepository: flashinfer-ai/flashinfer
Length of output: 6504
🏁 Script executed:
# Check the test file to see what it covers
cat -n flashinfer/tests/gdn/test_decode_pretranspose_noncontiguous_pool.py | sed -n '50,80p'Repository: flashinfer-ai/flashinfer
Length of output: 167
🏁 Script executed:
# Find the test file mentioned in the review
fd "test_decode_pretranspose" flashinfer/testsRepository: flashinfer-ai/flashinfer
Length of output: 171
🏁 Script executed:
# Let me also check if there's any storage_offset being tracked in the cache_key
grep -n "storage_offset\|offset_" flashinfer/gdn_kernels/gdn_decode_pretranspose.pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if there are any alignment checks before kernel compilation
grep -n "align" flashinfer/gdn_kernels/gdn_decode_pretranspose.py -iRepository: flashinfer-ai/flashinfer
Length of output: 760
🏁 Script executed:
# Look at the actual kernel code where h0_source is used with pooled indexing
cat -n flashinfer/gdn_kernels/gdn_decode_pretranspose.py | sed -n '140,150p'Repository: flashinfer-ai/flashinfer
Length of output: 638
🏁 Script executed:
# Search for CuTe async copy usage or cp.async patterns
rg "cp\.async|CopyG2SOp|cpasync" flashinfer/Repository: flashinfer-ai/flashinfer
Length of output: 23222
Cache key missing storage_offset allows misaligned pooled views to reuse wrong specialization.
The pooled path compiles with from_dlpack(..., assumed_align=16) and cpasync.CopyG2SOp() (line 969), which requires the effective base address (data_ptr + storage_offset * sizeof(dtype)) to be 16-byte aligned. However, the cache key (lines 929–950) includes shape and strides but not storage_offset().
Two tensor views can have identical shape and strides but different storage_offset()—for example, buf[..., :128] (offset 0) and buf[..., 1:129] (offset 1). If the first is page-aligned and cached, the second would incorrectly reuse that specialization despite being misaligned (e.g., byte offset 2 mod 16 for fp16), violating the assumed_align=16 contract.
Fix by either including storage_offset() in the cache key or rejecting misaligned pooled views before compilation.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_kernels/gdn_decode_pretranspose.py` around lines 929 - 950,
The cache key construction for the pooled path (variable cache_key) omits the
tensor storage offset, so different views with identical shape/strides but
different h0_source.storage_offset() can incorrectly reuse a specialization that
assumed 16-byte alignment for from_dlpack(..., assumed_align=16) and
cpasync.CopyG2SOp(); fix by including h0_source.storage_offset() in the
cache_key tuple (alongside
B,T,H,HV,K,V,q.dtype,scale,use_qk_l2norm,use_pool_indexing,pool_size,stride0..3)
or alternatively add an explicit alignment check when use_pool_indexing is true
that rejects/raises on misaligned h0_source.storage_offset() before compilation.
|
Doesn't it impact on performance? |
Good question, will test it later |
|
does vllm also call the MTP path with non-contiguous pool tensors? If yes, MTP needs the same fix? |
…transpose Add non-contiguous pool test (page-strided tensors) to test_decode_delta_rule.py and review notes with correctness results (33/33 passed on B200) and benchmark data (no regression, 1.41x vs Triton). AI-assisted review with Claude Code. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
ssm_state is non-contiguous as a whole tensor, but each per-slot slice ssm_state[i, :, :, :] is still contiguous. |
|
Could you please kindly review this PR? It addresses an important issue that was missed in the GDN decode path. |
|
stride[0] always constant across all pool slots? |
yes |
|
@kahyunnam can you please help merge this PR; seems like it is critical for vllm integration. cc: @vadiklyutiy |
yes, it is really don't allow to use FI's GDN decode |
Have verified, no regression
|
|
/bot run |
|
[FAILED] Pipeline #46893983: 10/20 passed |
|
failures seem unrelated to this PR? cc: @ZJY0516 @kahyunnam |
I think it's environment issue?
|
📌 Description
vllm uses non-contiguous state for gdn. Make flashinfer also support it
🔍 Related Issues
#2521
#2687
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes