[GDN] Change Attention State Layout from [N, HV, K, V] to [N, HV, V, K]#20283
[GDN] Change Attention State Layout from [N, HV, K, V] to [N, HV, V, K]#20283ispobock merged 5 commits intosgl-project:mainfrom
Conversation
|
/tag-and-rerun-ci |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on a critical optimization within the GDN attention mechanism by reordering the dimensions of the recurrent state. By swapping the K and V dimensions, the system achieves more efficient memory access on GPUs, directly translating to notable improvements in both the time it takes to generate the first token and the overall inference latency. This enhancement is particularly beneficial in decode scenarios where memory access patterns are crucial for performance. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request successfully refactors the attention state memory layout from [N, HV, K, V] to [N, HV, V, K]. This change is well-motivated by the goal of improving memory access patterns and throughput, especially for GPU coalesced memory access, as evidenced by the provided benchmarking results showing improvements in TTFT and E2E latency. The modifications across the Triton kernels (chunk_delta_h.py, chunk_o.py, fused_recurrent.py) correctly adapt the tensor shapes, strides, block pointers, and matrix multiplication operations to the new layout. The docstring updates accurately reflect these internal changes.
|
Fixed KDA CI failure. Now both KDA and GDN are using the VK layout. |
|
Kimi Linear model test passed. |
|
/rerun-failed-ci |
|
/tag-and-rerun-ci |
|
Verified that this PR worked correctly in the mamba |
|
gsm8k no drop in mamba |
|
Added gpqa test with mamba extra_buffer enabled, acc no drops: |
|
Added a test case for the new VK layout. |
|
/rerun-failed-ci |
2 similar comments
|
/rerun-failed-ci |
|
/rerun-failed-ci |
Motivation
In order to improve memory access pattern and throughput, this PR transpose the recurrent state memory layout in GDN attention from [N, HV, K, V] to [N, HV, V, K].
KV swap aligns the long edge (K dimension) of the state tile to the memory contiguous direction, significantly improving the efficiency of GPU's coalesced memory access, allowing the GPU to fetch more effective data with each memory access. This effect is also noticeable in the decode scenario (when BV is limited to 8).
Original [K, V] after Swap [V, K]
Tile Shape (decode) [256, 8] vs [8, 256]
Number of contiguous elements per row 8 vs 256
Number of rows 256 vs 8
Both GDN and KDA's SSM are adapted to VK layout. This change covers all the decode/extend/target_verify APIs.
Modifications
Accuracy Tests
gpqa no drops:
gsm8k has no drops:
Benchmarking and Profiling
Server:
Benchmark:
TTFT speedup: (14993-13842)/14993 = 7%
E2E speedup: (23203-21247)/23203 = 8%
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci