Perf: Optimize GDN decode pretranspose kernel for all batch sizes#2588
Perf: Optimize GDN decode pretranspose kernel for all batch sizes#2588yzh119 merged 1 commit intoflashinfer-ai:mainfrom
Conversation
Two optimizations: 1. Early gate reads: Move gate GMEM reads (r_A_log, r_a, r_dt_bias, r_b) to immediately before the barrier, hiding memory latency during sync. 2. Always use 8-CTA architecture: Remove batch size conditional that switched to big_batch (1 CTA) for B > 32. The 8-CTA small_batch kernel performs better for ALL batch sizes. Benchmark results (1000 iters, warmup=10, q_heads=16, k_heads=16, v_heads=32, head_size=128, bfloat16, qk_l2norm=ON): | Batch | Before (us) | After (us) | Improvement | |-------|-------------|------------|-------------| | 1 | 3.74 | 3.74 | ~same | | 2 | 4.22 | 4.22 | ~same | | 4 | 5.38 | 5.38 | ~same | | 8 | 7.68 | 7.65 | ~same | | 16 | 12.90 | 12.90 | ~same | | 32 | 23.04 | 23.04 | ~same | | 64 | 51.57 | 42.56 | ~17% | | 128 | 92.13 | 81.31 | ~12% | | 256 | 170.18 | 158.98 | ~7% | | 512 | 334.21 | 314.05 | ~6% | Correctness verified on all batch sizes (1, 2, 4, 8, 16, 32, 64, 128, 256, 512). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Summary of ChangesHello @ameynaik-hub, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the performance of the GDN decode pretranspose kernel by implementing two key optimizations. It reduces memory latency by strategically reordering gate value reads and simplifies the kernel selection process by standardizing on a single, more efficient 8-CTA architecture, leading to notable speedups for larger batch sizes. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
📝 WalkthroughWalkthroughThe pull request optimizes the GDN decode kernel by consolidating kernel selection logic to always use the small-batch pretranspose variant, adjusting memory access patterns by deferring GMEM reads, and simplifying the final writeback condition based on thread index boundaries. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Tip Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces two performance optimizations for the GDN decode pretranspose kernel. First, it moves the gate value reads from global memory to just before a barrier, which helps hide memory latency. Second, it removes the conditional logic that selected different kernels based on batch size, now always using the 8-CTA small_batch kernel, as benchmarks show it performs better for all batch sizes. The changes are logical and supported by the provided performance data. My only suggestion is to remove the now-unused big_batch kernel and its launcher function to clean up the codebase.
|
|
||
| # Choose kernel based on batch size | ||
| if B <= 32: | ||
| run_func = run_gdn_decode_kernel_small_batch_pretranspose | ||
| else: | ||
| run_func = run_gdn_decode_kernel_big_batch_pretranspose |
There was a problem hiding this comment.
|
/bot run |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/gdn_decode.py (1)
410-416:⚠️ Potential issue | 🟡 MinorFix the comment to reflect that only a subset of threads write per CTA in the small_batch kernel.
The writeback condition at line 415 is correct:
if tidx >= start_v_tiles * TILE_V and tidx < end_v_tiles * TILE_V:appropriately restricts writes to this CTA's assigned V-range. However, the comment at line 412 is misleading — it states "All threads write (V=128, NUM_THREADS=128)", which is only true for the big_batch kernel (line 685, usingif tidx < V:).In the small_batch kernel with 8 CTAs per state, only V/8 (≈16) threads per block satisfy the range condition. Update the comment to:
# Each CTA writes its assigned V-range (1/NUM_BLOCKS_PER_STATE of V)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_decode.py` around lines 410 - 416, The comment incorrectly claims "All threads write (V=128, NUM_THREADS=128)"; update it to reflect that only the threads in this CTA write its assigned V-range by replacing the comment above the writeback (near the block using tidx, start_v_tiles, TILE_V, end_v_tiles, sOutput and o) with something like: "# Each CTA writes its assigned V-range (1/NUM_BLOCKS_PER_STATE of V)" so it accurately documents that only a subset of threads per block perform the write in the small_batch kernel.
🧹 Nitpick comments (1)
flashinfer/gdn_decode.py (1)
1106-1107:gdn_decode_kernel_big_batch_pretransposeandrun_gdn_decode_kernel_big_batch_pretransposeare now dead code.The switch to always using
run_gdn_decode_kernel_small_batch_pretransposemakes the big_batch pretranspose kernel (line 420) and its runner (line 794) unreachable fromgated_delta_rule_decode_pretranspose. Since these functions are never called externally and only reference each other internally, they accumulate maintenance overhead without providing value.Consider removing both functions or adding a
# NOTE: unused — kept for referencecomment with a TODO to delete them.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_decode.py` around lines 1106 - 1107, The big-batch pretranspose kernel and its runner are now unreachable; remove gdn_decode_kernel_big_batch_pretranspose and run_gdn_decode_kernel_big_batch_pretranspose (they only call each other and are never referenced from gated_delta_rule_decode_pretranspose) to eliminate dead code, or alternatively add a clear comment above each function like "# NOTE: unused — kept for reference; TODO: delete" so reviewers know they are intentionally retained; update any imports/exports and tests that might reference these symbols to avoid breakage.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@flashinfer/gdn_decode.py`:
- Around line 410-416: The comment incorrectly claims "All threads write (V=128,
NUM_THREADS=128)"; update it to reflect that only the threads in this CTA write
its assigned V-range by replacing the comment above the writeback (near the
block using tidx, start_v_tiles, TILE_V, end_v_tiles, sOutput and o) with
something like: "# Each CTA writes its assigned V-range (1/NUM_BLOCKS_PER_STATE
of V)" so it accurately documents that only a subset of threads per block
perform the write in the small_batch kernel.
---
Nitpick comments:
In `@flashinfer/gdn_decode.py`:
- Around line 1106-1107: The big-batch pretranspose kernel and its runner are
now unreachable; remove gdn_decode_kernel_big_batch_pretranspose and
run_gdn_decode_kernel_big_batch_pretranspose (they only call each other and are
never referenced from gated_delta_rule_decode_pretranspose) to eliminate dead
code, or alternatively add a clear comment above each function like "# NOTE:
unused — kept for reference; TODO: delete" so reviewers know they are
intentionally retained; update any imports/exports and tests that might
reference these symbols to avoid breakage.
|
[FAILED] Pipeline #44415275: 9/20 passed |
|
#2610 will help with failures. |
|
/bot run |
|
@ameynaik-hub is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
|
/bot run |
|
[FAILED] Pipeline #44736192: 10/20 passed |
|
@kahyunnam @yzh119 I dont think the failures are related to my code. looks like tinygemm issue? |
I think this is a known issue, we can probably ignore it. @jimmyzho is working on this but OOTO |
Two optimizations:
Early gate reads: Move gate GMEM reads (r_A_log, r_a, r_dt_bias, r_b) to immediately before the barrier, hiding memory latency during sync.
Always use 8-CTA architecture: Remove batch size conditional that switched to big_batch (1 CTA) for B > 32. The 8-CTA small_batch kernel performs better for ALL batch sizes.
Benchmark results (1000 iters, warmup=10, q_heads=16, k_heads=16, v_heads=32, head_size=128, bfloat16, qk_l2norm=ON):
| Batch | Before (us) | After (us) | Improvement | |-------|-------------|------------|-------------|
| 1 | 3.74 | 3.74 | ~same |
| 2 | 4.22 | 4.22 | ~same |
| 4 | 5.38 | 5.38 | ~same |
| 8 | 7.68 | 7.65 | ~same |
| 16 | 12.90 | 12.90 | ~same |
| 32 | 23.04 | 23.04 | ~same |
| 64 | 51.57 | 42.56 | ~17% |
| 128 | 92.13 | 81.31 | ~12% |
| 256 | 170.18 | 158.98 | ~7% |
| 512 | 334.21 | 314.05 | ~6% |
Correctness verified on all batch sizes (1, 2, 4, 8, 16, 32, 64, 128, 256, 512).
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Release Notes