feat(gdn): add BF16 state kernel with MTP support beyond T>4 with intermediate caching.#2679
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR splits and renames the BF16-state GDN decode backend into two variants (single-token BF16 state and BF16-state MTP for multi-token), updates kernel exports/availability flags, changes runtime dispatch to choose T==1 vs T>1 kernels, and adjusts benchmarks/tests/CLI to match the new APIs. Changes
Sequence Diagram(s)sequenceDiagram
participant Client as Decode API
participant Dispatch as Kernel Dispatch
participant BF16 as BF16 State (T=1)
participant MTP as BF16 State MTP (T>1)
Client->>Dispatch: decode(T, dtype=bf16, args...)
activate Dispatch
alt T == 1
Dispatch->>BF16: call _gated_delta_rule_bf16_state(...)
activate BF16
BF16-->>Dispatch: result
deactivate BF16
else T > 1
Dispatch->>MTP: call _gated_delta_rule_bf16_state_mtp(...)
activate MTP
MTP-->>Dispatch: result
deactivate MTP
end
deactivate Dispatch
Dispatch-->>Client: decoded output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related issues
Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the GDN decode functionality by integrating a new, highly optimized BF16 state kernel. This kernel, built with CuTe DSL, provides substantial performance gains across various batch sizes and sequence lengths, particularly for multi-token prediction. The changes involve refactoring existing kernel implementations, updating benchmarking infrastructure, and expanding test coverage to ensure correctness and efficiency of the new BF16 state processing. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a high-performance BF16 state kernel for GDN decode, supporting both single-token (T=1) and multi-token prediction (MTP), which demonstrates significant performance improvements. The changes are well-structured, including updates to benchmarks and the addition of comprehensive tests.
My review has identified a bug in one of the new tests that needs to be addressed to ensure correctness. Additionally, I've provided a couple of suggestions to refactor duplicated code in both the benchmark and test files, which will improve the overall maintainability of the codebase.
| if T == 1: | ||
| return gdn_decode_bf16_state( | ||
| A_log=A_log, | ||
| a=a, | ||
| dt_bias=dt_bias, | ||
| softplus_beta=softplus_beta, | ||
| softplus_threshold=softplus_threshold, | ||
| q=q, | ||
| k=k, | ||
| v=v, | ||
| b=b, | ||
| initial_state_source=state, | ||
| use_qk_l2norm_in_kernel=use_qk_l2norm, | ||
| scale=scale, | ||
| ) | ||
| else: | ||
| return gdn_decode_bf16_state_mtp( | ||
| A_log=A_log, | ||
| a=a, | ||
| dt_bias=dt_bias, | ||
| softplus_beta=softplus_beta, | ||
| softplus_threshold=softplus_threshold, | ||
| q=q, | ||
| k=k, | ||
| v=v, | ||
| b=b, | ||
| initial_state_source=state, | ||
| use_qk_l2norm_in_kernel=use_qk_l2norm, | ||
| scale=scale, | ||
| ) |
There was a problem hiding this comment.
The calls to gdn_decode_bf16_state and gdn_decode_bf16_state_mtp share the same set of arguments. This code can be refactored to reduce duplication and improve maintainability by selecting the kernel function first and then calling it with a shared set of keyword arguments.
T = q.shape[1]
kernel_fn = gdn_decode_bf16_state if T == 1 else gdn_decode_bf16_state_mtp
return kernel_fn(
A_log=A_log,
a=a,
dt_bias=dt_bias,
softplus_beta=softplus_beta,
softplus_threshold=softplus_threshold,
q=q,
k=k,
v=v,
b=b,
initial_state_source=state,
use_qk_l2norm_in_kernel=use_qk_l2norm,
scale=scale,
)| def _test_gdn_decode_bf16_state_t1_kernel( | ||
| dtype: str, | ||
| batch_size: int, | ||
| num_q_heads: int, | ||
| num_k_heads: int, | ||
| num_v_heads: int, | ||
| head_size: int, | ||
| scale: float, | ||
| alpha: bool, | ||
| beta: bool, | ||
| seed: int | None = None, | ||
| ): |
There was a problem hiding this comment.
There's significant code duplication for test data generation across _test_gdn_decode_bf16_state_kernel, _test_gdn_decode_bf16_state_t1_kernel, and _test_gdn_decode_bf16_state_mtp_kernel. To improve maintainability, consider refactoring the tensor creation logic into a shared helper function or a pytest fixture. This would make the tests cleaner and easier to manage.
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
benchmarks/bench_gdn_decode.py (3)
2788-2799:⚠️ Potential issue | 🟠 MajorRestore
--comparerouting for decode versions inmain().For non-MTP paths,
args.compareis currently ignored, so single-layout comparison mode is no longer reachable from the CLI.Suggested fix
- if args.version == "mtp": + if args.version == "mtp": # MTP mode: use comparison or flashinfer-only if args.compare: run_comparison_benchmark(args, dtype, use_qk_l2norm) else: run_flashinfer_only_benchmark(args, dtype, use_qk_l2norm) elif args.version == "bf16_state": # BF16 state benchmark: T=1 and MTP T>=2 vs FP32 MTP run_gdn_decode_bf16_state_benchmark(args, dtype, use_qk_l2norm) else: - # Non-MTP: always run all layouts comparison (FlashInfer/Triton x pretranspose/nontranspose + gdn_decode_bf16_state) - run_all_layouts_benchmark(args, dtype, use_qk_l2norm) + # Decode mode: honor --compare flag + if args.compare: + run_comparison_benchmark(args, dtype, use_qk_l2norm) + else: + run_all_layouts_benchmark(args, dtype, use_qk_l2norm)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_gdn_decode.py` around lines 2788 - 2799, The non-MTP branches ignore args.compare so the CLI --compare mode is unreachable; update the bf16_state and else branches in main(): for the "bf16_state" branch, if args.compare call run_comparison_benchmark(args, dtype, use_qk_l2norm) else call run_gdn_decode_bf16_state_benchmark(...); for the final else branch, if args.compare call run_comparison_benchmark(...) else call run_all_layouts_benchmark(...). Use the same argument list (args, dtype, use_qk_l2norm) when invoking the functions run_comparison_benchmark, run_gdn_decode_bf16_state_benchmark, and run_all_layouts_benchmark so --compare behaves consistently across versions.
1845-1891:⚠️ Potential issue | 🟠 MajorForward preallocated tensors in BF16 MTP wrapper to avoid benchmark allocations.
The wrapper accepts
outputand ignores it in both T=1 and T>1 paths. For T>1, the MTP kernel supports bothoutputandinitial_state_indicesparameters, but the wrapper doesn't forward them. This causes allocations during benchmark timing, skewing measurements.Add
initial_state_indicesparameter to wrapper signature and forward both parameters togdn_decode_bf16_state_mtp(). Updatebench_gdn_decode_bf16_state()to create and passinitial_state_indiceswhen T>1.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_gdn_decode.py` around lines 1845 - 1891, The wrapper around gdn_decode_bf16_state currently ignores the preallocated output and doesn't accept or forward initial_state_indices for the multi-timestep path, causing benchmark allocations; update the wrapper signature to add an initial_state_indices: torch.Tensor (or optional) parameter, and when T>1 forward both output and initial_state_indices into gdn_decode_bf16_state_mtp by passing the kernel's output=output and initial_state_indices=initial_state_indices args (keep the T==1 call unchanged), and update bench_gdn_decode_bf16_state() to allocate/create initial_state_indices for the T>1 case and pass it through to the wrapper so no new allocations occur inside the kernel call.
2049-2068:⚠️ Potential issue | 🟠 MajorUse
float32dt_biasfor BF16-state benchmark calls.The
gdn_decode_bf16_statekernel specification requiresdt_biasas[HV] float32, as confirmed by kernel docstrings and test code comments. These benchmark paths currently create it withdtype(typically BF16), which diverges from the intended interface and may benchmark a different numerical path than intended.Suggested fix
- dt_bias = torch.randn(num_sab_heads, dtype=dtype, device="cuda") + dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device="cuda")Also applies to: 2256-2260
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_gdn_decode.py` around lines 2049 - 2068, The benchmark is passing dt_bias with the wrong dtype for the BF16-state kernel; update the dt_bias tensor construction used in the gdn_decode_bf16_state benchmark calls so dt_bias is created as float32 (torch.float32) with shape [num_sab_heads * head_size] or [HV] as required, then pass that float32 dt_bias into gdn_decode_bf16_state_wrapper (and the second BF16-state call around the other occurrence). Locate the dt_bias variable used in the gdn_decode_bf16_state_wrapper invocations and change its dtype to torch.float32 while keeping device and shape identical.tests/gdn/test_decode_delta_rule.py (1)
824-831:⚠️ Potential issue | 🟠 MajorFix kernel dispatch for
seq_len > 1in pretranspose API test.The test parametrizes
seq_lenwith [1, 2, 3, 4] (line 824), but the direct kernel call at line 902 always invokesgdn_decode_bf16_state, which is the single-token (T=1) kernel. Forseq_len > 1, the test should dispatch togdn_decode_bf16_state_mtpand pass theinitial_state_indicesparameter. Currently this breaks verification of the multi-token path.Suggested fix
- # Direct improved kernel - out_direct = gdn_decode_bf16_state( + # Direct kernel: T=1 uses single-token path, T>1 uses MTP path + direct_kernel = ( + gdn_decode_bf16_state if seq_len == 1 else gdn_decode_bf16_state_mtp + ) + direct_kwargs = dict( A_log=A_log, a=a, dt_bias=dt_bias, softplus_beta=1.0, softplus_threshold=20.0, q=q, k=k, v=v, b=b_tensor, initial_state_source=state_direct, use_qk_l2norm_in_kernel=True, scale=scale, - ) + ) + if seq_len > 1: + direct_kwargs["initial_state_indices"] = torch.arange( + batch_size, dtype=torch.int32, device=device + ) + out_direct = direct_kernel(**direct_kwargs)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_decode_delta_rule.py` around lines 824 - 831, The test test_pretranspose_api_uses_gdn_decode_bf16_state incorrectly always calls the single-token kernel gdn_decode_bf16_state; update the dispatch so when seq_len > 1 it calls gdn_decode_bf16_state_mtp and supplies the initial_state_indices argument (preserving the existing call for seq_len == 1). Locate the direct kernel invocation around the test body (the call to gdn_decode_bf16_state) and add a conditional: if seq_len == 1 keep the existing call, else call gdn_decode_bf16_state_mtp with the same parameters plus initial_state_indices so the multi-token verification path is exercised.
🧹 Nitpick comments (2)
tests/gdn/test_decode_delta_rule.py (2)
1148-1179: Validate cached intermediate states in the BF16 MTP test path.When
cache_intermediate_states=True, the test currently doesn’t verify buffer contents. Adding a reference comparison would protect the new intermediate-caching path from silent regressions.Also applies to: 1218-1225
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_decode_delta_rule.py` around lines 1148 - 1179, Add an explicit verification of the cached intermediate states when cache_intermediate_states=True: after calling gdn_decode_bf16_state_mtp with intermediate_states_buffer provided, run a reference call that produces the expected intermediate states (e.g., call gdn_decode_bf16_state_mtp or the float32 equivalent with caching disabled/producing a reference buffer) and compare intermediate_states_buffer to that reference using an appropriate numeric tolerance and dtype/device conversion (use torch.testing.assert_allclose or equivalent) so the BF16 MTP test path actually validates buffer contents; apply the same check in the other test location that mirrors this logic (the block around the second call referenced in the comment).
634-637: Remove or use the unusedalphahelper parameter.
alphais currently unused in the BF16 helper tests, which makes the test surface a bit misleading.Also applies to: 937-940
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_decode_delta_rule.py` around lines 634 - 637, The helper function in tests/gdn/test_decode_delta_rule.py declares an unused parameter alpha alongside beta and seed; either remove alpha from the helper signature and all its call sites in that file (including the BF16 helper usages) or modify the BF16 helper tests to actually use the alpha parameter, ensuring you update the function signature and every invocation consistently; reference the parameter names alpha, beta, seed when making the change so you catch all occurrences (the same unused-alpha issue also appears later in the file around the BF16 helper usages).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@benchmarks/bench_gdn_decode.py`:
- Around line 2788-2799: The non-MTP branches ignore args.compare so the CLI
--compare mode is unreachable; update the bf16_state and else branches in
main(): for the "bf16_state" branch, if args.compare call
run_comparison_benchmark(args, dtype, use_qk_l2norm) else call
run_gdn_decode_bf16_state_benchmark(...); for the final else branch, if
args.compare call run_comparison_benchmark(...) else call
run_all_layouts_benchmark(...). Use the same argument list (args, dtype,
use_qk_l2norm) when invoking the functions run_comparison_benchmark,
run_gdn_decode_bf16_state_benchmark, and run_all_layouts_benchmark so --compare
behaves consistently across versions.
- Around line 1845-1891: The wrapper around gdn_decode_bf16_state currently
ignores the preallocated output and doesn't accept or forward
initial_state_indices for the multi-timestep path, causing benchmark
allocations; update the wrapper signature to add an initial_state_indices:
torch.Tensor (or optional) parameter, and when T>1 forward both output and
initial_state_indices into gdn_decode_bf16_state_mtp by passing the kernel's
output=output and initial_state_indices=initial_state_indices args (keep the
T==1 call unchanged), and update bench_gdn_decode_bf16_state() to
allocate/create initial_state_indices for the T>1 case and pass it through to
the wrapper so no new allocations occur inside the kernel call.
- Around line 2049-2068: The benchmark is passing dt_bias with the wrong dtype
for the BF16-state kernel; update the dt_bias tensor construction used in the
gdn_decode_bf16_state benchmark calls so dt_bias is created as float32
(torch.float32) with shape [num_sab_heads * head_size] or [HV] as required, then
pass that float32 dt_bias into gdn_decode_bf16_state_wrapper (and the second
BF16-state call around the other occurrence). Locate the dt_bias variable used
in the gdn_decode_bf16_state_wrapper invocations and change its dtype to
torch.float32 while keeping device and shape identical.
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 824-831: The test test_pretranspose_api_uses_gdn_decode_bf16_state
incorrectly always calls the single-token kernel gdn_decode_bf16_state; update
the dispatch so when seq_len > 1 it calls gdn_decode_bf16_state_mtp and supplies
the initial_state_indices argument (preserving the existing call for seq_len ==
1). Locate the direct kernel invocation around the test body (the call to
gdn_decode_bf16_state) and add a conditional: if seq_len == 1 keep the existing
call, else call gdn_decode_bf16_state_mtp with the same parameters plus
initial_state_indices so the multi-token verification path is exercised.
---
Nitpick comments:
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 1148-1179: Add an explicit verification of the cached intermediate
states when cache_intermediate_states=True: after calling
gdn_decode_bf16_state_mtp with intermediate_states_buffer provided, run a
reference call that produces the expected intermediate states (e.g., call
gdn_decode_bf16_state_mtp or the float32 equivalent with caching
disabled/producing a reference buffer) and compare intermediate_states_buffer to
that reference using an appropriate numeric tolerance and dtype/device
conversion (use torch.testing.assert_allclose or equivalent) so the BF16 MTP
test path actually validates buffer contents; apply the same check in the other
test location that mirrors this logic (the block around the second call
referenced in the comment).
- Around line 634-637: The helper function in
tests/gdn/test_decode_delta_rule.py declares an unused parameter alpha alongside
beta and seed; either remove alpha from the helper signature and all its call
sites in that file (including the BF16 helper usages) or modify the BF16 helper
tests to actually use the alpha parameter, ensuring you update the function
signature and every invocation consistently; reference the parameter names
alpha, beta, seed when making the change so you catch all occurrences (the same
unused-alpha issue also appears later in the file around the BF16 helper
usages).
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between e08e8f3 and d470d1dd9ff67f16a1898e1699fddd22cdf3ae1e.
📒 Files selected for processing (6)
benchmarks/bench_gdn_decode.pyflashinfer/gdn_decode.pyflashinfer/gdn_kernels/__init__.pyflashinfer/gdn_kernels/gdn_decode_bf16_state.pyresults_bf16_optimizations/cooprow_bf16_vs_optimized_fp32_mtp.mdtests/gdn/test_decode_delta_rule.py
Add a high-performance CuTe DSL kernel for GDN decode with BF16 hidden state storage. Provides both T=1 (single token) and MTP (multi-token prediction) variants using a cooperative row approach. Key design: - Each warp processes one V-row at a time (4 warps = 4 V-rows/iter) - cp.async pipeline with TILE_V=8 x TILE_K=128 tiles - H state stored as BF16 in memory, FP32 in registers for compute - ILP-optimized variant for large batch sizes (BS>=32) Consolidated from separate cooprow file into canonical gdn_decode_bf16_state.py, replacing the old 32x128 H-chunk kernel. Updated gdn_decode.py dispatch to use BF16 state kernel for both T=1 and MTP (T>1) when state is BF16 and K=V=128. Benchmark results (B200, Qwen3-Next config, BF16 state MTP vs FP32 MTP): - BS=1-2: 1.09-1.35x speedup - BS=4-16: 1.24-2.21x speedup (biggest gains) - BS=32-512: 1.62-1.81x steady-state speedup - Peak: 13.8 TFLOPS (BS=512, T=8) vs FP32's 7.9 TFLOPS AI-assisted (Claude Code) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
d470d1d to
a906f21
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
benchmarks/bench_gdn_decode.py (1)
1928-1931:⚠️ Potential issue | 🟠 MajorUse float32
dt_biasin BF16-state benchmark paths.BF16-state tests in this PR consistently feed
dt_biasas float32, but these benchmark paths generatedt_biaswithdtype(typically bf16/fp16). That can trigger dtype assertions/casts and distort or fail BF16-state benchmarking.Suggested patch
- dt_bias = torch.randn(num_sab_heads, dtype=dtype, device="cuda") + dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device="cuda")- dt_bias = torch.randn(num_sab_heads, dtype=dtype, device="cuda") + dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device="cuda")Also applies to: 2257-2260
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_gdn_decode.py` around lines 1928 - 1931, The benchmark creates dt_bias with the generic dtype (bf16/fp16) causing dtype mismatches in BF16-state paths; change the dt_bias creation to explicitly use torch.float32 (e.g., dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device="cuda")) wherever dt_bias is constructed alongside A_log, a, b (the dt_bias variable in the block with A_log, a, b) and apply the same fix to the other identical occurrence later in the file.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gdn_decode.py`:
- Line 2420: The function gated_delta_rule_mtp now sets the parameter
disable_state_update default to False which silently changes behavior; update it
to preserve prior behavior by restoring disable_state_update: bool = True in the
gated_delta_rule_mtp signature (or if the change is intentional, update the
function's docstring and any callers to document and adopt
disable_state_update=False) and ensure the docstring text for
gated_delta_rule_mtp describing the default matches the new default to avoid
mismatch.
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 623-635: The helper function _test_gdn_decode_bf16_state_kernel
declares an unused parameter alpha which triggers ARG001; rename alpha to _alpha
(or remove it) to mark it as intentionally unused and silence the linter, and
apply the same rename to the other BF16-state helper(s) in this file that also
declare an unused alpha parameter so all occurrences are fixed.
---
Outside diff comments:
In `@benchmarks/bench_gdn_decode.py`:
- Around line 1928-1931: The benchmark creates dt_bias with the generic dtype
(bf16/fp16) causing dtype mismatches in BF16-state paths; change the dt_bias
creation to explicitly use torch.float32 (e.g., dt_bias =
torch.randn(num_sab_heads, dtype=torch.float32, device="cuda")) wherever dt_bias
is constructed alongside A_log, a, b (the dt_bias variable in the block with
A_log, a, b) and apply the same fix to the other identical occurrence later in
the file.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between d470d1dd9ff67f16a1898e1699fddd22cdf3ae1e and a906f21.
📒 Files selected for processing (5)
benchmarks/bench_gdn_decode.pyflashinfer/gdn_decode.pyflashinfer/gdn_kernels/__init__.pyflashinfer/gdn_kernels/gdn_decode_bf16_state.pytests/gdn/test_decode_delta_rule.py
| def _test_gdn_decode_bf16_state_kernel( | ||
| dtype: str, | ||
| batch_size: int, | ||
| num_q_heads: int, | ||
| num_k_heads: int, | ||
| num_v_heads: int, | ||
| head_size: int, | ||
| seq_len: int, # T=1,2,3,4 | ||
| seq_len: int, | ||
| scale: float, | ||
| alpha: bool, | ||
| beta: bool, | ||
| seed: int | None = None, | ||
| ): |
There was a problem hiding this comment.
Resolve unused alpha parameters in BF16-state helper tests.
alpha is not read in these helper bodies, and Ruff is already flagging it (ARG001). Rename to _alpha (or remove) to avoid lint noise/failures.
Suggested patch
def _test_gdn_decode_bf16_state_kernel(
@@
- alpha: bool,
+ _alpha: bool,
@@
def _test_gdn_decode_bf16_state_t1_kernel(
@@
- alpha: bool,
+ _alpha: bool,Also applies to: 928-939
🧰 Tools
🪛 Ruff (0.15.2)
[warning] 632-632: Unused function argument: alpha
(ARG001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/gdn/test_decode_delta_rule.py` around lines 623 - 635, The helper
function _test_gdn_decode_bf16_state_kernel declares an unused parameter alpha
which triggers ARG001; rename alpha to _alpha (or remove it) to mark it as
intentionally unused and silence the linter, and apply the same rename to the
other BF16-state helper(s) in this file that also declare an unused alpha
parameter so all occurrences are fixed.
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (2)
benchmarks/bench_gdn_decode.py (1)
1878-1891:⚠️ Potential issue | 🟠 MajorBenchmark the MTP kernel with the preallocated
outputtensor.This timed branch still omits
output=output, so each iteration benchmarks extra allocation/copy work instead of just the kernel path.Suggested patch
else: return gdn_decode_bf16_state_mtp( A_log=A_log, a=a, dt_bias=dt_bias, softplus_beta=softplus_beta, softplus_threshold=softplus_threshold, q=q, k=k, v=v, b=b, initial_state_source=state, use_qk_l2norm_in_kernel=use_qk_l2norm, scale=scale, + output=output, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_gdn_decode.py` around lines 1878 - 1891, The benchmark call to gdn_decode_bf16_state_mtp is missing the preallocated output tensor, causing allocations each iteration; update the call site where gdn_decode_bf16_state_mtp(...) is invoked (the function call shown) to pass output=output so the preallocated tensor is used, keeping all other named params (A_log, a, dt_bias, softplus_beta, softplus_threshold, q, k, v, b, initial_state_source=state, use_qk_l2norm_in_kernel=use_qk_l2norm, scale=scale) unchanged.flashinfer/gdn_decode.py (1)
509-546:⚠️ Potential issue | 🟠 MajorPreserve the previous read-only default for
disable_state_update.Changing the default to
Falsemeans existing callers that omit this argument now mutateinitial_state. That is a silent public-API behavior change.Suggested patch
- disable_state_update: bool = False, + disable_state_update: bool = True, @@ disable_state_update (bool): - If True, the initial state is not updated. Default: ``False``. + If True, the initial state is not updated. Default: ``True``.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_decode.py` around lines 509 - 546, The default for the disable_state_update parameter was changed and must be restored to preserve read-only behavior; in the function gdn_decode (the Gated Delta Rule MTP Kernel that accepts disable_state_update: bool), revert the default value of disable_state_update back to True, update any related docs/parameter description in the function docstring to match (i.e., indicate default True and that omitting the argument prevents mutation of initial_state), and ensure any downstream callers/tests relying on the previous default continue to behave the same.
🧹 Nitpick comments (1)
tests/gdn/test_decode_delta_rule.py (1)
1625-1702: Actually assert the cached BF16 intermediate states.When
cache_intermediate_states=True, this helper never checksintermediate_states_buffer. A broken beyond-T>4cache-write path would still pass, which leaves the new feature effectively unverified.Suggested test extension
# Reference: step through tokens with bf16 state ref_state = input_state_ref_bf16.clone() ref_outputs = [] + ref_intermediate_states = [] for t in range(seq_len): ref_o_t, ref_state = decode_delta_rule( q[:, t].float(), k[:, t].float(), @@ use_l2_norm=True, state_dtype=torch.bfloat16, ) ref_outputs.append(ref_o_t) + if cache_intermediate_states: + ref_intermediate_states.append( + ref_state.transpose(-2, -1).contiguous().clone() + ) ref_o = torch.stack(ref_outputs, dim=1).to(dtype_torch) @@ torch.testing.assert_close( our_o.float(), ref_o.float(), atol=atol_o, rtol=rtol_o, msg=f"Output mismatch for MTP BF16 state kernel (B={batch_size}, T={seq_len})", ) + + if cache_intermediate_states and intermediate_states_buffer is not None: + ref_intermediate = torch.stack(ref_intermediate_states, dim=1) + torch.testing.assert_close( + intermediate_states_buffer.float(), + ref_intermediate.float(), + atol=0.02, + rtol=0.01, + msg=( + f"Intermediate-state cache mismatch for MTP BF16 state kernel " + f"(B={batch_size}, T={seq_len})" + ), + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_decode_delta_rule.py` around lines 1625 - 1702, The test never asserts intermediate_states_buffer when cache_intermediate_states=True, so add assertions that the buffer returned/modified by gdn_decode_bf16_state_mtp (intermediate_states_buffer) matches the per-step BF16 intermediate states computed during the reference loop using decode_delta_rule (collect the per-token intermediate state values while building ref_outputs), comparing shapes/dtypes and using tight atol/rtol similar to output checks; reference intermediate_states_buffer, gdn_decode_bf16_state_mtp, decode_delta_rule, our_state and input_state_kernel to locate where to capture and compare the cached states and ensure the cache-write path is actually verified.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gdn_decode.py`:
- Around line 206-212: The BF16 fast-path selection (use_bf16_state) must be
guarded against padded/negative pool indices: extend the predicate that sets
use_bf16_state (which currently uses _GDN_DECODE_BF16_STATE_AVAILABLE,
state_dtype, K and V) to also verify that initial_state_indices either is None
or contains no negative entries (e.g. check initial_state_indices is not
provided OR torch.all(initial_state_indices >= 0) is true) before enabling the
BF16 backend, because the BF16 pooled-state backend does not implement
negative-index semantics and will mis-handle padding slots.
- Around line 235-250: The BF16 MTP fast path currently ignores the caller's
preallocated output and allocates a new tensor then copies back; modify the call
to _gated_delta_rule_bf16_state_mtp so it writes directly into the caller's
output buffer (pass the caller's output as an explicit output argument and
ensure shape/dtype/stride compatibility), remove the downstream copy-from-kernel
back into `output`, and preserve existing flags (use_pool,
initial_state/initial_state_indices, use_qk_l2norm, scale_val) when forwarding
to _gated_delta_rule_bf16_state_mtp so the kernel writes in-place into the
provided `output` buffer.
---
Duplicate comments:
In `@benchmarks/bench_gdn_decode.py`:
- Around line 1878-1891: The benchmark call to gdn_decode_bf16_state_mtp is
missing the preallocated output tensor, causing allocations each iteration;
update the call site where gdn_decode_bf16_state_mtp(...) is invoked (the
function call shown) to pass output=output so the preallocated tensor is used,
keeping all other named params (A_log, a, dt_bias, softplus_beta,
softplus_threshold, q, k, v, b, initial_state_source=state,
use_qk_l2norm_in_kernel=use_qk_l2norm, scale=scale) unchanged.
In `@flashinfer/gdn_decode.py`:
- Around line 509-546: The default for the disable_state_update parameter was
changed and must be restored to preserve read-only behavior; in the function
gdn_decode (the Gated Delta Rule MTP Kernel that accepts disable_state_update:
bool), revert the default value of disable_state_update back to True, update any
related docs/parameter description in the function docstring to match (i.e.,
indicate default True and that omitting the argument prevents mutation of
initial_state), and ensure any downstream callers/tests relying on the previous
default continue to behave the same.
---
Nitpick comments:
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 1625-1702: The test never asserts intermediate_states_buffer when
cache_intermediate_states=True, so add assertions that the buffer
returned/modified by gdn_decode_bf16_state_mtp (intermediate_states_buffer)
matches the per-step BF16 intermediate states computed during the reference loop
using decode_delta_rule (collect the per-token intermediate state values while
building ref_outputs), comparing shapes/dtypes and using tight atol/rtol similar
to output checks; reference intermediate_states_buffer,
gdn_decode_bf16_state_mtp, decode_delta_rule, our_state and input_state_kernel
to locate where to capture and compare the cached states and ensure the
cache-write path is actually verified.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: cb21f102-081e-40a1-854c-620a96b8d9d3
📥 Commits
Reviewing files that changed from the base of the PR and between a906f21 and 2ba157bb1f1ed6e0e90fd950a100d311f4d5c2f7.
📒 Files selected for processing (4)
benchmarks/bench_gdn_decode.pyflashinfer/gdn_decode.pyflashinfer/gdn_kernels/__init__.pytests/gdn/test_decode_delta_rule.py
| use_bf16_state = ( | ||
| _GDN_DECODE_BF16_STATE_AVAILABLE | ||
| and state_dtype == torch.bfloat16 | ||
| and T in (1, 2, 3, 4) | ||
| and K == 128 | ||
| and V == 128 | ||
| ) | ||
| if use_gdn_decode_klast_bf16_state: | ||
| if use_bf16_state: |
There was a problem hiding this comment.
Don’t route negative pool indices into the BF16 fast path.
This predicate still selects the BF16 backend for pooled bf16 state even when initial_state_indices contains padding slots (-1). That backend does not implement the negative-index semantics, so these calls can read/write the wrong pool row instead of honoring padding.
Suggested guard
use_bf16_state = (
_GDN_DECODE_BF16_STATE_AVAILABLE
and state_dtype == torch.bfloat16
and K == 128
and V == 128
)
+ if use_bf16_state and use_pool and (initial_state_indices < 0).any().item():
+ raise ValueError(
+ "Negative initial_state_indices are only supported with float32 state; "
+ "the BF16 fast path does not support padding slots."
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_decode.py` around lines 206 - 212, The BF16 fast-path
selection (use_bf16_state) must be guarded against padded/negative pool indices:
extend the predicate that sets use_bf16_state (which currently uses
_GDN_DECODE_BF16_STATE_AVAILABLE, state_dtype, K and V) to also verify that
initial_state_indices either is None or contains no negative entries (e.g. check
initial_state_indices is not provided OR torch.all(initial_state_indices >= 0)
is true) before enabling the BF16 backend, because the BF16 pooled-state backend
does not implement negative-index semantics and will mis-handle padding slots.
| # MTP kernel supports T>=1 and pool+indices | ||
| out = _gated_delta_rule_bf16_state_mtp( | ||
| A_log=A_log, | ||
| a=a, | ||
| dt_bias=dt_bias, | ||
| softplus_beta=1.0, | ||
| softplus_threshold=20.0, | ||
| q=q, | ||
| k=k, | ||
| v=v, | ||
| b=b, | ||
| initial_state_source=initial_state if use_pool else state, | ||
| initial_state_indices=initial_state_indices, | ||
| use_qk_l2norm_in_kernel=use_qk_l2norm, | ||
| scale=scale_val, | ||
| ) |
There was a problem hiding this comment.
Pass the caller’s output buffer into the BF16 MTP backend.
The new T>1 / pool fast path always allocates a fresh output tensor and then copies it back into output. That defeats preallocation and adds avoidable device traffic in the hot decode loop.
Suggested patch
else:
# MTP kernel supports T>=1 and pool+indices
out = _gated_delta_rule_bf16_state_mtp(
A_log=A_log,
a=a,
dt_bias=dt_bias,
softplus_beta=1.0,
softplus_threshold=20.0,
q=q,
k=k,
v=v,
b=b,
initial_state_source=initial_state if use_pool else state,
initial_state_indices=initial_state_indices,
use_qk_l2norm_in_kernel=use_qk_l2norm,
scale=scale_val,
+ output=output,
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_decode.py` around lines 235 - 250, The BF16 MTP fast path
currently ignores the caller's preallocated output and allocates a new tensor
then copies back; modify the call to _gated_delta_rule_bf16_state_mtp so it
writes directly into the caller's output buffer (pass the caller's output as an
explicit output argument and ensure shape/dtype/stride compatibility), remove
the downstream copy-from-kernel back into `output`, and preserve existing flags
(use_pool, initial_state/initial_state_indices, use_qk_l2norm, scale_val) when
forwarding to _gated_delta_rule_bf16_state_mtp so the kernel writes in-place
into the provided `output` buffer.
2ba157b to
896ebb3
Compare
896ebb3 to
46306ff
Compare
Resolve 4 conflicting files after main's major refactor of gdn_decode.py from a 2643-line monolith into a 645-line API layer with kernel code extracted into flashinfer/gdn_kernels/ submodules. Conflict resolutions: - gdn_decode.py: Accept main's refactored API layer, port feature branch's BF16 state import renames, updated dispatch logic (T=1 no-pool → bf16_state, else → bf16_state_mtp), and disable_state_update default docstring fix. - gdn_kernels/__init__.py: Merge both - keep main's expanded exports and add feature branch's MTP exports and backward-compat aliases. - gdn_kernels/gdn_decode_bf16_state.py: Accept feature branch's complete rewrite with coop-row kernel approach + MTP kernel. - tests/gdn/test_decode_delta_rule.py: Start from main's version (pool+indices tests, negative indices tests), apply feature branch's renames, dispatch split, new T=1 and MTP test functions, remove CI skip marker. AI-assisted merge resolution. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
|
/bot run |
|
[CANCELING] Pipeline #47460078: canceled |
|
/bot run |
|
[FAILED] Pipeline #47462794: 10/20 passed |
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Head branch was pushed to by a user without write access
|
/bot run |
Enable FlashInfer GDN BF16 state MTP kernel on SM100+ by: - Importing and dispatching to gdn_decode_bf16_state.gated_delta_rule_mtp for MTP verify on SM100+ (PR flashinfer-ai/flashinfer#2679) - Removing speculative_algorithm guard in server_args.py so FlashInfer GDN is auto-selected for MTP scenarios on SM100+ with bf16 state - Keeping SM90 (Hopper) FP32 state MTP path unchanged Tested on 4xGB200 TP4 with Qwen3.5-397B-A17B-FP8: - MTP: FlashInfer 5-20% lower TPOT vs Triton at conc>=32 - Non-MTP: no regression vs FlashInfer v0.6.7 stock
Enable FlashInfer GDN BF16 state MTP kernel on SM100+ by: - Importing and dispatching to gdn_decode_bf16_state.gated_delta_rule_mtp for MTP verify on SM100+ (PR flashinfer-ai/flashinfer#2679) - Removing speculative_algorithm guard in server_args.py so FlashInfer GDN is auto-selected for MTP scenarios on SM100+ with bf16 state - Keeping SM90 (Hopper) FP32 state MTP path unchanged Tested on 4xGB200 TP4 with Qwen3.5-397B-A17B-FP8: - MTP: FlashInfer 5-20% lower TPOT vs Triton at conc>=32 - Non-MTP: no regression vs FlashInfer v0.6.7 stock
Enable FlashInfer GDN BF16 state MTP kernel on SM100+ by: - Importing and dispatching to gdn_decode_bf16_state.gated_delta_rule_mtp for MTP verify on SM100+ (PR flashinfer-ai/flashinfer#2679) - Removing speculative_algorithm guard in server_args.py so FlashInfer GDN is auto-selected for MTP scenarios on SM100+ with bf16 state - Keeping SM90 (Hopper) FP32 state MTP path unchanged Tested on 4xGB200 TP4 with Qwen3.5-397B-A17B-FP8: - MTP: FlashInfer 5-20% lower TPOT vs Triton at conc>=32 - Non-MTP: no regression vs FlashInfer v0.6.7 stock
Add a high-performance CuTe DSL kernel for GDN decode with BF16 hidden state storage. Provides both T=1 (single token) and MTP (multi-token prediction) variants using a cooperative row approach.
Key design:
Consolidated from separate cooprow file into canonical gdn_decode_bf16_state.py, replacing the old 32x128 H-chunk kernel. Updated gdn_decode.py dispatch to use BF16 state kernel for both T=1 and MTP (T>1) when state is BF16 and K=V=128.
Benchmark results (B200, Qwen3-Next config, BF16 state MTP vs FP32 MTP):
Cooprow BF16 State vs Optimized FP32 MTP Benchmark
GPU: B200
Config: Qwen3-Next (q_heads=16, k_heads=16, v_heads=32, head_size=128, dtype=bfloat16, qk_l2norm=ON)
Mode:
cache_intermediate_states=ON,disable_state_update=TrueKernels compared:
gated_delta_rule_bf16state_cooprow_mtp— cooperative row BF16 state kernelgated_delta_rule_mtp— FP32 state kernel with ILP rows (1/2/4/8) + SMEM V caching1. Cooprow BF16 State Kernel Time (us)
2. Optimized FP32 MTP Kernel Time (us)
3. Speedup (FP32 time / BF16 time, >1.0 = BF16 wins)
Summary
AI-assisted (Claude Code)
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Refactor
Tests
Documentation