Skip to content

fix: ck_moe_stage1 split-K buffer overflow from padding scatter (alternative to #2508)#2547

Closed
ChuanLi1101 wants to merge 1 commit into
mainfrom
chuan/fix-ck-moe-stage1-splitk-scatter
Closed

fix: ck_moe_stage1 split-K buffer overflow from padding scatter (alternative to #2508)#2547
ChuanLi1101 wants to merge 1 commit into
mainfrom
chuan/fix-ck-moe-stage1-splitk-scatter

Conversation

@ChuanLi1101

@ChuanLi1101 ChuanLi1101 commented Mar 31, 2026

Copy link
Copy Markdown
Contributor

Summary

Fix out-of-bounds buffer overflow in \ck_moe_stage1\ when splitK is enabled.

Root cause

The CK MoE kernel uses \sorted_size = min(token_num * topk * block_m, sorted_token_ids.shape[0])\ as its M dimension. The kernel launches tile-based blocks covering the entire M range and scatters results to the output buffer. The output buffer must be large enough to accommodate \sorted_size\ rows.

The original code allocated only (token_num, topk, w1.shape[1])\ = \ oken_num * topk\ rows (a 3D tensor). For the padding entries in \sorted_token_ids, the sentinel value (topk << 24 | token_num)\ decodes to scatter position \ oken_num * topk + topk, which exceeds the allocated buffer. Additionally, the kernel expects the output buffer to span at least \sorted_size\ rows to match its tile-based computation grid.

Fix

  • Compute \sorted_size = min(token_num * topk * block_m, sorted_token_ids.shape[0])\ (matching the C++ wrapper logic)
  • Allocate a 2D fp32 buffer of shape (sorted_size, w1.shape[1])\ instead of the undersized 3D (token_num, topk, w1.shape[1])\
  • After the kernel, slice only the valid rows \ mp_out[:token_num*topk, :]\ before passing to \silu_and_mul\ / \gelu_and_mul\

Verification

Tested on MI355X (gfx950) with multiple token/topk/expert configurations:
\\

SplitK Scatter Fix Verification

[tok=1 topk=8 E=256] OK shape=torch.Size([1, 8, 256]) nan=False inf=False
[tok=2 topk=8 E=256] OK shape=torch.Size([2, 8, 256]) nan=False inf=False
[tok=4 topk=8 E=256] OK shape=torch.Size([4, 8, 256]) nan=False inf=False
[tok=16 topk=8 E=256] OK shape=torch.Size([16, 8, 256]) nan=False inf=False
[tok=1 topk=4 E=64] OK shape=torch.Size([1, 4, 256]) nan=False inf=False
[tok=3 topk=6 E=128] OK shape=torch.Size([3, 6, 256]) nan=False inf=False

Results: 6 passed, 0 failed out of 6
ALL TESTS PASSED!
\\

Comparison to PR #2508

PR #2508 uses \sorted_token_ids.shape[0]\ rows (safe but over-allocates). This PR uses \sorted_size\ (the exact M dimension the C++ wrapper computes), which is the minimal correct size. Both are valid; this PR is tighter on memory.

Test plan

  • Verified on MI355X with 6 different token/topk/expert configs
  • All tests pass with no NaN/Inf in output
  • Non-splitK path is unchanged (tmp_out = out)

@ChuanLi1101 ChuanLi1101 requested a review from a team March 31, 2026 06:20
@github-actions

Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2547 --add-label <label>

The CK kernel scatters output via sorted_token_ids using:
  token_offset = (fused_token & 0xffffff) * topk + (fused_token >> 24)

Padding entries use the sentinel value (topk << 24 | token_num),
which decodes to scatter position (token_num * topk + topk) -- beyond
the valid output range [0, token_num * topk). The original buffer
(token_num, topk, w1.shape[1]) only has token_num * topk rows, so
the padding scatter writes out of bounds, causing "HIP runtime error:
invalid argument" during CUDA graph capture (e.g. DeepSeek-R1 decode
with token_num=1, topk=8, block_m=16).

Fix: allocate (token_num * topk + topk + 1) rows -- the exact minimum
needed to absorb all padding scatter writes. After the kernel, slice
only the valid [0, token_num * topk) rows for the activation.

Related: #2508
Made-with: Cursor
@ChuanLi1101 ChuanLi1101 force-pushed the chuan/fix-ck-moe-stage1-splitk-scatter branch from 1ff7e70 to ab58051 Compare March 31, 2026 07:46
frida-andersson added a commit to frida-andersson/aiter that referenced this pull request Apr 1, 2026
Align split-K tmp_out allocation with CK sorted_size and scatter padding
so tile writes stay in bounds; slice valid rows for silu/gelu_and_mul.

Upstream: ROCm#2547
Made-with: Cursor
frida-andersson added a commit to frida-andersson/aiter that referenced this pull request Apr 1, 2026
Allow callers to supply a pre-allocated (M, model_dim) buffer for
moe_sorting instead of torch.empty each forward, for DSv32/vLLM integration.

Keeps ck_moe_stage1 split-K fix from ROCm#2547.

docs: update dsv32-opt-branch provenance (moe_buf + ROCm#2547).
Made-with: Cursor
frida-andersson added a commit to frida-andersson/aiter that referenced this pull request Apr 30, 2026
Align split-K tmp_out allocation with CK sorted_size and scatter padding
so tile writes stay in bounds; slice valid rows for silu/gelu_and_mul.

Upstream: ROCm#2547
Made-with: Cursor
frida-andersson added a commit to frida-andersson/aiter that referenced this pull request Apr 30, 2026
Allow callers to supply a pre-allocated (M, model_dim) buffer for
moe_sorting instead of torch.empty each forward, for DSv32/vLLM integration.

Keeps ck_moe_stage1 split-K fix from ROCm#2547.

docs: update dsv32-opt-branch provenance (moe_buf + ROCm#2547).
Made-with: Cursor
sunway513 added a commit that referenced this pull request May 1, 2026
@ChuanLi1101

Copy link
Copy Markdown
Contributor Author

Superseded by #2551 (merged to main on 2026-03-31 by @rbrugaro, commit e47cc0e). #2551 implements the same fix with two improvements over this PR: (1) uses orch.empty instead of orch.zeros to avoid double-zeroing (CK kernel zeros the buffer via hipMemsetAsync when KBatch > 1), and (2) keeps the .view(dtypes.fp32) call on the sliced �alid_out. Closing as duplicate.
A follow-up PR will address the same pattern in cktile_moe_stage1, which currently has a WARNING comment on main flagging the same undersized-buffer bug.

@ChuanLi1101 ChuanLi1101 closed this May 1, 2026
sunway513 added a commit that referenced this pull request May 4, 2026
…e.py

- Restore import to match main: use `from aiter import
  fused_dynamic_mxfp4_quant_moe_sort, mxfp4_moe_sort_fwd` instead of
  importing from internal triton path and fp4_utils
- Replace all fp4_utils.moe_mxfp4_sort() calls with mxfp4_moe_sort_fwd()
  using correct parameter names (cols= instead of block_size=)
- Remove all moe_buf preallocated buffer additions (PR #2687 rejected):
  parameter defaults, if-guards, and pass-throughs in _moe_sorting_impl,
  moe_sorting, fused_moe, fused_moe_fake, and fused_moe_
- Fix moe_sorting_dispatch_policy type annotation: bool -> int in
  fused_moe_fake and fused_moe_
- Remove moe_buf pass-through test from test_moe_sorting.py
- Preserve legitimate fp4_utils usage (mxfp4_to_f32, e8m0_to_f32) with
  local imports in stage1/stage2 fallback functions
azaidy added a commit that referenced this pull request May 4, 2026
aiter/fused_moe.py:
- Restore to origin/main. Per sunway513's own comment, #2457 and #2547
  were excluded from this bulk merge; per valarLip, #2687 was rejected.
  No source PR should land changes in this file. The previous state
  (+110/-119 vs main) was collateral damage from auto-resolved conflicts
  taking older sides, which silently reverted #2262 (xbf16 asm fmoe path),
  #2726 (FlyDSL a8w4 MoE wrapper params + fuse_quant), #2658 (CK fp8
  blockscale splitk tuner support), and #2620 (mxfp4_moe_sort_hip,
  flagged by valarLip).

op_tests/test_gemm_a8w8_blockscale.py:
- Replace with a clean 3-way merge of origin/main + #2541. Now +55/-0
  vs main, matching #2541's actual contribution exactly. The previous
  state was silently reverting #2645 (CK GEMM multi-arch + test infra:
  TEST_NUM_ITERS, --csv/--output args, kernel_name= param).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
sunway513 added a commit that referenced this pull request May 5, 2026
…3-Next, pa_mqa OOB) (#3005)

* fix: remap QuantType.No to per_1x32 for fp4x2 MoE weights (W4A6 support)

* Fixing two cascading bugs when running the MoE tuner

* Enable split-K for block-scale A8W8 CK and CKTile GEMMs

Propagate the splitK parameter (as KBatch = 2^splitK) through the
block-scale GEMM kernel infrastructure so that the tuning scripts
can sweep split-K values to improve occupancy on small-M shapes.

CK path: add KBatch parameter to gemm_a8w8_blockscale_impl and call
SetKBatch on the device argument. The CK invoker handles output
zeroing and atomic accumulation internally.

CKTile path: add k_batch parameter to gemm_a8w8_blockscale_cktile_impl,
remove the "split-k is not supported yet" runtime guard, and add
hipMemsetAsync to zero the output buffer before atomic accumulation.

Non-tune entry points pass KBatch=1 (no split-K) to preserve existing
behavior. Code generation scripts (gen_instances.py, gen_instances_cktile.py)
updated to include the new parameter in generated wrappers and manifests.

Made-with: Cursor

* Wire splitK from tuning CSV through production blockscale GEMM dispatch

The tuning infrastructure already sweeps splitK and writes it to the CSV,
but the production dispatch ignored it and hardcoded KBatch=1. Add splitK
as a runtime parameter to the non-tune entry points so tuned split-K
values are used without compiling the full _tune instance set.

Made-with: Cursor

* fix: ck_moe_stage1 split-K output buffer overflow from padding scatter

The CK kernel scatters output via sorted_token_ids using:
  token_offset = (fused_token & 0xffffff) * topk + (fused_token >> 24)

Padding entries use the sentinel value (topk << 24 | token_num),
which decodes to scatter position (token_num * topk + topk) -- beyond
the valid output range [0, token_num * topk). The original buffer
(token_num, topk, w1.shape[1]) only has token_num * topk rows, so
the padding scatter writes out of bounds, causing "HIP runtime error:
invalid argument" during CUDA graph capture (e.g. DeepSeek-R1 decode
with token_num=1, topk=8, block_m=16).

Fix: allocate (token_num * topk + topk + 1) rows -- the exact minimum
needed to absorb all padding scatter writes. After the kernel, slice
only the valid [0, token_num * topk) rows for the activation.

Related: #2508
Made-with: Cursor

* Address PR review feedback: validate splitK, fix hipMemset stride issue, add correctness test

Agent-Logs-Url: https://github.com/ROCm/aiter/sessions/e3b37b0f-e151-4935-ad89-fd72436d41e2

Co-authored-by: samremes <181322991+samremes@users.noreply.github.com>

* black format

* fix splitk test dimensions

* Add gdn fusions

* style: fix ruff F841 and black-format Triton PR files

Remove unused variable in rmsnorm FP8 test ref. Apply Black to
kernels, launchers, tests, and gated_delta_rule decode __init__.

Made-with: Cursor

* Update fused_rearrange_sigmoid_gdr.py

* Update op_tests

* Fix BLACK format problem

* Fix black check failure

* Update test_fused_rearrange_sigmoid_gdr.py

* Allow callers to pass pre-allocated moe_buf to avoid output copy

Add an optional `moe_buf` parameter through the moe_sorting and
fused_moe call chain. When provided, the sorting kernel writes
directly into the caller's buffer instead of allocating a new one,
eliminating a redundant copy on the output path.

Made-with: Cursor

* Add moe_buf pass-through test to existing test_moe_sorting

Made-with: Cursor

* Replace _fast with _single_token for causal conv1d update kernels for single token decoding

* Fix blck format error

* Add tuned a8w8 blockscale GEMM config for Qwen3-Next-80B-A3B on MI355X

Tuned 1482 shapes (TP1/TP2/TP4) for Qwen/Qwen3-Next-80B-A3B-Instruct-FP8
on MI355X using CK + CK-TILE backends with splitK support.

Depends on:
- PR #2862 (CK bump for stride fix in CK-TILE blockscale)
- PR #2541 (splitK support for CK/CK-TILE blockscale GEMMs)
- PR #2487 (AQLayout tunable for CK-TILE blockscale 8-warp kernels)

* refactor(triton): rename gated RMSNorm+FP8 op to fused_rms_gated_fp8_group_quant

Colocate the gated RMSNorm + FP8 group quant path with the other fused FP8
ops. The Triton kernel is now _fused_rms_gated_fp8_group_quant_kernel in
_triton_kernels/quant/fused_fp8_quant.py; the Python entry point is
fused_rms_gated_fp8_group_quant in quant/fused_fp8_quant.py, with a docstring
that contrasts it with fused_rms_fp8_group_quant. Remove the old
rmsnorm_input_quant_fp8 module and rms_norm_input_quant_fp8 kernel file.
Re-export the new symbol and helpers (get_fp8_min_max_bounds,
calc_rows_per_block) from aiter.ops.triton.quant. Rename the test file to
test_fused_rms_gated_fp8_group_quant.py and update test.sh.

BREAKING CHANGE: rmsnorm_input_quant_fp8 is removed; use
fused_rms_gated_fp8_group_quant instead.

Made-with: Cursor

* Retune blockscale GEMM configs to fix invalid kernelId+splitK combinations

Full retune of all 1482 shapes on MI355X (gfx950, cu_num=256).
Key changes:
- SplitK usage dropped from 613 to 88 CK shapes (splitK > 0)
- All shapes validated via --run_config (1482/1482 OK)
- E2e perf: 2-8% output throughput improvement vs untuned heuristic

* [Bug] pa_mqa_logits: mask OOB stores on OutLogits_buffer

The gluon `_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle` and
`_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx` kernels have 10
`buffer_store(ptr=OutLogits_buffer, ...)` call sites that are missing the
upper-bound mask present on their sibling stores.  When
`context_length == max_model_len` (the last-token position in a long-
context decode step), `split_context_length` is rounded UP to a
`KVBlockSize` multiple at line 427 and the final prefix/suffix store then
writes up to `ChunkKPerStage` float32 elements past the logical row end.
With `stride_out_batch == max_model_len`, those writes cross into the
next row / the next allocation, causing intermittent HIP memory-access
faults on gfx950 during DeepSeek V3.2 MTP decoding.

This change adds `mask=<offset> < max_model_len` to every unmasked
`buffer_store` on `OutLogits_buffer` in both preshuffle kernels, matching
the pattern of their already-masked neighbours.  The existing
`tl.where(..., -inf)` masking of the *values* is preserved; the only
behavioural change is that out-of-row lanes no longer emit buffer
stores.  Hardware overhead is negligible: `buffer_store` with a predicate
is the same SMEM descriptor path as the unmasked variant, just with a
VCC mask setup.

Repro + end-to-end fix evidence: see PR description.

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>

* style: fix Black formatting

* style: fix Black formatting (Python 3.12 compatible)

* ci: replace deprecated zmq package with pyzmq

The `zmq` meta-package fails to install on some CI runners because
it cannot resolve the `pyzmq` dependency. Use `pyzmq` directly,
which is the actual package providing ZeroMQ bindings for Python.

Fixes Triton Test Shard 7 setup failures.

* ci: increase pip retries and timeout for CI reliability

Set pip global retries=15 and timeout=120s in build_aiter_triton.sh
to handle transient PyPI network failures on self-hosted runners.
Shard 5/7 failures were caused by RemoteDisconnected during pip install.

* ci: make pyzmq install non-blocking in triton test setup

pyzmq is only used by aiter.dist.shm_broadcast, not by any triton
test. When PyPI is unreachable on self-hosted runners, the pyzmq
install failure should not block the entire CI shard.

Split pyzmq into a separate pip install with || fallback so triton
tests can proceed even when PyPI connectivity is degraded.

* ci: retry pip install individually on batch failure

When batch pip install fails (e.g., PyPI connectivity issues on
self-hosted runners), retry each package individually. Only pyzmq
is allowed to fail silently since it's only used by
aiter.dist.shm_broadcast and not required by any CI test suite.

Critical packages (pandas, einops, numpy) must still succeed.

* [MLA] Fix nhead=32 non-persistent decode crash on gfx950

Commit c849fd5 ("Add bf16 MLA decode kernel for gqa_ratio=64,
qseqlen=1 (non-persistent)") zeroed ptr_RP and out_16_nosplit for all
non-persistent dispatch. The legacy QH16 ASM kernel used for nhead=32
(MLA_A16W16_1TG_4W_32mx1_16nx1_Coex0_Msk1_QH16.co) still writes
directly to the output buffer via ptr_RP when kv_split==1.
Dereferencing nullptr causes a GPU memory access fault during CUDA
graph capture on MI355X (gfx950) with DeepSeek-V3.2 at TP4.

Fix:
- Conditionally restore ptr_RP and out_16_nosplit in the non-persistent
  path for legacy kernels (gqa_ratio * max_seqlen_q <= 64) while
  keeping nullptr for newer kernels (e.g. gqa_ratio=64).
- Restore the bf16 nhead in [32,64] early-return after stage1 when
  num_kv_splits==1 to prevent stage2 from overwriting the kernel's
  direct output.

Tested on MI355X TP4 with deepseek-ai/DeepSeek-V3.2 (nhead=32):
- No crash during CUDA graph capture
- Correct GSM8K accuracy

Made-with: Cursor

* revert: remove #2983 (MLA nhead=32 fix) — causes test_mla CI failures

Reverting cherry-pick of #2983 from this bulk merge. The MLA nhead=32
non-persistent decode fix causes deterministic test_mla k_cache and
mla_decode-absorb precision failures on CI MI35X runners (Shard 1 & 2).

#2983 should go through its own PR with proper CI validation by the
original author (frida-andersson).

* fix: restore tuple unpack for FlyDSL fused-quant stage1 return

flydsl_moe_stage1 returns (out, out_scale_sorted) when the kernel uses
fused fp4/fp8 quantization. The tuple unpack logic was removed during
earlier refactoring but the kernel behavior was not changed, causing
fused_moe_2stages to crash with:
  AttributeError: 'tuple' object has no attribute 'view'

Restore the unpack: detect tuple return, extract tensor and scale,
handle fp4 byte-packing trim, and skip redundant Python-side requant
when the kernel already produced sorted scales.

* Revert leaked changes from excluded PRs #2457/#2547/#2687 in fused_moe.py

- Restore import to match main: use `from aiter import
  fused_dynamic_mxfp4_quant_moe_sort, mxfp4_moe_sort_fwd` instead of
  importing from internal triton path and fp4_utils
- Replace all fp4_utils.moe_mxfp4_sort() calls with mxfp4_moe_sort_fwd()
  using correct parameter names (cols= instead of block_size=)
- Remove all moe_buf preallocated buffer additions (PR #2687 rejected):
  parameter defaults, if-guards, and pass-throughs in _moe_sorting_impl,
  moe_sorting, fused_moe, fused_moe_fake, and fused_moe_
- Fix moe_sorting_dispatch_policy type annotation: bool -> int in
  fused_moe_fake and fused_moe_
- Remove moe_buf pass-through test from test_moe_sorting.py
- Preserve legitimate fp4_utils usage (mxfp4_to_f32, e8m0_to_f32) with
  local imports in stage1/stage2 fallback functions

* fix: restore fp4_utils.moe_mxfp4_sort for new code paths (different output layout than mxfp4_moe_sort_fwd)

* style: fix Black formatting for local imports

* fix: remove rejected W4A6 QuantType remap from fused_moe_dp_shared_expert

Lingpeng explicitly rejected this change (from excluded PR #2457).
Reverts the QuantType.No -> per_1x32 remap for fp4x2 weights.

* fix: restore silently-reverted main features from bad merge resolution

aiter/fused_moe.py:
- Restore to origin/main. Per sunway513's own comment, #2457 and #2547
  were excluded from this bulk merge; per valarLip, #2687 was rejected.
  No source PR should land changes in this file. The previous state
  (+110/-119 vs main) was collateral damage from auto-resolved conflicts
  taking older sides, which silently reverted #2262 (xbf16 asm fmoe path),
  #2726 (FlyDSL a8w4 MoE wrapper params + fuse_quant), #2658 (CK fp8
  blockscale splitk tuner support), and #2620 (mxfp4_moe_sort_hip,
  flagged by valarLip).

op_tests/test_gemm_a8w8_blockscale.py:
- Replace with a clean 3-way merge of origin/main + #2541. Now +55/-0
  vs main, matching #2541's actual contribution exactly. The previous
  state was silently reverting #2645 (CK GEMM multi-arch + test infra:
  TEST_NUM_ITERS, --csv/--output args, kernel_name= param).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* chore: remove #2464 from bulk merge per author request

@xaguilar-amd asked to drop #2464 (CK MoE tuner bug fixes) from this
bulk merge — they don't need it for the uplift.

Verified that #2464 is the only PR in this bulk merge touching
aiter/jit/core.py and aiter/utility/mp_tuner.py: the diff between the
branch and origin/main on those files is exactly #2464's +9/-1 and
+5/-0, with no other PR content mixed in. Restoring both files to
origin/main therefore drops #2464 cleanly.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: vecheruk-amd <vecheruk@amd.com>
Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com>
Co-authored-by: Sami Remes <samremes@amd.com>
Co-authored-by: Li <chuali@amd.com>
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: samremes <181322991+samremes@users.noreply.github.com>
Co-authored-by: hellozhuo <zhuo.su@amd.com>
Co-authored-by: Tres Popp <tres.popp@amd.com>
Co-authored-by: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com>
Co-authored-by: Niklas Holmberg <nholmber@users.noreply.github.com>
Co-authored-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: frida-andersson <fanderss@amd.com>
Co-authored-by: Aliasger Zaidy <aliasger.zaidy@amd.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
sunway513 added a commit that referenced this pull request May 5, 2026
…st1 (#3005)

Squash-merged from main commit 2c855fb.

Includes 8 atomic Silo PRs (4 bug fixes + 3 features + 1):
Bug fixes:
- #2457 MoE dispatch fix for Quark W4A6 (MXFP4 weights with QuantType.No)
- #2464 CK MoE tuner cascading bugs
- #2547 ck_moe_stage1 split-K buffer overflow (memory safety)
- #2866 pa_mqa_logits OOB stores fix (memory safety)
Features:
- #2423 Triton optimized decode for Qwen3-Next (GDN, conv1d, fused FP8 quant)
- #2541 SplitK support for CK/CKTile Block-Scale GEMMs
- #2687 Allow preallocated MoE sorting buffer

Conflict resolutions (3 files):
- .github/workflows/aiter-test.yaml (3 blocks): took HEAD to preserve
  release/v0.1.13's prebuilt-image-extract CI flow. Theirs would replace
  it with main's inline  flow which has not
  been validated for this release branch.
- .github/workflows/vllm_benchmark.yaml (1 block): took HEAD for the same
  CI architecture preservation reason.
- aiter/ops/triton/gated_delta_net/__init__.py (1 block): took HEAD.
  Theirs would expose  and
  , but those functions exist only on main
  (added in a separate commit not part of this PR), so taking theirs
  would break import on this release branch with NameError. The new
  fused_rearrange_sigmoid_gdr.py wrapper that #3005 introduces is
  importable directly from its module path; not exporting it via
  __init__.py simply means library consumers must use the longer import
  path. Acceptable trade-off vs broken imports.

28 files changed, +5433/-47 (3 of original 31 dropped to HEAD per above).

Driver: vLLM 0.21 freeze 2026-05-08 — Silo customers need these kernel
fixes (especially #2547 / #2866 memory safety) on the AITER release
wheel, not nightly.

Verification gates added before tag:
- ATOM 5-model accuracy unchanged within +/- 0.005 vs v0.1.13-rc1
- New Qwen3-Next decode codepath smoke (GDN + causal_conv1d_single_token
  + fused_fp8_quant must JIT-compile and produce coherent output)
- Memory safety regression check on Kimi-K2.5-MXFP4 (exercises ck_moe
  stage1) and DeepSeek-V3.2 (exercises pa_mqa_logits)
- Perf delta sample on Kimi/MiniMax/DSv3.2 c=1 + c=64 vs rc1 baseline

(cherry picked from commit 2c855fb)
Liang-jianhao97 pushed a commit that referenced this pull request May 7, 2026
…3-Next, pa_mqa OOB) (#3005)

* fix: remap QuantType.No to per_1x32 for fp4x2 MoE weights (W4A6 support)

* Fixing two cascading bugs when running the MoE tuner

* Enable split-K for block-scale A8W8 CK and CKTile GEMMs

Propagate the splitK parameter (as KBatch = 2^splitK) through the
block-scale GEMM kernel infrastructure so that the tuning scripts
can sweep split-K values to improve occupancy on small-M shapes.

CK path: add KBatch parameter to gemm_a8w8_blockscale_impl and call
SetKBatch on the device argument. The CK invoker handles output
zeroing and atomic accumulation internally.

CKTile path: add k_batch parameter to gemm_a8w8_blockscale_cktile_impl,
remove the "split-k is not supported yet" runtime guard, and add
hipMemsetAsync to zero the output buffer before atomic accumulation.

Non-tune entry points pass KBatch=1 (no split-K) to preserve existing
behavior. Code generation scripts (gen_instances.py, gen_instances_cktile.py)
updated to include the new parameter in generated wrappers and manifests.

Made-with: Cursor

* Wire splitK from tuning CSV through production blockscale GEMM dispatch

The tuning infrastructure already sweeps splitK and writes it to the CSV,
but the production dispatch ignored it and hardcoded KBatch=1. Add splitK
as a runtime parameter to the non-tune entry points so tuned split-K
values are used without compiling the full _tune instance set.

Made-with: Cursor

* fix: ck_moe_stage1 split-K output buffer overflow from padding scatter

The CK kernel scatters output via sorted_token_ids using:
  token_offset = (fused_token & 0xffffff) * topk + (fused_token >> 24)

Padding entries use the sentinel value (topk << 24 | token_num),
which decodes to scatter position (token_num * topk + topk) -- beyond
the valid output range [0, token_num * topk). The original buffer
(token_num, topk, w1.shape[1]) only has token_num * topk rows, so
the padding scatter writes out of bounds, causing "HIP runtime error:
invalid argument" during CUDA graph capture (e.g. DeepSeek-R1 decode
with token_num=1, topk=8, block_m=16).

Fix: allocate (token_num * topk + topk + 1) rows -- the exact minimum
needed to absorb all padding scatter writes. After the kernel, slice
only the valid [0, token_num * topk) rows for the activation.

Related: #2508
Made-with: Cursor

* Address PR review feedback: validate splitK, fix hipMemset stride issue, add correctness test

Agent-Logs-Url: https://github.com/ROCm/aiter/sessions/e3b37b0f-e151-4935-ad89-fd72436d41e2

Co-authored-by: samremes <181322991+samremes@users.noreply.github.com>

* black format

* fix splitk test dimensions

* Add gdn fusions

* style: fix ruff F841 and black-format Triton PR files

Remove unused variable in rmsnorm FP8 test ref. Apply Black to
kernels, launchers, tests, and gated_delta_rule decode __init__.

Made-with: Cursor

* Update fused_rearrange_sigmoid_gdr.py

* Update op_tests

* Fix BLACK format problem

* Fix black check failure

* Update test_fused_rearrange_sigmoid_gdr.py

* Allow callers to pass pre-allocated moe_buf to avoid output copy

Add an optional `moe_buf` parameter through the moe_sorting and
fused_moe call chain. When provided, the sorting kernel writes
directly into the caller's buffer instead of allocating a new one,
eliminating a redundant copy on the output path.

Made-with: Cursor

* Add moe_buf pass-through test to existing test_moe_sorting

Made-with: Cursor

* Replace _fast with _single_token for causal conv1d update kernels for single token decoding

* Fix blck format error

* Add tuned a8w8 blockscale GEMM config for Qwen3-Next-80B-A3B on MI355X

Tuned 1482 shapes (TP1/TP2/TP4) for Qwen/Qwen3-Next-80B-A3B-Instruct-FP8
on MI355X using CK + CK-TILE backends with splitK support.

Depends on:
- PR #2862 (CK bump for stride fix in CK-TILE blockscale)
- PR #2541 (splitK support for CK/CK-TILE blockscale GEMMs)
- PR #2487 (AQLayout tunable for CK-TILE blockscale 8-warp kernels)

* refactor(triton): rename gated RMSNorm+FP8 op to fused_rms_gated_fp8_group_quant

Colocate the gated RMSNorm + FP8 group quant path with the other fused FP8
ops. The Triton kernel is now _fused_rms_gated_fp8_group_quant_kernel in
_triton_kernels/quant/fused_fp8_quant.py; the Python entry point is
fused_rms_gated_fp8_group_quant in quant/fused_fp8_quant.py, with a docstring
that contrasts it with fused_rms_fp8_group_quant. Remove the old
rmsnorm_input_quant_fp8 module and rms_norm_input_quant_fp8 kernel file.
Re-export the new symbol and helpers (get_fp8_min_max_bounds,
calc_rows_per_block) from aiter.ops.triton.quant. Rename the test file to
test_fused_rms_gated_fp8_group_quant.py and update test.sh.

BREAKING CHANGE: rmsnorm_input_quant_fp8 is removed; use
fused_rms_gated_fp8_group_quant instead.

Made-with: Cursor

* Retune blockscale GEMM configs to fix invalid kernelId+splitK combinations

Full retune of all 1482 shapes on MI355X (gfx950, cu_num=256).
Key changes:
- SplitK usage dropped from 613 to 88 CK shapes (splitK > 0)
- All shapes validated via --run_config (1482/1482 OK)
- E2e perf: 2-8% output throughput improvement vs untuned heuristic

* [Bug] pa_mqa_logits: mask OOB stores on OutLogits_buffer

The gluon `_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle` and
`_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx` kernels have 10
`buffer_store(ptr=OutLogits_buffer, ...)` call sites that are missing the
upper-bound mask present on their sibling stores.  When
`context_length == max_model_len` (the last-token position in a long-
context decode step), `split_context_length` is rounded UP to a
`KVBlockSize` multiple at line 427 and the final prefix/suffix store then
writes up to `ChunkKPerStage` float32 elements past the logical row end.
With `stride_out_batch == max_model_len`, those writes cross into the
next row / the next allocation, causing intermittent HIP memory-access
faults on gfx950 during DeepSeek V3.2 MTP decoding.

This change adds `mask=<offset> < max_model_len` to every unmasked
`buffer_store` on `OutLogits_buffer` in both preshuffle kernels, matching
the pattern of their already-masked neighbours.  The existing
`tl.where(..., -inf)` masking of the *values* is preserved; the only
behavioural change is that out-of-row lanes no longer emit buffer
stores.  Hardware overhead is negligible: `buffer_store` with a predicate
is the same SMEM descriptor path as the unmasked variant, just with a
VCC mask setup.

Repro + end-to-end fix evidence: see PR description.

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>

* style: fix Black formatting

* style: fix Black formatting (Python 3.12 compatible)

* ci: replace deprecated zmq package with pyzmq

The `zmq` meta-package fails to install on some CI runners because
it cannot resolve the `pyzmq` dependency. Use `pyzmq` directly,
which is the actual package providing ZeroMQ bindings for Python.

Fixes Triton Test Shard 7 setup failures.

* ci: increase pip retries and timeout for CI reliability

Set pip global retries=15 and timeout=120s in build_aiter_triton.sh
to handle transient PyPI network failures on self-hosted runners.
Shard 5/7 failures were caused by RemoteDisconnected during pip install.

* ci: make pyzmq install non-blocking in triton test setup

pyzmq is only used by aiter.dist.shm_broadcast, not by any triton
test. When PyPI is unreachable on self-hosted runners, the pyzmq
install failure should not block the entire CI shard.

Split pyzmq into a separate pip install with || fallback so triton
tests can proceed even when PyPI connectivity is degraded.

* ci: retry pip install individually on batch failure

When batch pip install fails (e.g., PyPI connectivity issues on
self-hosted runners), retry each package individually. Only pyzmq
is allowed to fail silently since it's only used by
aiter.dist.shm_broadcast and not required by any CI test suite.

Critical packages (pandas, einops, numpy) must still succeed.

* [MLA] Fix nhead=32 non-persistent decode crash on gfx950

Commit c849fd5 ("Add bf16 MLA decode kernel for gqa_ratio=64,
qseqlen=1 (non-persistent)") zeroed ptr_RP and out_16_nosplit for all
non-persistent dispatch. The legacy QH16 ASM kernel used for nhead=32
(MLA_A16W16_1TG_4W_32mx1_16nx1_Coex0_Msk1_QH16.co) still writes
directly to the output buffer via ptr_RP when kv_split==1.
Dereferencing nullptr causes a GPU memory access fault during CUDA
graph capture on MI355X (gfx950) with DeepSeek-V3.2 at TP4.

Fix:
- Conditionally restore ptr_RP and out_16_nosplit in the non-persistent
  path for legacy kernels (gqa_ratio * max_seqlen_q <= 64) while
  keeping nullptr for newer kernels (e.g. gqa_ratio=64).
- Restore the bf16 nhead in [32,64] early-return after stage1 when
  num_kv_splits==1 to prevent stage2 from overwriting the kernel's
  direct output.

Tested on MI355X TP4 with deepseek-ai/DeepSeek-V3.2 (nhead=32):
- No crash during CUDA graph capture
- Correct GSM8K accuracy

Made-with: Cursor

* revert: remove #2983 (MLA nhead=32 fix) — causes test_mla CI failures

Reverting cherry-pick of #2983 from this bulk merge. The MLA nhead=32
non-persistent decode fix causes deterministic test_mla k_cache and
mla_decode-absorb precision failures on CI MI35X runners (Shard 1 & 2).

#2983 should go through its own PR with proper CI validation by the
original author (frida-andersson).

* fix: restore tuple unpack for FlyDSL fused-quant stage1 return

flydsl_moe_stage1 returns (out, out_scale_sorted) when the kernel uses
fused fp4/fp8 quantization. The tuple unpack logic was removed during
earlier refactoring but the kernel behavior was not changed, causing
fused_moe_2stages to crash with:
  AttributeError: 'tuple' object has no attribute 'view'

Restore the unpack: detect tuple return, extract tensor and scale,
handle fp4 byte-packing trim, and skip redundant Python-side requant
when the kernel already produced sorted scales.

* Revert leaked changes from excluded PRs #2457/#2547/#2687 in fused_moe.py

- Restore import to match main: use `from aiter import
  fused_dynamic_mxfp4_quant_moe_sort, mxfp4_moe_sort_fwd` instead of
  importing from internal triton path and fp4_utils
- Replace all fp4_utils.moe_mxfp4_sort() calls with mxfp4_moe_sort_fwd()
  using correct parameter names (cols= instead of block_size=)
- Remove all moe_buf preallocated buffer additions (PR #2687 rejected):
  parameter defaults, if-guards, and pass-throughs in _moe_sorting_impl,
  moe_sorting, fused_moe, fused_moe_fake, and fused_moe_
- Fix moe_sorting_dispatch_policy type annotation: bool -> int in
  fused_moe_fake and fused_moe_
- Remove moe_buf pass-through test from test_moe_sorting.py
- Preserve legitimate fp4_utils usage (mxfp4_to_f32, e8m0_to_f32) with
  local imports in stage1/stage2 fallback functions

* fix: restore fp4_utils.moe_mxfp4_sort for new code paths (different output layout than mxfp4_moe_sort_fwd)

* style: fix Black formatting for local imports

* fix: remove rejected W4A6 QuantType remap from fused_moe_dp_shared_expert

Lingpeng explicitly rejected this change (from excluded PR #2457).
Reverts the QuantType.No -> per_1x32 remap for fp4x2 weights.

* fix: restore silently-reverted main features from bad merge resolution

aiter/fused_moe.py:
- Restore to origin/main. Per sunway513's own comment, #2457 and #2547
  were excluded from this bulk merge; per valarLip, #2687 was rejected.
  No source PR should land changes in this file. The previous state
  (+110/-119 vs main) was collateral damage from auto-resolved conflicts
  taking older sides, which silently reverted #2262 (xbf16 asm fmoe path),
  #2726 (FlyDSL a8w4 MoE wrapper params + fuse_quant), #2658 (CK fp8
  blockscale splitk tuner support), and #2620 (mxfp4_moe_sort_hip,
  flagged by valarLip).

op_tests/test_gemm_a8w8_blockscale.py:
- Replace with a clean 3-way merge of origin/main + #2541. Now +55/-0
  vs main, matching #2541's actual contribution exactly. The previous
  state was silently reverting #2645 (CK GEMM multi-arch + test infra:
  TEST_NUM_ITERS, --csv/--output args, kernel_name= param).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* chore: remove #2464 from bulk merge per author request

@xaguilar-amd asked to drop #2464 (CK MoE tuner bug fixes) from this
bulk merge — they don't need it for the uplift.

Verified that #2464 is the only PR in this bulk merge touching
aiter/jit/core.py and aiter/utility/mp_tuner.py: the diff between the
branch and origin/main on those files is exactly #2464's +9/-1 and
+5/-0, with no other PR content mixed in. Restoring both files to
origin/main therefore drops #2464 cleanly.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: vecheruk-amd <vecheruk@amd.com>
Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com>
Co-authored-by: Sami Remes <samremes@amd.com>
Co-authored-by: Li <chuali@amd.com>
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: samremes <181322991+samremes@users.noreply.github.com>
Co-authored-by: hellozhuo <zhuo.su@amd.com>
Co-authored-by: Tres Popp <tres.popp@amd.com>
Co-authored-by: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com>
Co-authored-by: Niklas Holmberg <nholmber@users.noreply.github.com>
Co-authored-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: frida-andersson <fanderss@amd.com>
Co-authored-by: Aliasger Zaidy <aliasger.zaidy@amd.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant