Skip to content

feat(gdn): add BF16 state kernel with MTP support beyond T>4 with intermediate caching.#2679

Merged
kahyunnam merged 13 commits into
flashinfer-ai:mainfrom
ameynaik-hub:ameyn/gdn_bf16_improvements
Apr 2, 2026
Merged

feat(gdn): add BF16 state kernel with MTP support beyond T>4 with intermediate caching.#2679
kahyunnam merged 13 commits into
flashinfer-ai:mainfrom
ameynaik-hub:ameyn/gdn_bf16_improvements

Conversation

@ameynaik-hub

@ameynaik-hub ameynaik-hub commented Mar 3, 2026

Copy link
Copy Markdown
Contributor

Add a high-performance CuTe DSL kernel for GDN decode with BF16 hidden state storage. Provides both T=1 (single token) and MTP (multi-token prediction) variants using a cooperative row approach.

Key design:

  • Each warp processes one V-row at a time (4 warps = 4 V-rows/iter)
  • cp.async pipeline with TILE_V=8 x TILE_K=128 tiles
  • H state stored as BF16 in memory, FP32 in registers for compute
  • ILP-optimized variant for large batch sizes (BS>=32)

Consolidated from separate cooprow file into canonical gdn_decode_bf16_state.py, replacing the old 32x128 H-chunk kernel. Updated gdn_decode.py dispatch to use BF16 state kernel for both T=1 and MTP (T>1) when state is BF16 and K=V=128.

Benchmark results (B200, Qwen3-Next config, BF16 state MTP vs FP32 MTP):

  • BS=1-2: 1.09-1.35x speedup
  • BS=4-16: 1.24-2.21x speedup (biggest gains)
  • BS=32-512: 1.62-1.81x steady-state speedup
  • Peak: 13.8 TFLOPS (BS=512, T=8) vs FP32's 7.9 TFLOPS

Cooprow BF16 State vs Optimized FP32 MTP Benchmark

GPU: B200
Config: Qwen3-Next (q_heads=16, k_heads=16, v_heads=32, head_size=128, dtype=bfloat16, qk_l2norm=ON)
Mode: cache_intermediate_states=ON, disable_state_update=True

Kernels compared:

  • Cooprow BF16: gated_delta_rule_bf16state_cooprow_mtp — cooperative row BF16 state kernel
  • Optimized FP32 MTP: gated_delta_rule_mtp — FP32 state kernel with ILP rows (1/2/4/8) + SMEM V caching

1. Cooprow BF16 State Kernel Time (us)

BS \ T 2 3 4 5 6 7 8
1 5.18 5.82 6.62 8.74 9.38 10.08 11.07
2 5.66 6.40 7.42 9.38 10.56 11.42 12.45
4 6.67 7.58 8.83 11.20 12.54 13.76 14.94
8 11.33 10.82 12.86 16.13 18.42 20.40 22.51
16 13.68 17.44 21.23 26.18 30.18 34.43 38.29
32 23.30 30.32 38.30 46.85 54.59 62.24 70.27
64 42.37 55.74 69.71 85.46 100.11 114.93 129.42
128 78.56 101.86 129.73 159.89 188.13 216.77 245.31
256 149.39 194.24 248.41 307.04 362.75 418.59 475.36
512 289.76 376.80 483.71 598.51 708.69 842.46 932.94

2. Optimized FP32 MTP Kernel Time (us)

BS \ T 2 3 4 5 6 7 8
1 5.66 7.04 8.34 9.79 10.93 12.54 13.85
2 6.61 8.26 9.95 11.58 13.22 14.91 16.78
4 9.50 11.94 14.08 16.26 18.64 28.77 23.65
8 14.08 17.60 28.22 24.82 29.20 33.09 37.73
16 22.96 27.65 47.02 43.84 64.59 73.97 84.75
32 40.48 55.62 64.32 76.29 89.78 105.57 119.10
64 68.45 92.64 119.04 142.56 167.07 194.13 218.30
128 129.63 176.99 222.93 270.22 317.36 369.17 419.15
256 250.81 341.90 432.94 524.08 617.77 721.64 822.65
512 492.59 671.39 854.70 1039.02 1232.91 1431.24 1691.08

3. Speedup (FP32 time / BF16 time, >1.0 = BF16 wins)

BS \ T 2 3 4 5 6 7 8
1 1.09 1.21 1.26 1.12 1.17 1.24 1.25
2 1.17 1.29 1.34 1.24 1.25 1.31 1.35
4 1.42 1.57 1.59 1.45 1.49 2.09 1.58
8 1.24 1.63 2.19 1.54 1.59 1.62 1.68
16 1.68 1.59 2.21 1.67 2.14 2.15 2.21
32 1.74 1.83 1.68 1.63 1.64 1.70 1.69
64 1.62 1.66 1.71 1.67 1.67 1.69 1.69
128 1.65 1.74 1.72 1.69 1.69 1.70 1.71
256 1.68 1.76 1.74 1.71 1.70 1.72 1.73
512 1.70 1.78 1.77 1.74 1.74 1.70 1.81

Summary

  • BS=1-2: 1.09-1.35x — cooprow BF16 wins but margins are smaller
  • BS=4-16: 1.24-2.21x — biggest gains; >2x spikes at BS=4-16 likely indicate tile-size transitions in the FP32 kernel
  • BS=32-512: 1.62-1.81x — consistent ~1.70-1.78x steady-state speedup
  • Peak TFLOPS: Cooprow BF16 reaches 13.8 TFLOPS (BS=512, T=8) vs FP32's 7.9 TFLOPS

AI-assisted (Claude Code)

📌 Description

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Automatic BF16 State selection: single-step vs multi-token (MTP) chosen at runtime by sequence length.
    • Exposes additional BF16 State kernel variants for improved multi-token performance.
  • Refactor

    • Unified "BF16 State" naming across CLI, benchmarks, outputs, and help text.
    • Default state-update behavior for gated-delta operations changed.
  • Tests

    • Expanded coverage for single-step and MTP BF16 State paths.
  • Documentation

    • Updated CLI help, examples, benchmark legends, and run descriptions.

@coderabbitai

coderabbitai Bot commented Mar 3, 2026

Copy link
Copy Markdown
Contributor

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR splits and renames the BF16-state GDN decode backend into two variants (single-token BF16 state and BF16-state MTP for multi-token), updates kernel exports/availability flags, changes runtime dispatch to choose T==1 vs T>1 kernels, and adjusts benchmarks/tests/CLI to match the new APIs.

Changes

Cohort / File(s) Summary
Benchmarks / CLI
benchmarks/bench_gdn_decode.py
Renamed klast_bf16 → bf16_state, relaxed seq-len to T>=1, runtime dispatch now calls single-token BF16 for T==1 or MTP for T>1, added MTP wrapper args and updated output/labels.
Kernel Exports
flashinfer/gdn_kernels/__init__.py
Removed GatedDeltaRuleKernel export; added gated_delta_rule_mtp, gated_delta_rule_bf16state_cooprow, gated_delta_rule_bf16state_cooprow_mtp; updated imports, ImportError fallbacks, and __all__.
Core Decode Implementation
flashinfer/gdn_decode.py
Renamed availability flag to _GDN_DECODE_BF16_STATE_AVAILABLE; introduced _gated_delta_rule_bf16_state and _gated_delta_rule_bf16_state_mtp; dispatch selects T==1 vs T>1/pool paths; changed gated_delta_rule_mtp default disable_state_update to False.
Tests
tests/gdn/test_decode_delta_rule.py
Replaced klast_bf16 test references with bf16_state; added explicit T=1 and MTP (T>1) test branches and helpers; updated imports, assertions, messages, and smoke-test flows to new kernel naming.

Sequence Diagram(s)

sequenceDiagram
    participant Client as Decode API
    participant Dispatch as Kernel Dispatch
    participant BF16 as BF16 State (T=1)
    participant MTP as BF16 State MTP (T>1)

    Client->>Dispatch: decode(T, dtype=bf16, args...)
    activate Dispatch
    alt T == 1
        Dispatch->>BF16: call _gated_delta_rule_bf16_state(...)
        activate BF16
        BF16-->>Dispatch: result
        deactivate BF16
    else T > 1
        Dispatch->>MTP: call _gated_delta_rule_bf16_state_mtp(...)
        activate MTP
        MTP-->>Dispatch: result
        deactivate MTP
    end
    deactivate Dispatch
    Dispatch-->>Client: decoded output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related issues

Possibly related PRs

Suggested labels

model: qwen3-next, run-ci

Suggested reviewers

  • bkryu
  • cyx-6
  • nvmbreughe
  • jimmyzho
  • jiahanc
  • yzh119

Poem

🐇
I hopped from KLAST to BF16 at night,
One-step for T==1, MTP for flight.
Kernels renamed, dispatch set true—
A carrot-coded shuffle, through and through. 🥕

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 72.22% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: adding a BF16 state kernel with MTP support. It accurately reflects the primary objective of the pull request.
Description check ✅ Passed The description is mostly complete with technical details and benchmarks, but the required template sections (Description rationale, Related Issues) lack substantive content—only placeholders remain. The checklist shows pre-commit and tests pending.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist

Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the GDN decode functionality by integrating a new, highly optimized BF16 state kernel. This kernel, built with CuTe DSL, provides substantial performance gains across various batch sizes and sequence lengths, particularly for multi-token prediction. The changes involve refactoring existing kernel implementations, updating benchmarking infrastructure, and expanding test coverage to ensure correctness and efficiency of the new BF16 state processing.

Highlights

  • New BF16 State Kernel: Introduced a high-performance CuTe DSL kernel for GDN decode with BF16 hidden state storage, supporting both single-token (T=1) and multi-token prediction (MTP) variants.
  • Performance Improvements: Achieved significant speedups: 1.09-1.35x for batch sizes 1-2, 1.24-2.21x for batch sizes 4-16 (biggest gains), and a consistent 1.62-1.81x steady-state speedup for batch sizes 32-512. Peak performance reached 13.8 TFLOPS compared to FP32's 7.9 TFLOPS.
  • Kernel Design: The new kernel employs a cooperative row approach, where each warp processes one V-row, utilizes a cp.async pipeline with TILE_V=8 x TILE_K=128 tiles, stores H state as BF16 in memory (FP32 in registers for compute), and includes an ILP-optimized variant for large batch sizes (BS>=32).
  • Refactored GDN Decode: The previous gdn_decode_klast_bf16_state kernel has been consolidated and replaced by the new gdn_decode_bf16_state.py which now handles both T=1 and MTP (T>1) scenarios when the state is BF16 and K=V=128.
  • Benchmarking and Testing: Updated benchmarks and tests to reflect the new BF16 state kernel, including dedicated tests for T=1 and MTP (T>=2) scenarios, and expanded seq_len support from T=1,2,3,4 to T>=1.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • benchmarks/bench_gdn_decode.py
    • Renamed references from gdn_decode_klast_bf16_state to bf16_state for clarity and consistency.
    • Updated kernel import and availability check to GDN_DECODE_BF16_STATE_AVAILABLE.
    • Modified the wrapper function gdn_decode_klast_bf16_state_wrapper to gdn_decode_bf16_state_wrapper and added dispatch logic for T=1 and MTP (T>1) kernels.
    • Expanded seq_len support in benchmarks from fixed values (1,2,3,4) to any T>=1.
    • Adjusted benchmark output table headers and summary statistics to reflect the new kernel name.
  • flashinfer/gdn_kernels/init.py
    • Updated documentation comments to distinguish between T=1 and MTP (T>=1) BF16 state kernels.
    • Imported gated_delta_rule_mtp and added backward compatibility aliases for gated_delta_rule_bf16state_cooprow and gated_delta_rule_bf16state_cooprow_mtp.
    • Removed the export of GatedDeltaRuleKernel.
  • results_bf16_optimizations/cooprow_bf16_vs_optimized_fp32_mtp.md
    • Added a new markdown file detailing benchmark results comparing the cooperative row BF16 state kernel against an optimized FP32 MTP kernel.
  • tests/gdn/test_decode_delta_rule.py
    • Renamed kernel references and availability flags from gdn_decode_klast_bf16_state to gdn_decode_bf16_state.
    • Imported gdn_decode_bf16_state_mtp for multi-token prediction testing.
    • Updated the core test function _test_gdn_decode_klast_bf16_state_kernel to _test_gdn_decode_bf16_state_kernel, which now dispatches to the appropriate T=1 or MTP kernel based on seq_len.
    • Removed the seq_len constraint of T=1,2,3,4, allowing any T>=1.
    • Introduced new dedicated test functions _test_gdn_decode_bf16_state_t1_kernel and _test_gdn_decode_bf16_state_mtp_kernel for thorough validation of single-token and multi-token BF16 state kernels.
    • Updated the API dispatch test test_pretranspose_api_uses_gdn_decode_klast_bf16_state to test_pretranspose_api_uses_gdn_decode_bf16_state to verify correct kernel selection.
    • Adjusted smoke test output messages to reflect the new kernel naming.
Activity
  • The pull request introduces a new feature with significant performance improvements, as detailed in the benchmark results.
  • New benchmark results have been added in a dedicated markdown file to showcase the performance gains.
  • Extensive unit tests have been added and updated to cover the new BF16 state kernel for both single-token and multi-token prediction scenarios, ensuring correctness and reliability.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a high-performance BF16 state kernel for GDN decode, supporting both single-token (T=1) and multi-token prediction (MTP), which demonstrates significant performance improvements. The changes are well-structured, including updates to benchmarks and the addition of comprehensive tests.

My review has identified a bug in one of the new tests that needs to be addressed to ensure correctness. Additionally, I've provided a couple of suggestions to refactor duplicated code in both the benchmark and test files, which will improve the overall maintainability of the codebase.

Comment on lines +1862 to +1891
if T == 1:
return gdn_decode_bf16_state(
A_log=A_log,
a=a,
dt_bias=dt_bias,
softplus_beta=softplus_beta,
softplus_threshold=softplus_threshold,
q=q,
k=k,
v=v,
b=b,
initial_state_source=state,
use_qk_l2norm_in_kernel=use_qk_l2norm,
scale=scale,
)
else:
return gdn_decode_bf16_state_mtp(
A_log=A_log,
a=a,
dt_bias=dt_bias,
softplus_beta=softplus_beta,
softplus_threshold=softplus_threshold,
q=q,
k=k,
v=v,
b=b,
initial_state_source=state,
use_qk_l2norm_in_kernel=use_qk_l2norm,
scale=scale,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The calls to gdn_decode_bf16_state and gdn_decode_bf16_state_mtp share the same set of arguments. This code can be refactored to reduce duplication and improve maintainability by selecting the kernel function first and then calling it with a shared set of keyword arguments.

    T = q.shape[1]
    kernel_fn = gdn_decode_bf16_state if T == 1 else gdn_decode_bf16_state_mtp
    return kernel_fn(
        A_log=A_log,
        a=a,
        dt_bias=dt_bias,
        softplus_beta=softplus_beta,
        softplus_threshold=softplus_threshold,
        q=q,
        k=k,
        v=v,
        b=b,
        initial_state_source=state,
        use_qk_l2norm_in_kernel=use_qk_l2norm,
        scale=scale,
    )

Comment on lines +929 to +940
def _test_gdn_decode_bf16_state_t1_kernel(
dtype: str,
batch_size: int,
num_q_heads: int,
num_k_heads: int,
num_v_heads: int,
head_size: int,
scale: float,
alpha: bool,
beta: bool,
seed: int | None = None,
):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's significant code duplication for test data generation across _test_gdn_decode_bf16_state_kernel, _test_gdn_decode_bf16_state_t1_kernel, and _test_gdn_decode_bf16_state_mtp_kernel. To improve maintainability, consider refactoring the tensor creation logic into a shared helper function or a pytest fixture. This would make the tests cleaner and easier to manage.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
benchmarks/bench_gdn_decode.py (3)

2788-2799: ⚠️ Potential issue | 🟠 Major

Restore --compare routing for decode versions in main().

For non-MTP paths, args.compare is currently ignored, so single-layout comparison mode is no longer reachable from the CLI.

Suggested fix
-    if args.version == "mtp":
+    if args.version == "mtp":
         # MTP mode: use comparison or flashinfer-only
         if args.compare:
             run_comparison_benchmark(args, dtype, use_qk_l2norm)
         else:
             run_flashinfer_only_benchmark(args, dtype, use_qk_l2norm)
     elif args.version == "bf16_state":
         # BF16 state benchmark: T=1 and MTP T>=2 vs FP32 MTP
         run_gdn_decode_bf16_state_benchmark(args, dtype, use_qk_l2norm)
     else:
-        # Non-MTP: always run all layouts comparison (FlashInfer/Triton x pretranspose/nontranspose + gdn_decode_bf16_state)
-        run_all_layouts_benchmark(args, dtype, use_qk_l2norm)
+        # Decode mode: honor --compare flag
+        if args.compare:
+            run_comparison_benchmark(args, dtype, use_qk_l2norm)
+        else:
+            run_all_layouts_benchmark(args, dtype, use_qk_l2norm)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_gdn_decode.py` around lines 2788 - 2799, The non-MTP
branches ignore args.compare so the CLI --compare mode is unreachable; update
the bf16_state and else branches in main(): for the "bf16_state" branch, if
args.compare call run_comparison_benchmark(args, dtype, use_qk_l2norm) else call
run_gdn_decode_bf16_state_benchmark(...); for the final else branch, if
args.compare call run_comparison_benchmark(...) else call
run_all_layouts_benchmark(...). Use the same argument list (args, dtype,
use_qk_l2norm) when invoking the functions run_comparison_benchmark,
run_gdn_decode_bf16_state_benchmark, and run_all_layouts_benchmark so --compare
behaves consistently across versions.

1845-1891: ⚠️ Potential issue | 🟠 Major

Forward preallocated tensors in BF16 MTP wrapper to avoid benchmark allocations.

The wrapper accepts output and ignores it in both T=1 and T>1 paths. For T>1, the MTP kernel supports both output and initial_state_indices parameters, but the wrapper doesn't forward them. This causes allocations during benchmark timing, skewing measurements.

Add initial_state_indices parameter to wrapper signature and forward both parameters to gdn_decode_bf16_state_mtp(). Update bench_gdn_decode_bf16_state() to create and pass initial_state_indices when T>1.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_gdn_decode.py` around lines 1845 - 1891, The wrapper around
gdn_decode_bf16_state currently ignores the preallocated output and doesn't
accept or forward initial_state_indices for the multi-timestep path, causing
benchmark allocations; update the wrapper signature to add an
initial_state_indices: torch.Tensor (or optional) parameter, and when T>1
forward both output and initial_state_indices into gdn_decode_bf16_state_mtp by
passing the kernel's output=output and
initial_state_indices=initial_state_indices args (keep the T==1 call unchanged),
and update bench_gdn_decode_bf16_state() to allocate/create
initial_state_indices for the T>1 case and pass it through to the wrapper so no
new allocations occur inside the kernel call.

2049-2068: ⚠️ Potential issue | 🟠 Major

Use float32 dt_bias for BF16-state benchmark calls.

The gdn_decode_bf16_state kernel specification requires dt_bias as [HV] float32, as confirmed by kernel docstrings and test code comments. These benchmark paths currently create it with dtype (typically BF16), which diverges from the intended interface and may benchmark a different numerical path than intended.

Suggested fix
-    dt_bias = torch.randn(num_sab_heads, dtype=dtype, device="cuda")
+    dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device="cuda")

Also applies to: 2256-2260

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_gdn_decode.py` around lines 2049 - 2068, The benchmark is
passing dt_bias with the wrong dtype for the BF16-state kernel; update the
dt_bias tensor construction used in the gdn_decode_bf16_state benchmark calls so
dt_bias is created as float32 (torch.float32) with shape [num_sab_heads *
head_size] or [HV] as required, then pass that float32 dt_bias into
gdn_decode_bf16_state_wrapper (and the second BF16-state call around the other
occurrence). Locate the dt_bias variable used in the
gdn_decode_bf16_state_wrapper invocations and change its dtype to torch.float32
while keeping device and shape identical.
tests/gdn/test_decode_delta_rule.py (1)

824-831: ⚠️ Potential issue | 🟠 Major

Fix kernel dispatch for seq_len > 1 in pretranspose API test.

The test parametrizes seq_len with [1, 2, 3, 4] (line 824), but the direct kernel call at line 902 always invokes gdn_decode_bf16_state, which is the single-token (T=1) kernel. For seq_len > 1, the test should dispatch to gdn_decode_bf16_state_mtp and pass the initial_state_indices parameter. Currently this breaks verification of the multi-token path.

Suggested fix
-    # Direct improved kernel
-    out_direct = gdn_decode_bf16_state(
+    # Direct kernel: T=1 uses single-token path, T>1 uses MTP path
+    direct_kernel = (
+        gdn_decode_bf16_state if seq_len == 1 else gdn_decode_bf16_state_mtp
+    )
+    direct_kwargs = dict(
         A_log=A_log,
         a=a,
         dt_bias=dt_bias,
         softplus_beta=1.0,
         softplus_threshold=20.0,
         q=q,
         k=k,
         v=v,
         b=b_tensor,
         initial_state_source=state_direct,
         use_qk_l2norm_in_kernel=True,
         scale=scale,
-    )
+    )
+    if seq_len > 1:
+        direct_kwargs["initial_state_indices"] = torch.arange(
+            batch_size, dtype=torch.int32, device=device
+        )
+    out_direct = direct_kernel(**direct_kwargs)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_delta_rule.py` around lines 824 - 831, The test
test_pretranspose_api_uses_gdn_decode_bf16_state incorrectly always calls the
single-token kernel gdn_decode_bf16_state; update the dispatch so when seq_len >
1 it calls gdn_decode_bf16_state_mtp and supplies the initial_state_indices
argument (preserving the existing call for seq_len == 1). Locate the direct
kernel invocation around the test body (the call to gdn_decode_bf16_state) and
add a conditional: if seq_len == 1 keep the existing call, else call
gdn_decode_bf16_state_mtp with the same parameters plus initial_state_indices so
the multi-token verification path is exercised.
🧹 Nitpick comments (2)
tests/gdn/test_decode_delta_rule.py (2)

1148-1179: Validate cached intermediate states in the BF16 MTP test path.

When cache_intermediate_states=True, the test currently doesn’t verify buffer contents. Adding a reference comparison would protect the new intermediate-caching path from silent regressions.

Also applies to: 1218-1225

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_delta_rule.py` around lines 1148 - 1179, Add an
explicit verification of the cached intermediate states when
cache_intermediate_states=True: after calling gdn_decode_bf16_state_mtp with
intermediate_states_buffer provided, run a reference call that produces the
expected intermediate states (e.g., call gdn_decode_bf16_state_mtp or the
float32 equivalent with caching disabled/producing a reference buffer) and
compare intermediate_states_buffer to that reference using an appropriate
numeric tolerance and dtype/device conversion (use torch.testing.assert_allclose
or equivalent) so the BF16 MTP test path actually validates buffer contents;
apply the same check in the other test location that mirrors this logic (the
block around the second call referenced in the comment).

634-637: Remove or use the unused alpha helper parameter.

alpha is currently unused in the BF16 helper tests, which makes the test surface a bit misleading.

Also applies to: 937-940

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_delta_rule.py` around lines 634 - 637, The helper
function in tests/gdn/test_decode_delta_rule.py declares an unused parameter
alpha alongside beta and seed; either remove alpha from the helper signature and
all its call sites in that file (including the BF16 helper usages) or modify the
BF16 helper tests to actually use the alpha parameter, ensuring you update the
function signature and every invocation consistently; reference the parameter
names alpha, beta, seed when making the change so you catch all occurrences (the
same unused-alpha issue also appears later in the file around the BF16 helper
usages).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@benchmarks/bench_gdn_decode.py`:
- Around line 2788-2799: The non-MTP branches ignore args.compare so the CLI
--compare mode is unreachable; update the bf16_state and else branches in
main(): for the "bf16_state" branch, if args.compare call
run_comparison_benchmark(args, dtype, use_qk_l2norm) else call
run_gdn_decode_bf16_state_benchmark(...); for the final else branch, if
args.compare call run_comparison_benchmark(...) else call
run_all_layouts_benchmark(...). Use the same argument list (args, dtype,
use_qk_l2norm) when invoking the functions run_comparison_benchmark,
run_gdn_decode_bf16_state_benchmark, and run_all_layouts_benchmark so --compare
behaves consistently across versions.
- Around line 1845-1891: The wrapper around gdn_decode_bf16_state currently
ignores the preallocated output and doesn't accept or forward
initial_state_indices for the multi-timestep path, causing benchmark
allocations; update the wrapper signature to add an initial_state_indices:
torch.Tensor (or optional) parameter, and when T>1 forward both output and
initial_state_indices into gdn_decode_bf16_state_mtp by passing the kernel's
output=output and initial_state_indices=initial_state_indices args (keep the
T==1 call unchanged), and update bench_gdn_decode_bf16_state() to
allocate/create initial_state_indices for the T>1 case and pass it through to
the wrapper so no new allocations occur inside the kernel call.
- Around line 2049-2068: The benchmark is passing dt_bias with the wrong dtype
for the BF16-state kernel; update the dt_bias tensor construction used in the
gdn_decode_bf16_state benchmark calls so dt_bias is created as float32
(torch.float32) with shape [num_sab_heads * head_size] or [HV] as required, then
pass that float32 dt_bias into gdn_decode_bf16_state_wrapper (and the second
BF16-state call around the other occurrence). Locate the dt_bias variable used
in the gdn_decode_bf16_state_wrapper invocations and change its dtype to
torch.float32 while keeping device and shape identical.

In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 824-831: The test test_pretranspose_api_uses_gdn_decode_bf16_state
incorrectly always calls the single-token kernel gdn_decode_bf16_state; update
the dispatch so when seq_len > 1 it calls gdn_decode_bf16_state_mtp and supplies
the initial_state_indices argument (preserving the existing call for seq_len ==
1). Locate the direct kernel invocation around the test body (the call to
gdn_decode_bf16_state) and add a conditional: if seq_len == 1 keep the existing
call, else call gdn_decode_bf16_state_mtp with the same parameters plus
initial_state_indices so the multi-token verification path is exercised.

---

Nitpick comments:
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 1148-1179: Add an explicit verification of the cached intermediate
states when cache_intermediate_states=True: after calling
gdn_decode_bf16_state_mtp with intermediate_states_buffer provided, run a
reference call that produces the expected intermediate states (e.g., call
gdn_decode_bf16_state_mtp or the float32 equivalent with caching
disabled/producing a reference buffer) and compare intermediate_states_buffer to
that reference using an appropriate numeric tolerance and dtype/device
conversion (use torch.testing.assert_allclose or equivalent) so the BF16 MTP
test path actually validates buffer contents; apply the same check in the other
test location that mirrors this logic (the block around the second call
referenced in the comment).
- Around line 634-637: The helper function in
tests/gdn/test_decode_delta_rule.py declares an unused parameter alpha alongside
beta and seed; either remove alpha from the helper signature and all its call
sites in that file (including the BF16 helper usages) or modify the BF16 helper
tests to actually use the alpha parameter, ensuring you update the function
signature and every invocation consistently; reference the parameter names
alpha, beta, seed when making the change so you catch all occurrences (the same
unused-alpha issue also appears later in the file around the BF16 helper
usages).

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e08e8f3 and d470d1dd9ff67f16a1898e1699fddd22cdf3ae1e.

📒 Files selected for processing (6)
  • benchmarks/bench_gdn_decode.py
  • flashinfer/gdn_decode.py
  • flashinfer/gdn_kernels/__init__.py
  • flashinfer/gdn_kernels/gdn_decode_bf16_state.py
  • results_bf16_optimizations/cooprow_bf16_vs_optimized_fp32_mtp.md
  • tests/gdn/test_decode_delta_rule.py

Add a high-performance CuTe DSL kernel for GDN decode with BF16 hidden
state storage. Provides both T=1 (single token) and MTP (multi-token
prediction) variants using a cooperative row approach.

Key design:
- Each warp processes one V-row at a time (4 warps = 4 V-rows/iter)
- cp.async pipeline with TILE_V=8 x TILE_K=128 tiles
- H state stored as BF16 in memory, FP32 in registers for compute
- ILP-optimized variant for large batch sizes (BS>=32)

Consolidated from separate cooprow file into canonical
gdn_decode_bf16_state.py, replacing the old 32x128 H-chunk kernel.
Updated gdn_decode.py dispatch to use BF16 state kernel for both T=1
and MTP (T>1) when state is BF16 and K=V=128.

Benchmark results (B200, Qwen3-Next config, BF16 state MTP vs FP32 MTP):
- BS=1-2:   1.09-1.35x speedup
- BS=4-16:  1.24-2.21x speedup (biggest gains)
- BS=32-512: 1.62-1.81x steady-state speedup
- Peak: 13.8 TFLOPS (BS=512, T=8) vs FP32's 7.9 TFLOPS

AI-assisted (Claude Code)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
@ameynaik-hub ameynaik-hub force-pushed the ameyn/gdn_bf16_improvements branch from d470d1d to a906f21 Compare March 3, 2026 20:39

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
benchmarks/bench_gdn_decode.py (1)

1928-1931: ⚠️ Potential issue | 🟠 Major

Use float32 dt_bias in BF16-state benchmark paths.

BF16-state tests in this PR consistently feed dt_bias as float32, but these benchmark paths generate dt_bias with dtype (typically bf16/fp16). That can trigger dtype assertions/casts and distort or fail BF16-state benchmarking.

Suggested patch
-    dt_bias = torch.randn(num_sab_heads, dtype=dtype, device="cuda")
+    dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device="cuda")
-    dt_bias = torch.randn(num_sab_heads, dtype=dtype, device="cuda")
+    dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device="cuda")

Also applies to: 2257-2260

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_gdn_decode.py` around lines 1928 - 1931, The benchmark
creates dt_bias with the generic dtype (bf16/fp16) causing dtype mismatches in
BF16-state paths; change the dt_bias creation to explicitly use torch.float32
(e.g., dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device="cuda"))
wherever dt_bias is constructed alongside A_log, a, b (the dt_bias variable in
the block with A_log, a, b) and apply the same fix to the other identical
occurrence later in the file.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_decode.py`:
- Line 2420: The function gated_delta_rule_mtp now sets the parameter
disable_state_update default to False which silently changes behavior; update it
to preserve prior behavior by restoring disable_state_update: bool = True in the
gated_delta_rule_mtp signature (or if the change is intentional, update the
function's docstring and any callers to document and adopt
disable_state_update=False) and ensure the docstring text for
gated_delta_rule_mtp describing the default matches the new default to avoid
mismatch.

In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 623-635: The helper function _test_gdn_decode_bf16_state_kernel
declares an unused parameter alpha which triggers ARG001; rename alpha to _alpha
(or remove it) to mark it as intentionally unused and silence the linter, and
apply the same rename to the other BF16-state helper(s) in this file that also
declare an unused alpha parameter so all occurrences are fixed.

---

Outside diff comments:
In `@benchmarks/bench_gdn_decode.py`:
- Around line 1928-1931: The benchmark creates dt_bias with the generic dtype
(bf16/fp16) causing dtype mismatches in BF16-state paths; change the dt_bias
creation to explicitly use torch.float32 (e.g., dt_bias =
torch.randn(num_sab_heads, dtype=torch.float32, device="cuda")) wherever dt_bias
is constructed alongside A_log, a, b (the dt_bias variable in the block with
A_log, a, b) and apply the same fix to the other identical occurrence later in
the file.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d470d1dd9ff67f16a1898e1699fddd22cdf3ae1e and a906f21.

📒 Files selected for processing (5)
  • benchmarks/bench_gdn_decode.py
  • flashinfer/gdn_decode.py
  • flashinfer/gdn_kernels/__init__.py
  • flashinfer/gdn_kernels/gdn_decode_bf16_state.py
  • tests/gdn/test_decode_delta_rule.py

Comment thread benchmarks/bench_gdn_decode.py
Comment thread flashinfer/gdn_decode.py Outdated
Comment on lines +623 to 635
def _test_gdn_decode_bf16_state_kernel(
dtype: str,
batch_size: int,
num_q_heads: int,
num_k_heads: int,
num_v_heads: int,
head_size: int,
seq_len: int, # T=1,2,3,4
seq_len: int,
scale: float,
alpha: bool,
beta: bool,
seed: int | None = None,
):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Resolve unused alpha parameters in BF16-state helper tests.

alpha is not read in these helper bodies, and Ruff is already flagging it (ARG001). Rename to _alpha (or remove) to avoid lint noise/failures.

Suggested patch
 def _test_gdn_decode_bf16_state_kernel(
@@
-    alpha: bool,
+    _alpha: bool,
@@
 def _test_gdn_decode_bf16_state_t1_kernel(
@@
-    alpha: bool,
+    _alpha: bool,

Also applies to: 928-939

🧰 Tools
🪛 Ruff (0.15.2)

[warning] 632-632: Unused function argument: alpha

(ARG001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_delta_rule.py` around lines 623 - 635, The helper
function _test_gdn_decode_bf16_state_kernel declares an unused parameter alpha
which triggers ARG001; rename alpha to _alpha (or remove it) to mark it as
intentionally unused and silence the linter, and apply the same rename to the
other BF16-state helper(s) in this file that also declare an unused alpha
parameter so all occurrences are fixed.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (2)
benchmarks/bench_gdn_decode.py (1)

