Skip to content

Perf: Optimize GDN decode pretranspose kernel for all batch sizes#2588

Merged
yzh119 merged 1 commit intoflashinfer-ai:mainfrom
ameynaik-hub:ameyn/hfp32_perf_boost
Feb 26, 2026
Merged

Perf: Optimize GDN decode pretranspose kernel for all batch sizes#2588
yzh119 merged 1 commit intoflashinfer-ai:mainfrom
ameynaik-hub:ameyn/hfp32_perf_boost

Conversation

@ameynaik-hub
Copy link
Copy Markdown
Contributor

@ameynaik-hub ameynaik-hub commented Feb 19, 2026

Two optimizations:

  1. Early gate reads: Move gate GMEM reads (r_A_log, r_a, r_dt_bias, r_b) to immediately before the barrier, hiding memory latency during sync.

  2. Always use 8-CTA architecture: Remove batch size conditional that switched to big_batch (1 CTA) for B > 32. The 8-CTA small_batch kernel performs better for ALL batch sizes.

Benchmark results (1000 iters, warmup=10, q_heads=16, k_heads=16, v_heads=32, head_size=128, bfloat16, qk_l2norm=ON):

| Batch | Before (us) | After (us) | Improvement | |-------|-------------|------------|-------------|
| 1 | 3.74 | 3.74 | ~same |
| 2 | 4.22 | 4.22 | ~same |
| 4 | 5.38 | 5.38 | ~same |
| 8 | 7.68 | 7.65 | ~same |
| 16 | 12.90 | 12.90 | ~same |
| 32 | 23.04 | 23.04 | ~same |
| 64 | 51.57 | 42.56 | ~17% |
| 128 | 92.13 | 81.31 | ~12% |
| 256 | 170.18 | 158.98 | ~7% |
| 512 | 334.21 | 314.05 | ~6% |

Correctness verified on all batch sizes (1, 2, 4, 8, 16, 32, 64, 128, 256, 512).

📌 Description

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

Release Notes

  • Refactor
    • Optimized GPU kernel memory management sequences in decoding operations with improved synchronization and data loading patterns
    • Consolidated kernel execution logic by unifying batch-size handling, removing conditional branching for more consistent performance
    • Enhanced thread-level resource efficiency through refined writeback operations in the decoding pipeline

Two optimizations:

1. Early gate reads: Move gate GMEM reads (r_A_log, r_a, r_dt_bias, r_b)
   to immediately before the barrier, hiding memory latency during sync.

2. Always use 8-CTA architecture: Remove batch size conditional that
   switched to big_batch (1 CTA) for B > 32. The 8-CTA small_batch
   kernel performs better for ALL batch sizes.

Benchmark results (1000 iters, warmup=10, q_heads=16, k_heads=16, v_heads=32, head_size=128, bfloat16, qk_l2norm=ON):

| Batch | Before (us) | After (us) | Improvement |
|-------|-------------|------------|-------------|
|     1 |        3.74 |       3.74 |       ~same |
|     2 |        4.22 |       4.22 |       ~same |
|     4 |        5.38 |       5.38 |       ~same |
|     8 |        7.68 |       7.65 |       ~same |
|    16 |       12.90 |      12.90 |       ~same |
|    32 |       23.04 |      23.04 |       ~same |
|    64 |       51.57 |      42.56 |        ~17% |
|   128 |       92.13 |      81.31 |        ~12% |
|   256 |      170.18 |     158.98 |         ~7% |
|   512 |      334.21 |     314.05 |         ~6% |

Correctness verified on all batch sizes (1, 2, 4, 8, 16, 32, 64, 128, 256, 512).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @ameynaik-hub, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the performance of the GDN decode pretranspose kernel by implementing two key optimizations. It reduces memory latency by strategically reordering gate value reads and simplifies the kernel selection process by standardizing on a single, more efficient 8-CTA architecture, leading to notable speedups for larger batch sizes.

Highlights

  • Memory Latency Hiding: Gate memory reads (r_A_log, r_a, r_dt_bias, r_b) were moved to occur earlier in the gdn_decode_kernel_small_batch_pretranspose function, specifically before the cute.arch.barrier(), to hide memory latency during synchronization.
  • Kernel Architecture Unification: The conditional logic that previously switched between small_batch (8-CTA) and big_batch (1-CTA) kernels based on batch size was removed. The small_batch (8-CTA) kernel is now always used in gated_delta_rule_decode_pretranspose, as benchmarks demonstrated its superior performance across all batch sizes.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flashinfer/gdn_decode.py
    • Moved the initialization of r_A_log, r_a, r_dt_bias, and r_b from global memory to immediately before the cute.arch.barrier() call within gdn_decode_kernel_small_batch_pretranspose.
    • Removed the if B <= 32 conditional statement in gated_delta_rule_decode_pretranspose that previously selected between small_batch and big_batch kernels, now always assigning run_gdn_decode_kernel_small_batch_pretranspose to run_func.
Activity
  • The pull request author, ameynaik-hub, created this pull request to optimize the GDN decode pretranspose kernel.
  • The author provided detailed benchmark results demonstrating performance improvements for various batch sizes, with up to ~17% improvement for batch size 64.
  • The author confirmed correctness verification across a range of batch sizes (1, 2, 4, 8, 16, 32, 64, 128, 256, 512).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Feb 19, 2026

📝 Walkthrough

Walkthrough

The pull request optimizes the GDN decode kernel by consolidating kernel selection logic to always use the small-batch pretranspose variant, adjusting memory access patterns by deferring GMEM reads, and simplifying the final writeback condition based on thread index boundaries.

Changes

Cohort / File(s) Summary
GDN Decode Kernel Optimization
flashinfer/gdn_decode.py
Removed conditional batch-size-based kernel selection; moved A_log, a, dt_bias, and b GMEM reads from initialization to post-index-calculation phase; simplified final writeback condition from computed range check to tidx < V in small-batch pretranspose kernel.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Suggested labels

v0.6.2

Suggested reviewers

  • cyx-6
  • nvmbreughe
  • yzh119

Poem

🐰 Hops through kernels with delight,
Batches merged, logic made tight,
Memory reads dance at the right time,
Simpler writebacks, code sublime!

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed Title clearly summarizes the main optimization: switching the GDN decode pretranspose kernel to always use the 8-CTA architecture for better performance across all batch sizes.
Description check ✅ Passed Description provides clear technical details of both optimizations, includes comprehensive benchmark results validating improvements, confirms correctness testing, and aligns with the repository template structure.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces two performance optimizations for the GDN decode pretranspose kernel. First, it moves the gate value reads from global memory to just before a barrier, which helps hide memory latency. Second, it removes the conditional logic that selected different kernels based on batch size, now always using the 8-CTA small_batch kernel, as benchmarks show it performs better for all batch sizes. The changes are logical and supported by the provided performance data. My only suggestion is to remove the now-unused big_batch kernel and its launcher function to clean up the codebase.

Comment thread flashinfer/gdn_decode.py
Comment on lines 1105 to -1109

# Choose kernel based on batch size
if B <= 32:
run_func = run_gdn_decode_kernel_small_batch_pretranspose
else:
run_func = run_gdn_decode_kernel_big_batch_pretranspose
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

With the removal of this conditional logic, the run_gdn_decode_kernel_big_batch_pretranspose function and the gdn_decode_kernel_big_batch_pretranspose kernel it calls are no longer used. To improve maintainability, this dead code should be removed from the file.

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Feb 19, 2026

/bot run

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/gdn_decode.py (1)

410-416: ⚠️ Potential issue | 🟡 Minor

Fix the comment to reflect that only a subset of threads write per CTA in the small_batch kernel.

The writeback condition at line 415 is correct: if tidx >= start_v_tiles * TILE_V and tidx < end_v_tiles * TILE_V: appropriately restricts writes to this CTA's assigned V-range. However, the comment at line 412 is misleading — it states "All threads write (V=128, NUM_THREADS=128)", which is only true for the big_batch kernel (line 685, using if tidx < V:).

In the small_batch kernel with 8 CTAs per state, only V/8 (≈16) threads per block satisfy the range condition. Update the comment to:

# Each CTA writes its assigned V-range (1/NUM_BLOCKS_PER_STATE of V)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 410 - 416, The comment incorrectly
claims "All threads write (V=128, NUM_THREADS=128)"; update it to reflect that
only the threads in this CTA write its assigned V-range by replacing the comment
above the writeback (near the block using tidx, start_v_tiles, TILE_V,
end_v_tiles, sOutput and o) with something like: "# Each CTA writes its assigned
V-range (1/NUM_BLOCKS_PER_STATE of V)" so it accurately documents that only a
subset of threads per block perform the write in the small_batch kernel.
🧹 Nitpick comments (1)
flashinfer/gdn_decode.py (1)

1106-1107: gdn_decode_kernel_big_batch_pretranspose and run_gdn_decode_kernel_big_batch_pretranspose are now dead code.

The switch to always using run_gdn_decode_kernel_small_batch_pretranspose makes the big_batch pretranspose kernel (line 420) and its runner (line 794) unreachable from gated_delta_rule_decode_pretranspose. Since these functions are never called externally and only reference each other internally, they accumulate maintenance overhead without providing value.

Consider removing both functions or adding a # NOTE: unused — kept for reference comment with a TODO to delete them.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 1106 - 1107, The big-batch
pretranspose kernel and its runner are now unreachable; remove
gdn_decode_kernel_big_batch_pretranspose and
run_gdn_decode_kernel_big_batch_pretranspose (they only call each other and are
never referenced from gated_delta_rule_decode_pretranspose) to eliminate dead
code, or alternatively add a clear comment above each function like "# NOTE:
unused — kept for reference; TODO: delete" so reviewers know they are
intentionally retained; update any imports/exports and tests that might
reference these symbols to avoid breakage.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@flashinfer/gdn_decode.py`:
- Around line 410-416: The comment incorrectly claims "All threads write (V=128,
NUM_THREADS=128)"; update it to reflect that only the threads in this CTA write
its assigned V-range by replacing the comment above the writeback (near the
block using tidx, start_v_tiles, TILE_V, end_v_tiles, sOutput and o) with
something like: "# Each CTA writes its assigned V-range (1/NUM_BLOCKS_PER_STATE
of V)" so it accurately documents that only a subset of threads per block
perform the write in the small_batch kernel.

---

Nitpick comments:
In `@flashinfer/gdn_decode.py`:
- Around line 1106-1107: The big-batch pretranspose kernel and its runner are
now unreachable; remove gdn_decode_kernel_big_batch_pretranspose and
run_gdn_decode_kernel_big_batch_pretranspose (they only call each other and are
never referenced from gated_delta_rule_decode_pretranspose) to eliminate dead
code, or alternatively add a clear comment above each function like "# NOTE:
unused — kept for reference; TODO: delete" so reviewers know they are
intentionally retained; update any imports/exports and tests that might
reference these symbols to avoid breakage.

@ameynaik-hub
Copy link
Copy Markdown
Contributor Author

@kahyunnam

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !330 has been created, and the CI pipeline #44415275 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #44415275: 9/20 passed

@ameynaik-hub
Copy link
Copy Markdown
Contributor Author

#2610 will help with failures.

@ameynaik-hub
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

@ameynaik-hub is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

@kahyunnam
Copy link
Copy Markdown
Member

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !330 has been created, and the CI pipeline #44736192 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #44736192: 10/20 passed

@ameynaik-hub
Copy link
Copy Markdown
Contributor Author

@kahyunnam @yzh119 I dont think the failures are related to my code. looks like tinygemm issue?
https://gitlab-master.nvidia.com/dl/flashinfer/flashinfer-ci/-/pipelines/44736192

@kahyunnam
Copy link
Copy Markdown
Member

@kahyunnam @yzh119 I dont think the failures are related to my code. looks like tinygemm issue? https://gitlab-master.nvidia.com/dl/flashinfer/flashinfer-ci/-/pipelines/44736192

I think this is a known issue, we can probably ignore it. @jimmyzho is working on this but OOTO

@yzh119 yzh119 merged commit 1589ebb into flashinfer-ai:main Feb 26, 2026
20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants