Replace QH16 bf16 kernel with a new one that does not use ptr_RP#2999
Conversation
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
ce19134 to
46d6983
Compare
The legacy QH16 m32x1_n16x1 ASM kernel (gqa_ratio=32, bf16/bf16, non-persistent, decode qseqlen=1) writes its output directly via ptr_RP when kv_split==1. Upstream passes ptr_RP=nullptr and out_16_nosplit=0, causing GPU memory faults on gfx950 (DeepSeek-V3.2 TP4 hits this with nhead=32). Fix: - C++: set ptr_RP and out_16_nosplit only when gqa_ratio==32 AND max_seqlen_q==1 (the exact legacy kernel condition). Other non-persistent kernels (v3, stage1) use split-reduce and expect ptr_RP = nullptr, so they are unaffected. - Python: reuse output buffer for logits and skip stage2 only when nhead==32 and max_seqlen_q==1 (matches the C++ gate). Tested on MI355X (gfx950): nhead=8,16,32,64,128 all pass. bf16/bf16, ctx_lens=[256,1024], batch=[1,4,16]. Supersedes: ROCm#2999 (broken — tile mismatch, 85% wrong output) Co-authored-by: Cursor <cursoragent@cursor.com>
The legacy QH16 m32x1_n16x1 ASM kernel (gqa_ratio=32, bf16/bf16, non-persistent, decode qseqlen=1) writes its output directly via ptr_RP when kv_split==1. Upstream passes ptr_RP=nullptr and out_16_nosplit=0, causing GPU memory faults on gfx950 (DeepSeek-V3.2 TP4 hits this with nhead=32). Fix: - C++: set ptr_RP and out_16_nosplit only when gqa_ratio==32 AND max_seqlen_q==1 (the exact legacy kernel condition). Other non-persistent kernels (v3, stage1) use split-reduce and expect ptr_RP = nullptr, so they are unaffected. - Python: reuse output buffer for logits and skip stage2 only when nhead==32 and max_seqlen_q==1 (matches the C++ gate). Tested on MI355X (gfx950): nhead=8,16,32,64,128 all pass. bf16/bf16, ctx_lens=[256,1024], batch=[1,4,16]. Supersedes: #2999 (broken — tile mismatch, 85% wrong output) Co-authored-by: Cursor <cursoragent@cursor.com> Co-authored-by: azaidy <aliasger.zaidy@amd.com>
The legacy QH16 m32x1_n16x1 ASM kernel (gqa_ratio=32, bf16/bf16, non-persistent, decode qseqlen=1) writes its output directly via ptr_RP when kv_split==1. Upstream passes ptr_RP=nullptr and out_16_nosplit=0, causing GPU memory faults on gfx950 (DeepSeek-V3.2 TP4 hits this with nhead=32). Fix: - C++: set ptr_RP and out_16_nosplit only when gqa_ratio==32 AND max_seqlen_q==1 (the exact legacy kernel condition). Other non-persistent kernels (v3, stage1) use split-reduce and expect ptr_RP = nullptr, so they are unaffected. - Python: reuse output buffer for logits and skip stage2 only when nhead==32 and max_seqlen_q==1 (matches the C++ gate). Tested on MI355X (gfx950): nhead=8,16,32,64,128 all pass. bf16/bf16, ctx_lens=[256,1024], batch=[1,4,16]. Supersedes: #2999 (broken — tile mismatch, 85% wrong output) Co-authored-by: Cursor <cursoragent@cursor.com> Co-authored-by: azaidy <aliasger.zaidy@amd.com>
There was a problem hiding this comment.
LGTM.
Clean follow-up to merged #2729 / #2983 -- adds the matching bf16 nhead=32 path to the QH16 kernel family using the same ptr_R/logits convention as the QH64 kernel. Diff is +26/-3 across 4 files with a precise dispatch guard (q.dtype==bf16 AND kv.dtype==bf16 AND nhead==32); behavior on every other path is preserved. CI all green.
cc @frida-andersson @xaguilar-amd for a courtesy MLA-area LGTM since this lives next to your merged MLA fixes -- happy to merge once one of you takes a quick look.
|
LGTM too. |
|
LGTM |
|
This now has to be rebased on top of main in the following order due to conflicts this PR |
4f7d467 to
da27d6d
Compare
|
It is ready to be merged now once approved. |
Motivation
#2729 has introduced a new QH64 kernel that is not writing directly to ptr_RP and instead is writing split data into ptr_R/logits.
As this #2983 states other kernels like MLA_A16W16_1TG_4W_32mx1_16nx1_Coex0_Msk1_QH16.co do not follow the same logic and write into a null pointer instead.
Technical Details
This change is introducing a new kernel for nhead=32 bf16 that is using same convention as QH64 kernel. However I have not been able to find a kernel with mfma layouts 32x32x16, instead I am using the one with 16x16x32.
Test Plan
Run a new test in aiter and make sure it pass torch reference.
Run DeepSeek in TP4 and make sure it is not crashing.
Test Result
Submission Checklist