1878-1891: ⚠️ Potential issue | 🟠 Major

Benchmark the MTP kernel with the preallocated output tensor.

This timed branch still omits output=output, so each iteration benchmarks extra allocation/copy work instead of just the kernel path.

Suggested patch
     else:
         return gdn_decode_bf16_state_mtp(
             A_log=A_log,
             a=a,
             dt_bias=dt_bias,
             softplus_beta=softplus_beta,
             softplus_threshold=softplus_threshold,
             q=q,
             k=k,
             v=v,
             b=b,
             initial_state_source=state,
             use_qk_l2norm_in_kernel=use_qk_l2norm,
             scale=scale,
+            output=output,
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_gdn_decode.py` around lines 1878 - 1891, The benchmark call
to gdn_decode_bf16_state_mtp is missing the preallocated output tensor, causing
allocations each iteration; update the call site where
gdn_decode_bf16_state_mtp(...) is invoked (the function call shown) to pass
output=output so the preallocated tensor is used, keeping all other named params
(A_log, a, dt_bias, softplus_beta, softplus_threshold, q, k, v, b,
initial_state_source=state, use_qk_l2norm_in_kernel=use_qk_l2norm, scale=scale)
unchanged.
flashinfer/gdn_decode.py (1)

509-546: ⚠️ Potential issue | 🟠 Major

Preserve the previous read-only default for disable_state_update.

Changing the default to False means existing callers that omit this argument now mutate initial_state. That is a silent public-API behavior change.

Suggested patch
-    disable_state_update: bool = False,
+    disable_state_update: bool = True,
@@
         disable_state_update (bool):
-            If True, the initial state is not updated. Default: ``False``.
+            If True, the initial state is not updated. Default: ``True``.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 509 - 546, The default for the
disable_state_update parameter was changed and must be restored to preserve
read-only behavior; in the function gdn_decode (the Gated Delta Rule MTP Kernel
that accepts disable_state_update: bool), revert the default value of
disable_state_update back to True, update any related docs/parameter description
in the function docstring to match (i.e., indicate default True and that
omitting the argument prevents mutation of initial_state), and ensure any
downstream callers/tests relying on the previous default continue to behave the
same.
🧹 Nitpick comments (1)
tests/gdn/test_decode_delta_rule.py (1)

1625-1702: Actually assert the cached BF16 intermediate states.

When cache_intermediate_states=True, this helper never checks intermediate_states_buffer. A broken beyond-T>4 cache-write path would still pass, which leaves the new feature effectively unverified.

Suggested test extension
     # Reference: step through tokens with bf16 state
     ref_state = input_state_ref_bf16.clone()
     ref_outputs = []
+    ref_intermediate_states = []

     for t in range(seq_len):
         ref_o_t, ref_state = decode_delta_rule(
             q[:, t].float(),
             k[:, t].float(),
@@
             use_l2_norm=True,
             state_dtype=torch.bfloat16,
         )
         ref_outputs.append(ref_o_t)
+        if cache_intermediate_states:
+            ref_intermediate_states.append(
+                ref_state.transpose(-2, -1).contiguous().clone()
+            )

     ref_o = torch.stack(ref_outputs, dim=1).to(dtype_torch)
@@
     torch.testing.assert_close(
         our_o.float(),
         ref_o.float(),
         atol=atol_o,
         rtol=rtol_o,
         msg=f"Output mismatch for MTP BF16 state kernel (B={batch_size}, T={seq_len})",
     )
+
+    if cache_intermediate_states and intermediate_states_buffer is not None:
+        ref_intermediate = torch.stack(ref_intermediate_states, dim=1)
+        torch.testing.assert_close(
+            intermediate_states_buffer.float(),
+            ref_intermediate.float(),
+            atol=0.02,
+            rtol=0.01,
+            msg=(
+                f"Intermediate-state cache mismatch for MTP BF16 state kernel "
+                f"(B={batch_size}, T={seq_len})"
+            ),
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_delta_rule.py` around lines 1625 - 1702, The test never
asserts intermediate_states_buffer when cache_intermediate_states=True, so add
assertions that the buffer returned/modified by gdn_decode_bf16_state_mtp
(intermediate_states_buffer) matches the per-step BF16 intermediate states
computed during the reference loop using decode_delta_rule (collect the
per-token intermediate state values while building ref_outputs), comparing
shapes/dtypes and using tight atol/rtol similar to output checks; reference
intermediate_states_buffer, gdn_decode_bf16_state_mtp, decode_delta_rule,
our_state and input_state_kernel to locate where to capture and compare the
cached states and ensure the cache-write path is actually verified.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_decode.py`:
- Around line 206-212: The BF16 fast-path selection (use_bf16_state) must be
guarded against padded/negative pool indices: extend the predicate that sets
use_bf16_state (which currently uses _GDN_DECODE_BF16_STATE_AVAILABLE,
state_dtype, K and V) to also verify that initial_state_indices either is None
or contains no negative entries (e.g. check initial_state_indices is not
provided OR torch.all(initial_state_indices >= 0) is true) before enabling the
BF16 backend, because the BF16 pooled-state backend does not implement
negative-index semantics and will mis-handle padding slots.
- Around line 235-250: The BF16 MTP fast path currently ignores the caller's
preallocated output and allocates a new tensor then copies back; modify the call
to _gated_delta_rule_bf16_state_mtp so it writes directly into the caller's
output buffer (pass the caller's output as an explicit output argument and
ensure shape/dtype/stride compatibility), remove the downstream copy-from-kernel
back into `output`, and preserve existing flags (use_pool,
initial_state/initial_state_indices, use_qk_l2norm, scale_val) when forwarding
to _gated_delta_rule_bf16_state_mtp so the kernel writes in-place into the
provided `output` buffer.

---

Duplicate comments:
In `@benchmarks/bench_gdn_decode.py`:
- Around line 1878-1891: The benchmark call to gdn_decode_bf16_state_mtp is
missing the preallocated output tensor, causing allocations each iteration;
update the call site where gdn_decode_bf16_state_mtp(...) is invoked (the
function call shown) to pass output=output so the preallocated tensor is used,
keeping all other named params (A_log, a, dt_bias, softplus_beta,
softplus_threshold, q, k, v, b, initial_state_source=state,
use_qk_l2norm_in_kernel=use_qk_l2norm, scale=scale) unchanged.

In `@flashinfer/gdn_decode.py`:
- Around line 509-546: The default for the disable_state_update parameter was
changed and must be restored to preserve read-only behavior; in the function
gdn_decode (the Gated Delta Rule MTP Kernel that accepts disable_state_update:
bool), revert the default value of disable_state_update back to True, update any
related docs/parameter description in the function docstring to match (i.e.,
indicate default True and that omitting the argument prevents mutation of
initial_state), and ensure any downstream callers/tests relying on the previous
default continue to behave the same.

---

Nitpick comments:
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 1625-1702: The test never asserts intermediate_states_buffer when
cache_intermediate_states=True, so add assertions that the buffer
returned/modified by gdn_decode_bf16_state_mtp (intermediate_states_buffer)
matches the per-step BF16 intermediate states computed during the reference loop
using decode_delta_rule (collect the per-token intermediate state values while
building ref_outputs), comparing shapes/dtypes and using tight atol/rtol similar
to output checks; reference intermediate_states_buffer,
gdn_decode_bf16_state_mtp, decode_delta_rule, our_state and input_state_kernel
to locate where to capture and compare the cached states and ensure the
cache-write path is actually verified.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: cb21f102-081e-40a1-854c-620a96b8d9d3

📥 Commits

Reviewing files that changed from the base of the PR and between a906f21 and 2ba157bb1f1ed6e0e90fd950a100d311f4d5c2f7.

📒 Files selected for processing (4)
  • benchmarks/bench_gdn_decode.py
  • flashinfer/gdn_decode.py
  • flashinfer/gdn_kernels/__init__.py
  • tests/gdn/test_decode_delta_rule.py

Comment thread flashinfer/gdn_decode.py
Comment on lines +206 to +212
use_bf16_state = (
_GDN_DECODE_BF16_STATE_AVAILABLE
and state_dtype == torch.bfloat16
and T in (1, 2, 3, 4)
and K == 128
and V == 128
)
if use_gdn_decode_klast_bf16_state:
if use_bf16_state:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don’t route negative pool indices into the BF16 fast path.

This predicate still selects the BF16 backend for pooled bf16 state even when initial_state_indices contains padding slots (-1). That backend does not implement the negative-index semantics, so these calls can read/write the wrong pool row instead of honoring padding.

Suggested guard
     use_bf16_state = (
         _GDN_DECODE_BF16_STATE_AVAILABLE
         and state_dtype == torch.bfloat16
         and K == 128
         and V == 128
     )
+    if use_bf16_state and use_pool and (initial_state_indices < 0).any().item():
+        raise ValueError(
+            "Negative initial_state_indices are only supported with float32 state; "
+            "the BF16 fast path does not support padding slots."
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 206 - 212, The BF16 fast-path
selection (use_bf16_state) must be guarded against padded/negative pool indices:
extend the predicate that sets use_bf16_state (which currently uses
_GDN_DECODE_BF16_STATE_AVAILABLE, state_dtype, K and V) to also verify that
initial_state_indices either is None or contains no negative entries (e.g. check
initial_state_indices is not provided OR torch.all(initial_state_indices >= 0)
is true) before enabling the BF16 backend, because the BF16 pooled-state backend
does not implement negative-index semantics and will mis-handle padding slots.

Comment thread flashinfer/gdn_decode.py
Comment on lines +235 to +250
# MTP kernel supports T>=1 and pool+indices
out = _gated_delta_rule_bf16_state_mtp(
A_log=A_log,
a=a,
dt_bias=dt_bias,
softplus_beta=1.0,
softplus_threshold=20.0,
q=q,
k=k,
v=v,
b=b,
initial_state_source=initial_state if use_pool else state,
initial_state_indices=initial_state_indices,
use_qk_l2norm_in_kernel=use_qk_l2norm,
scale=scale_val,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Pass the caller’s output buffer into the BF16 MTP backend.

The new T>1 / pool fast path always allocates a fresh output tensor and then copies it back into output. That defeats preallocation and adds avoidable device traffic in the hot decode loop.

Suggested patch
         else:
             # MTP kernel supports T>=1 and pool+indices
             out = _gated_delta_rule_bf16_state_mtp(
                 A_log=A_log,
                 a=a,
                 dt_bias=dt_bias,
                 softplus_beta=1.0,
                 softplus_threshold=20.0,
                 q=q,
                 k=k,
                 v=v,
                 b=b,
                 initial_state_source=initial_state if use_pool else state,
                 initial_state_indices=initial_state_indices,
                 use_qk_l2norm_in_kernel=use_qk_l2norm,
                 scale=scale_val,
+                output=output,
             )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 235 - 250, The BF16 MTP fast path
currently ignores the caller's preallocated output and allocates a new tensor
then copies back; modify the call to _gated_delta_rule_bf16_state_mtp so it
writes directly into the caller's output buffer (pass the caller's output as an
explicit output argument and ensure shape/dtype/stride compatibility), remove
the downstream copy-from-kernel back into `output`, and preserve existing flags
(use_pool, initial_state/initial_state_indices, use_qk_l2norm, scale_val) when
forwarding to _gated_delta_rule_bf16_state_mtp so the kernel writes in-place
into the provided `output` buffer.

Resolve 4 conflicting files after main's major refactor of gdn_decode.py
from a 2643-line monolith into a 645-line API layer with kernel code
extracted into flashinfer/gdn_kernels/ submodules.

Conflict resolutions:
- gdn_decode.py: Accept main's refactored API layer, port feature branch's
  BF16 state import renames, updated dispatch logic (T=1 no-pool → bf16_state,
  else → bf16_state_mtp), and disable_state_update default docstring fix.
- gdn_kernels/__init__.py: Merge both - keep main's expanded exports and add
  feature branch's MTP exports and backward-compat aliases.
- gdn_kernels/gdn_decode_bf16_state.py: Accept feature branch's complete
  rewrite with coop-row kernel approach + MTP kernel.
- tests/gdn/test_decode_delta_rule.py: Start from main's version (pool+indices
  tests, negative indices tests), apply feature branch's renames, dispatch
  split, new T=1 and MTP test functions, remove CI skip marker.

AI-assisted merge resolution.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
@ameynaik-hub

Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot

Copy link
Copy Markdown
Collaborator

GitLab MR !450 has been updated with latest changes, and the CI pipeline #47460078 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot

Copy link
Copy Markdown
Collaborator

[CANCELING] Pipeline #47460078: canceled

@ameynaik-hub

Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot

Copy link
Copy Markdown
Collaborator

GitLab MR !450 has been created, and the CI pipeline #47462794 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot

Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #47462794: 10/20 passed

@kahyunnam kahyunnam enabled auto-merge (squash) April 1, 2026 23:33
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
auto-merge was automatically disabled April 2, 2026 03:01

Head branch was pushed to by a user without write access

@ameynaik-hub

Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot

Copy link
Copy Markdown
Collaborator

GitLab MR !450 has been updated with latest changes, and the CI pipeline #47555068 is currently running. I'll report back once the pipeline job completes.

@kahyunnam kahyunnam merged commit 7514bf2 into flashinfer-ai:main Apr 2, 2026
29 of 30 checks passed
YAMY1234 added a commit to YAMY1234/sglang that referenced this pull request Apr 7, 2026
Enable FlashInfer GDN BF16 state MTP kernel on SM100+ by:
- Importing and dispatching to gdn_decode_bf16_state.gated_delta_rule_mtp
  for MTP verify on SM100+ (PR flashinfer-ai/flashinfer#2679)
- Removing speculative_algorithm guard in server_args.py so FlashInfer GDN
  is auto-selected for MTP scenarios on SM100+ with bf16 state
- Keeping SM90 (Hopper) FP32 state MTP path unchanged

Tested on 4xGB200 TP4 with Qwen3.5-397B-A17B-FP8:
- MTP: FlashInfer 5-20% lower TPOT vs Triton at conc>=32
- Non-MTP: no regression vs FlashInfer v0.6.7 stock
YAMY1234 added a commit to YAMY1234/flashinfer that referenced this pull request Apr 7, 2026
YAMY1234 added a commit to YAMY1234/sglang that referenced this pull request Apr 20, 2026
Enable FlashInfer GDN BF16 state MTP kernel on SM100+ by:
- Importing and dispatching to gdn_decode_bf16_state.gated_delta_rule_mtp
  for MTP verify on SM100+ (PR flashinfer-ai/flashinfer#2679)
- Removing speculative_algorithm guard in server_args.py so FlashInfer GDN
  is auto-selected for MTP scenarios on SM100+ with bf16 state
- Keeping SM90 (Hopper) FP32 state MTP path unchanged

Tested on 4xGB200 TP4 with Qwen3.5-397B-A17B-FP8:
- MTP: FlashInfer 5-20% lower TPOT vs Triton at conc>=32
- Non-MTP: no regression vs FlashInfer v0.6.7 stock
YAMY1234 added a commit to YAMY1234/sglang that referenced this pull request Apr 20, 2026
Enable FlashInfer GDN BF16 state MTP kernel on SM100+ by:
- Importing and dispatching to gdn_decode_bf16_state.gated_delta_rule_mtp
  for MTP verify on SM100+ (PR flashinfer-ai/flashinfer#2679)
- Removing speculative_algorithm guard in server_args.py so FlashInfer GDN
  is auto-selected for MTP scenarios on SM100+ with bf16 state
- Keeping SM90 (Hopper) FP32 state MTP path unchanged

Tested on 4xGB200 TP4 with Qwen3.5-397B-A17B-FP8:
- MTP: FlashInfer 5-20% lower TPOT vs Triton at conc>=32
- Non-MTP: no regression vs FlashInfer v0.6.7 stock
@coderabbitai coderabbitai Bot mentioned this pull request Jun 12, 2026
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants