Skip to content

Replace QH16 bf16 kernel with a new one that does not use ptr_RP#2999

Merged
valarLip merged 2 commits into
mainfrom
mla_nheads32_fault_fix
May 18, 2026
Merged

Replace QH16 bf16 kernel with a new one that does not use ptr_RP#2999
valarLip merged 2 commits into
mainfrom
mla_nheads32_fault_fix

Conversation

@JohnNikolay84

Copy link
Copy Markdown
Contributor

Motivation

#2729 has introduced a new QH64 kernel that is not writing directly to ptr_RP and instead is writing split data into ptr_R/logits.

As this #2983 states other kernels like MLA_A16W16_1TG_4W_32mx1_16nx1_Coex0_Msk1_QH16.co do not follow the same logic and write into a null pointer instead.

Technical Details

This change is introducing a new kernel for nhead=32 bf16 that is using same convention as QH64 kernel. However I have not been able to find a kernel with mfma layouts 32x32x16, instead I am using the one with 16x16x32.

Test Plan

Run a new test in aiter and make sure it pass torch reference.
Run DeepSeek in TP4 and make sure it is not crashing.

Test Result

image

Submission Checklist

@JohnNikolay84 JohnNikolay84 self-assigned this May 1, 2026
@JohnNikolay84 JohnNikolay84 requested review from a team and fangche123 May 1, 2026 11:37
@github-actions

github-actions Bot commented May 1, 2026

Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2999 --add-label <label>

@JohnNikolay84 JohnNikolay84 requested review from Zzz9990 and valarLip May 1, 2026 11:48
@JohnNikolay84 JohnNikolay84 force-pushed the mla_nheads32_fault_fix branch from ce19134 to 46d6983 Compare May 4, 2026 13:38
frida-andersson added a commit to frida-andersson/aiter that referenced this pull request May 5, 2026
The legacy QH16 m32x1_n16x1 ASM kernel (gqa_ratio=32, bf16/bf16,
non-persistent, decode qseqlen=1) writes its output directly via
ptr_RP when kv_split==1. Upstream passes ptr_RP=nullptr and
out_16_nosplit=0, causing GPU memory faults on gfx950 (DeepSeek-V3.2
TP4 hits this with nhead=32).

Fix:
- C++: set ptr_RP and out_16_nosplit only when gqa_ratio==32 AND
  max_seqlen_q==1 (the exact legacy kernel condition). Other
  non-persistent kernels (v3, stage1) use split-reduce and expect
  ptr_RP = nullptr, so they are unaffected.
- Python: reuse output buffer for logits and skip stage2 only when
  nhead==32 and max_seqlen_q==1 (matches the C++ gate).

Tested on MI355X (gfx950): nhead=8,16,32,64,128 all pass.
bf16/bf16, ctx_lens=[256,1024], batch=[1,4,16].

Supersedes: ROCm#2999 (broken — tile mismatch, 85% wrong output)
Co-authored-by: Cursor <cursoragent@cursor.com>
valarLip pushed a commit that referenced this pull request May 6, 2026
The legacy QH16 m32x1_n16x1 ASM kernel (gqa_ratio=32, bf16/bf16,
non-persistent, decode qseqlen=1) writes its output directly via
ptr_RP when kv_split==1. Upstream passes ptr_RP=nullptr and
out_16_nosplit=0, causing GPU memory faults on gfx950 (DeepSeek-V3.2
TP4 hits this with nhead=32).

Fix:
- C++: set ptr_RP and out_16_nosplit only when gqa_ratio==32 AND
  max_seqlen_q==1 (the exact legacy kernel condition). Other
  non-persistent kernels (v3, stage1) use split-reduce and expect
  ptr_RP = nullptr, so they are unaffected.
- Python: reuse output buffer for logits and skip stage2 only when
  nhead==32 and max_seqlen_q==1 (matches the C++ gate).

Tested on MI355X (gfx950): nhead=8,16,32,64,128 all pass.
bf16/bf16, ctx_lens=[256,1024], batch=[1,4,16].

Supersedes: #2999 (broken — tile mismatch, 85% wrong output)

Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: azaidy <aliasger.zaidy@amd.com>
Liang-jianhao97 pushed a commit that referenced this pull request May 7, 2026
The legacy QH16 m32x1_n16x1 ASM kernel (gqa_ratio=32, bf16/bf16,
non-persistent, decode qseqlen=1) writes its output directly via
ptr_RP when kv_split==1. Upstream passes ptr_RP=nullptr and
out_16_nosplit=0, causing GPU memory faults on gfx950 (DeepSeek-V3.2
TP4 hits this with nhead=32).

Fix:
- C++: set ptr_RP and out_16_nosplit only when gqa_ratio==32 AND
  max_seqlen_q==1 (the exact legacy kernel condition). Other
  non-persistent kernels (v3, stage1) use split-reduce and expect
  ptr_RP = nullptr, so they are unaffected.
- Python: reuse output buffer for logits and skip stage2 only when
  nhead==32 and max_seqlen_q==1 (matches the C++ gate).

Tested on MI355X (gfx950): nhead=8,16,32,64,128 all pass.
bf16/bf16, ctx_lens=[256,1024], batch=[1,4,16].

Supersedes: #2999 (broken — tile mismatch, 85% wrong output)

Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: azaidy <aliasger.zaidy@amd.com>
ChuanLi1101
ChuanLi1101 previously approved these changes May 14, 2026

@ChuanLi1101 ChuanLi1101 left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Clean follow-up to merged #2729 / #2983 -- adds the matching bf16 nhead=32 path to the QH16 kernel family using the same ptr_R/logits convention as the QH64 kernel. Diff is +26/-3 across 4 files with a precise dispatch guard (q.dtype==bf16 AND kv.dtype==bf16 AND nhead==32); behavior on every other path is preserved. CI all green.

cc @frida-andersson @xaguilar-amd for a courtesy MLA-area LGTM since this lives next to your merged MLA fixes -- happy to merge once one of you takes a quick look.

@xaguilar-amd

Copy link
Copy Markdown
Contributor

LGTM too.

@frida-andersson

Copy link
Copy Markdown
Contributor

LGTM

@JohnNikolay84

Copy link
Copy Markdown
Contributor Author

This now has to be rebased on top of main in the following order due to conflicts

this PR
Frida's 2983 PR reverted
HEAD

@JohnNikolay84

JohnNikolay84 commented May 15, 2026

Copy link
Copy Markdown
Contributor Author

It is ready to be merged now once approved.

@JohnNikolay84 JohnNikolay84 requested a review from ChuanLi1101 May 18, 2026 07:56
@valarLip valarLip merged commit 1eb1c9b into main May 18, 2026
43 of 45 checks passed
@valarLip valarLip deleted the mla_nheads32_fault_fix branch May 18, 2026 10:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants