Skip to content

[Bugfix][ROCm]Fix Qwen3-Next-80B-A3B-Thinking inference and optimize non-standard block size (544) support under rocm_atten#31380

Merged
tjtanaa merged 18 commits intovllm-project:mainfrom
EmbeddedLLM:fix-rocm-qwen3-next-inference
Jan 9, 2026
Merged

[Bugfix][ROCm]Fix Qwen3-Next-80B-A3B-Thinking inference and optimize non-standard block size (544) support under rocm_atten#31380
tjtanaa merged 18 commits intovllm-project:mainfrom
EmbeddedLLM:fix-rocm-qwen3-next-inference

Conversation

@vllmellm
Copy link
Copy Markdown
Contributor

@vllmellm vllmellm commented Dec 26, 2025

Purpose

Fixes #26473

This PR refactors the rocm_attn backend kernels to support models with non-power-of-2 block sizes, specifically the Qwen/Qwen3-Next-80B-A3B-Thinking model.

The core of this update is a Dynamic Dispatching Mechanism:Standard Path ($2^n$): For models with power-of-2 block sizes (16, 32, 64, 128, etc.), the kernel retains the legacy bitwise-optimization logic to ensure maximum performance and zero regression.Generalized Path ($non-2^n$): For non-standard models (e.g., Qwen3-544), the kernel switches to a generalized arithmetic addressing logic, bypassing the Triton "Not a power of 2" constraint.

Technical Implementation

  1. Dual-Path Execution Model
    We introduced an is_pow2 check at the Python launch level. This allows the compiler to choose the optimal execution path:
  • Bitwise Branch: Uses block_id << shift for zero-latency addressing.
  • Arithmetic Branch: Uses absolute linear addressing (abs_token_idx // PHYSICAL_BLOCK_SIZE) to handle arbitrary stride and alignment requirements.
  1. Universal Physical Addressing
    We refactored the following kernels to handle 5D (K-cache) and 4D (V-cache) addressing without relying on bit-shifting:
  • vllm/attention/ops/triton_reshape_and_cache_flash.py: Updated cache population for arbitrary strides.
  • vllm/attention/ops/chunked_prefill_paged_decode.py: Refactored mixed prefill-decode kernels to handle non-linear offsets.
  • vllm/attention/ops/prefix_prefill.py: Optimized context-heavy thinking processes with unified addressing.

Test Plan

The fix was validated on a dual AMD MI300X system using vLLM V1 engine.

Backend: VLLM_ATTENTION_BACKEND="ROCM_ATTN".

Unit Tests: All 164 cases in tests/kernels/attention/test_prefix_prefill.py PASSED and test_qwen3_nonstandard_block_size function PASSED with block_size=544.

End-to-End: Verified Qwen3-Thinking with GSM8K using reasoning_parser=deepseek_r1.

Regression Testing: Conducted accuracy and performance benchmarks on standard models (Qwen3-235B-FP8 and Llama-4-Scout-17B) to ensure don't regression for power-of-2 block sizes (16, 32).

Test Result

benchmark

  1. Qwen/Qwen3-Next-80B-A3B-Thinking - The Primary Fix
  2. Qwen/Qwen3-235B-A22B-Instruct-2507-FP8 - Test regression situation
  3. Qwen/Qwen2.5-7B-Instruct - Test regression situation
  4. meta-llama/Llama-3.1-8B - Test regression situation
  5. RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic - Test regression situation
Screenshot 2026-01-08 at 17 44 25

Currently, ROCM can successfully perform attention comparisons on two backends of the qwen3-next architecture.

TRITON_ATTN VS ROCM_ATTN

Screenshot 2026-01-28 at 11 28 23 Screenshot 2026-01-28 at 11 28 36

lm_eval

  1. Qwen/Qwen3-Next-80B-A3B-Thinking - The Primary Fix
  2. Qwen/Qwen3-235B-A22B-Instruct-2507-FP8 - Test regression situation
  3. RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic - Test regression situation
Screenshot 2026-01-08 at 17 44 05

Kernel Validation

test_prefix_prefill.py:164 PASSED, 224 SKIPPED.

tests/kernels/attention/test_prefix_prefill.py::test_contexted_kv_attention_alibi_f32[chunked_prefill_paged_decode-cuda:1-fp8_dtype0-128-64-64] SKIPPED
tests/kernels/attention/test_prefix_prefill.py::test_contexted_kv_attention_alibi_f32[chunked_prefill_paged_decode-cuda:1-fp8_esm2_dtype0-128-64-64] SKIPPED
tests/kernels/attention/test_prefix_prefill.py::test_contexted_kv_attention_alibi_f32[chunked_prefill_paged_decode-cuda:1-fp8_esm2_dtype0-24-64-64] SKIPPED
tests/kernels/attention/test_prefix_prefill.py::test_contexted_kv_attention_alibi_f32[chunked_prefill_paged_decode-cuda:1-fp8_dtype0-24-64-64] SKIPPED
tests/kernels/attention/test_prefix_prefill.py::test_contexted_kv_attention_alibi_f32[context_attention_fwd-cuda:0-auto-dtype0-24-24-64] SKIPPED (need ...)
tests/kernels/attention/test_prefix_prefill.py::test_contexted_kv_attention_alibi_f32[context_attention_fwd-cuda:1-auto-dtype0-24-24-64] SKIPPED (need ...)
tests/kernels/attention/test_prefix_prefill.py::test_contexted_kv_attention_alibi_f32[context_attention_fwd-cuda:0-fp8_dtype0-24-64-64] SKIPPED (need ...)
tests/kernels/attention/test_prefix_prefill.py::test_contexted_kv_attention_alibi_f32[context_attention_fwd-cuda:1-fp8_dtype0-24-64-64] SKIPPED (need ...)
tests/kernels/attention/test_prefix_prefill.py::test_contexted_kv_attention_alibi_f32[context_attention_fwd-cuda:0-fp8_esm2-dtype0-24-64-64] SKIPPED (need ...)
tests/kernels/attention/test_prefix_prefill.py::test_contexted_kv_attention_alibi_f32[context_attention_fwd-cuda:1-fp8_esm2-dtype0-24-64-64] SKIPPED (need ...)
...
...
...

=============================== warnings summary ===============================
<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html

================== 164 passed, 224 skipped, 2 warnings in 209.55s (0:03:29)==================

root@tw043:/app/vllm_source#

Note

Enables ROCm attention to work with non-power-of-2 KV block sizes by dynamically dispatching to Triton kernels and introducing generalized physical addressing.

  • Generalizes Triton kernels (chunked_prefill_paged_decode, prefix_prefill) to accept PHYSICAL_BLOCK_SIZE and compute 5D (K) / 4D (V) offsets from logical token indices; adds optional processing for block-table pointers, fixes sm_scale dimension, and improves numerical stability (masking with -inf, alpha/L init, small epsilon in division)
  • Extends triton_reshape_and_cache_flash to support head-major cache layout (5D K / 4D V) with new stride params and uses it in ROCm backend when block size is non-power-of-2; otherwise retain existing HIP path
  • In ROCm backend, detect block size and route reshape/cache + attention accordingly; for non-pow2, force Triton path, keep pow2 behavior unchanged
  • Tests: allow passing block_size into attention tests and add test_qwen3_nonstandard_block_size (544) gated on ROCm

Written by Cursor Bugbot for commit de977a7. This will update automatically on new commits. Configure here.

…ng model under rocm_atten.

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@mergify mergify bot added qwen Related to Qwen models rocm Related to AMD ROCm v1 labels Dec 26, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the ROCm attention kernels to support non-power-of-two block sizes, which is required for models like Qwen3-Thinking. The changes involve moving from bitwise addressing to a more general linear addressing scheme in Triton kernels.

My review identified a few critical issues. The Triton kernels in chunked_prefill_paged_decode.py and prefix_prefill.py incorrectly hardcode a value of 8 for memory layout calculations, which will cause memory corruption for data types where this assumption does not hold (e.g., float32 or fp8). I also found a code quality issue in triton_reshape_and_cache_flash.py with a redundant line of code.

These critical issues need to be addressed to ensure correctness and prevent crashes.

@vllmellm vllmellm changed the title [Bugfix][ROCm]Fix Qwen3-Thinking inference and optimize non-standard block size (544) support under rocm_atten [Bugfix][ROCm]Fix Qwen3-Next-80B-A3B-Thinking inference and optimize non-standard block size (544) support under rocm_atten Dec 26, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@vllmellm
Copy link
Copy Markdown
Contributor Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for non-power-of-2 block sizes in ROCm attention kernels, which is crucial for models like Qwen3-Next-80B-A3B-Thinking. The changes primarily involve refactoring addressing logic from bitwise operations to generalized arithmetic, which is a solid approach for universality. The PR also includes several important numerical stability fixes in the attention kernels.

My review identified a critical issue in the pointer-to-index conversion logic in prefix_prefill.py where an incorrect mask could lead to erroneous address calculations. Additionally, there's an opportunity to unify and improve the pointer conversion logic across chunked_prefill_paged_decode.py and prefix_prefill.py for better robustness and maintainability. The other changes, including the core logic for non-standard block sizes and various bug fixes, appear correct and well-implemented.

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@vllmellm
Copy link
Copy Markdown
Contributor Author

vllmellm commented Dec 26, 2025

@codex review

@vllmellm
Copy link
Copy Markdown
Contributor Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors ROCm attention kernels to support non-power-of-2 block sizes, which is required for models like Qwen3-Thinking. The core change is a shift from bitwise addressing to a more general arithmetic-based addressing for KV cache access. The PR also includes a critical bug fix for the sm_scale calculation and several important numerical stability improvements that prevent NaN propagation in edge cases. My review focuses on these correctness and stability fixes, and also points out a potential fragility in using magic numbers to detect physical pointers in block tables. Overall, the changes are well-implemented and significantly improve the robustness and capability of the ROCm attention backend.

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@vllmellm vllmellm marked this pull request as ready for review December 27, 2025 08:09
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@vllmellm
Copy link
Copy Markdown
Contributor Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant refactoring to the ROCm attention kernels to support non-power-of-2 block sizes, which is necessary for models like Qwen3-Next-80B. The changes primarily involve moving from bitwise addressing to a more general absolute linear addressing for the KV cache, which is a solid approach. The modifications in chunked_prefill_paged_decode.py, prefix_prefill.py, and triton_reshape_and_cache_flash.py are consistent and correctly implement this new addressing logic.

Additionally, the PR includes several important correctness fixes, such as preventing NaN values in softmax calculations, avoiding division by zero, and correcting the sm_scale calculation. These are all valuable improvements.

My main concern, detailed in a specific comment, is the use of a fragile heuristic to detect whether the block table contains pointers or indices in prefix_prefill.py. Replacing this with an explicit flag would improve robustness and consistency across the codebase.

vllmellm and others added 3 commits December 29, 2025 03:31
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@tjtanaa
Copy link
Copy Markdown
Collaborator

tjtanaa commented Dec 29, 2025

@vllmellm Please check, the changes are affecting the AMD CI.

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@vllmellm vllmellm requested a review from mgoin as a code owner December 30, 2025 04:14
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@tjtanaa
Copy link
Copy Markdown
Collaborator

tjtanaa commented Dec 31, 2025

@ganyi1996ppo @zejunchen-zejun can you take a look at this PR? Thank you.

@tjtanaa tjtanaa added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 6, 2026
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@vllmellm vllmellm requested a review from WoosukKwon as a code owner January 8, 2026 10:24
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@vllmellm vllmellm force-pushed the fix-rocm-qwen3-next-inference branch from 0133f5e to ddf33fa Compare January 9, 2026 04:32
Copy link
Copy Markdown
Collaborator

@tjtanaa tjtanaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Amazing work. Thank you so much

@tjtanaa tjtanaa merged commit 1a19e9c into vllm-project:main Jan 9, 2026
53 checks passed
@vllmellm
Copy link
Copy Markdown
Contributor Author

vllmellm commented Jan 9, 2026

Hi @tjtanaa , Thanks for your approval.

akh64bit pushed a commit to akh64bit/vllm that referenced this pull request Jan 16, 2026
…non-standard block size (544) support under rocm_atten (vllm-project#31380)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…non-standard block size (544) support under rocm_atten (vllm-project#31380)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
@mergify mergify bot added the bug Something isn't working label Jan 28, 2026
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
…non-standard block size (544) support under rocm_atten (vllm-project#31380)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
khairulkabir1661 added a commit to khairulkabir1661/vllm that referenced this pull request Mar 26, 2026
…lock size (PR vllm-project#31380)

Backported from vllm-project#31380

This PR enables ROCm attention backend to support non-power-of-2 KV block
sizes, specifically fixing Qwen3-Next-80B-A3B-Thinking model inference.

Key Changes:
- Generalized Triton kernels to handle non-standard block sizes (e.g., 544)
- Added dynamic dispatching based on is_pow2 check for optimal performance
- Updated chunked_prefill_paged_decode.py with PHYSICAL_BLOCK_SIZE support
- Updated prefix_prefill.py with generalized physical addressing
- Extended triton_reshape_and_cache_flash.py for head-major cache layout
- Added test_qwen3_nonstandard_block_size test case

Impact:
- Fixes Qwen3-Next-80B model inference on ROCm
- Maintains performance for power-of-2 block sizes
- GSM8K accuracy improved from 0.67% to 96%

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
khairulkabir1661 added a commit to khairulkabir1661/vllm that referenced this pull request Mar 26, 2026
…lock size (PR vllm-project#31380)

Backported from vllm-project#31380

This PR enables ROCm attention backend to support non-power-of-2 KV block
sizes, specifically fixing Qwen3-Next-80B-A3B-Thinking model inference.

Key Changes:
- Generalized Triton kernels to handle non-standard block sizes (e.g., 544)
- Added dynamic dispatching based on is_pow2 check for optimal performance
- Updated chunked_prefill_paged_decode.py with PHYSICAL_BLOCK_SIZE support
- Updated prefix_prefill.py with generalized physical addressing
- Extended triton_reshape_and_cache_flash.py for head-major cache layout
- Added test_qwen3_nonstandard_block_size test case

Impact:
- Fixes Qwen3-Next-80B model inference on ROCm
- Maintains performance for power-of-2 block sizes
- GSM8K accuracy improved from 0.67% to 96%

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Signed-off-by: khairulkabir1661 <khairulkabir1661@users.noreply.github.com>
khairulkabir1661 added a commit to khairulkabir1661/vllm that referenced this pull request Mar 26, 2026
…lock size (PR vllm-project#31380)

Backported from vllm-project#31380

This PR enables ROCm attention backend to support non-power-of-2 KV block
sizes, specifically fixing Qwen3-Next-80B-A3B-Thinking model inference.

Key Changes:
- Generalized Triton kernels to handle non-standard block sizes (e.g., 544)
- Added dynamic dispatching based on is_pow2 check for optimal performance
- Updated chunked_prefill_paged_decode.py with PHYSICAL_BLOCK_SIZE support
- Updated prefix_prefill.py with generalized physical addressing
- Extended triton_reshape_and_cache_flash.py for head-major cache layout
- Added test_qwen3_nonstandard_block_size test case

Impact:
- Fixes Qwen3-Next-80B model inference on ROCm
- Maintains performance for power-of-2 block sizes
- GSM8K accuracy improved from 0.67% to 96%

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Signed-off-by: khairulkabir1661 <khairulkabir1661@users.noreply.github.com>
khairulkabir1661 added a commit to khairulkabir1661/vllm that referenced this pull request Mar 26, 2026
…lock size (PR vllm-project#31380)

Backported from vllm-project#31380

This PR enables ROCm attention backend to support non-power-of-2 KV block
sizes, specifically fixing Qwen3-Next-80B-A3B-Thinking model inference.

Key Changes:
- Generalized Triton kernels to handle non-standard block sizes (e.g., 544)
- Added dynamic dispatching based on is_pow2 check for optimal performance
- Updated chunked_prefill_paged_decode.py with PHYSICAL_BLOCK_SIZE support
- Updated prefix_prefill.py with generalized physical addressing
- Extended triton_reshape_and_cache_flash.py for head-major cache layout
- Added test_qwen3_nonstandard_block_size test case

Impact:
- Fixes Qwen3-Next-80B model inference on ROCm
- Maintains performance for power-of-2 block sizes
- GSM8K accuracy improved from 0.67% to 96%

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Signed-off-by: khairulkabir1661 <khairulkabir1661@users.noreply.github.com>
Arist12 added a commit to amdpilot-org/amdpilot-evals that referenced this pull request Apr 4, 2026
New bugfix eval instances curated from sglang, vllm, and aiter repos:

- aiter-mxfp4-rounding-fix (ROCm/aiter#2249)
- sglang-json-nonfinite-fix (sgl-project/sglang#20714)
- vllm-corrupt-image-400 (vllm-project/vllm#38253)
- vllm-mxfp4-moe-fallback (vllm-project/vllm#35893)
- vllm-rocm-attn-blocksize-qwen35 (vllm-project/vllm#35923)
- vllm-rocm-cross-attn-dispatch (vllm-project/vllm#38450)
- vllm-rocm-fused-moe-fix (vllm-project/vllm#36100)
- vllm-rocm-lru-cache-fix (vllm-project/vllm#37547)
- vllm-rocm-nonpow2-blocksize (vllm-project/vllm#31380)

All instances validated end-to-end: test FAILS without fix, PASSES
with fix. vllm instances use rocm/vllm-dev base image.
Arist12 added a commit to amdpilot-org/amdpilot-evals that referenced this pull request Apr 6, 2026
* feat(evals): add optimization instances, fix bugs, update model config

- Add aiter-sigmoid-fastmath eval instance (Hard, aiter PR #1879)
- Add aiter-mla-reduce-optimize eval instance (Very Hard, aiter PR #1896)
- Fix curate_eval.py EVALS_DIR path bug (was evals/evals/instances)
- Add AMDPILOT_MODEL_URL env var support in curate_eval.py, run_issue.py
- Update all task.yaml model endpoints from hardcoded internal IP to
  localhost default (override via AMDPILOT_MODEL_URL env var)
- Update README with correct instance count and harness classifications

Made-with: Cursor

* refactor(config): remove dev-specific model name, GPU from task.yaml

- Remove hardcoded model: "qwen-3.5" from all 9 task.yaml files
  (model now comes from AMDPILOT_MODEL env var)
- Remove hardcoded gpu: "0" from all task.yaml container sections
  (GPU now comes from AMDPILOT_GPU env var or --gpu CLI flag)
- Fix classify_pr/classify_issue returning "fix" instead of "bugfix"
- Remove model name from generated task.yaml templates

Made-with: Cursor

* fix(eval): remove data leak from sglang-fused-moe-fix test harness

Replaced source-level inspection (checking for _is_cuda guards around
get_global_server_args calls) with pure behavioral tests: module imports
without NameError, function is accessible, and fresh subprocess import
succeeds. The old harness explicitly told the agent what fix to apply.

Made-with: Cursor

* fix(evals): fix data leak, JIT rebuild, and test harness bugs in optimization evals

- All 4 optimization Dockerfiles: delete pre-compiled JIT .so modules
  to force rebuild from reverted source (fixes incorrect cached binaries)
- All 4 optimization Dockerfiles: squash git history after source revert
  to prevent agent from viewing optimized solution via git diff/log
- aiter-mla-reduce-optimize: rewrite performance test to use persistent-mode
  MLA decode which actually exercises the reduce.cu kernel (old test
  measured PyTorch SDPA which is unrelated to the reduce kernel)
- aiter-sigmoid-fastmath: fix test to use aiter.ops.aiter_operator.sigmoid
  instead of non-existent generic activation() function

Made-with: Cursor

* fix(evals): adjust test thresholds and fix edge cases

- sigmoid: lower threshold to 18us (baseline ~22us, optimized ~15us),
  fix 1D tensor edge case, relax silu_and_mul tolerance
- mla-reduce: adjust threshold to 25us (baseline ~28us, optimized ~23us)
- all Dockerfiles: add git remote removal for complete data leak prevention

Made-with: Cursor

* fix(evals): adjust moe-align threshold from 140us to 175us

The PR's optimization achieves ~170us for E=64 (our test config), making
the original 140us target unreachable. The new 175us threshold requires
meaningful optimization (~10% over baseline ~192us) while being achievable.

Made-with: Cursor

* feat(scripts): use platform auto-detection for base image in run_issue.py

Replace hardcoded BASE_IMAGES dict with auto-detection from
amdpilot.orchestrator.platform. The issue resolver now automatically
selects the correct rocm/sgl-dev image for the host GPU and ROCm version.

Made-with: Cursor

* fix(evals): add JIT cache deletion to all aiter task descriptions

The agent was modifying kernel .cu files but changes had no effect
because AITER's JIT-compiled .so modules were not being rebuilt.
Updated all 4 aiter optimization task descriptions to include the
correct rebuild command with explicit JIT cache deletion.

Root cause of MOE align failure: agent tried 19+ optimization
strategies but none took effect because the old cached binary
was always used.

Made-with: Cursor

* feat(eval): add sglang-glm5-optimize instance for MI355X

Adapted from xsun_wip branch glm5-optimize job for 8× MI355X GPUs.
Includes Dockerfile with sglang config patch for glm_moe_dsa model type,
pre-built benchmark script, and focused task description.

* fix(eval): use sglang 20260311 base image for GLM-5 AMD support

The 20260311 image auto-selects tilelang NSA decode backend on AMD GPUs,
fixing the flash_attn_with_kvcache NameError. Upgrade transformers in
Dockerfile for native glm_moe_dsa config loading. Remove manual patch.

* feat(eval): strengthen glm5 task to require source-level optimizations

Expanded task description with concrete kernel-level targets: attention
backend tuning, MoE dispatch profiling, all-reduce analysis, CUDA graph
capture checks. Config-only tuning is explicitly marked as insufficient.

* feat(sglang-kimi-moe-tune): add Tier 0 profiling requirement to test harness (#2)

Restructured scoring from 10+10+80 to 15+10+10+65:
- Tier 0 (15 pts): Profiling evidence -- checks for rocprof output files
  (results.stats.csv) and profiling references in optimization_state.json
- Tier 3: Reduced from 80 to 65 pts, capped at 50% without profiling

Without profiling, max achievable score drops from 100 to ~52.5.
This forces the agent to actually run rocprof/torch.profiler instead of
skipping straight to blind config tuning.

Co-authored-by: jhinpan <311651cb+jhinpan@users.noreply.github.com>

* fix(sglang-glm5-optimize): improve task description with backend detection guidance

* fix(sglang-glm5-optimize): add benchmark timeout guidance for long-loading models

* refactor(glm5): simplify task description, remove overfitting

Remove step-by-step optimization walkthrough and leaked targets/baselines.
Keep only environment description, benchmark instructions, and rules.
Optimization knowledge belongs in agent skills, targets in supervisor stages.

* feat(glm5): add fast benchmark script and enable thinking

- Add bench_glm5_fast.sh: uses --disable-cuda-graph for faster iteration
  during profiling. Uses `| tee` instead of $() for real-time output
  streaming when run in background.
- Enable kimi_cli.thinking: true — required for Qwen3.5 thinking model
  to properly return content via reasoning_content field.
- Add fast_profile_command to task.yaml pointing to the fast script.

* fix(glm5): simplify task spec — remove unnecessary fast benchmark

Remove bench_glm5_fast.sh and fast_profile_command. The agent can use
--disable-cuda-graph on its own if needed. Simplify task_description.md
to match the original xsun_wip style. Keep thinking:true (required for
Qwen3.5 on sglang).

* fix(glm5): correct benchmark timing and add backend info

Benchmark takes ~25 min (not 50-60 min). Document that bench_one_batch
supports backend selection flags (--attention-backend, etc.) so agents
know backends can be configured.

* fix(glm5): source bench_config.env for reproducible verification

The benchmark script now sources /workspace/bench_config.env if present.
This lets the agent persist environment variables (e.g. backend selection)
in a file that both the agent's run and the orchestrator's verification
run will use, ensuring consistent results.

* fix(glm5): update benchmark time estimate for local NVMe

With model weights on local NVMe instead of NFS, first benchmark
run takes ~5 minutes (was ~25 minutes on NFS). Update the executor
guidance accordingly.

* chore(glm5): document local NVMe volume path in task.yaml

* feat(evals): add sglang-qwen-vl-optimize task instance

Qwen3-VL serving throughput optimization on MI355X. SGLang has a 33%
regression vs vLLM (1235 vs 1648 tok/s). Self-contained bench_serving
benchmark, PYTHONPATH fix for sglang.benchmark.datasets, fork cloned
to /workspace/sglang-fork/ to avoid namespace shadowing.

* fix(evals): add timeouts to bench_qwen_vl.sh to prevent hangs

bench_serving blocks indefinitely when the sglang server enters a stuck
graceful-shutdown state (common with aiter backend on VL models). Three
fixes:

- Wrap both warmup and benchmark bench_serving calls with `timeout`
  (default 900s, configurable via BENCH_SERVING_TIMEOUT)
- Use kill -9 in cleanup instead of SIGTERM (hung servers ignore SIGTERM)
- Kill ALL sglang child processes (scheduler, detokenizer) and free the
  port with fuser on cleanup, not just the launch_server parent
- Bump recommended agent timeout from 1200 to 2400s in task description

* fix(evals): lock triton backend in bench_qwen_vl.sh, reframe task as regression fix

The previous benchmark allowed the agent to override ATTENTION_BACKEND
via bench_config.env, enabling a bypass (switch to aiter = 2000 tok/s)
instead of fixing the actual triton regression (1235 tok/s).

- Hardcode ATTENTION_BACKEND="triton" in the benchmark script
- Remove ATTENTION_BACKEND from bench_config.env support
- Rewrite task description: fix must be source-level changes to the
  triton attention path, not a backend switch
- Update investigation areas to focus on triton kernel tuning, CUDA
  graph interaction, and VL-specific decode inefficiencies

* add sglang-kimi-k25-optimize eval instance

* update kimi-k25 task: switch executor to Kimi-K2.5, remove result leakage

- Update model_endpoint to moonshotai/Kimi-K2.5
- Rewrite task_description.md: remove all prior-run result leakage,
  make optimization approach fully flexible (no backend restrictions),
  update delivery branch to v3

* fix eval tasks: remove deprecated instance, fix test harness

- Remove aiter-moe-align-optimize instance (deprecated)
- Fix vllm-ck-mxfp4-moe test harness

* feat: LLM-powered SFT data curation pipeline

Three-phase pipeline that internalizes nudge agent signals into executor
trajectories for on-policy SFT training:

Phase 1 (regex): Structurally identify and remove _steer tool calls +
nudge tool results from the JSONL trajectory.

Phase 2 (Claude opus): For each nudge, call Claude via AMD Gateway to
rewrite the executor's thinking so it reads as independent reasoning.
The LLM sees the full nudge content, prior context, and executor response
as plain text (~3K chars per call). Never sees JSONL directly.

Phase 3 (Claude opus validation): Final check that zero nudge traces
remain. Only flags nudge-specific references — supervisor hints from
retry_with_hints are preserved as legitimate inter-trial context.

Tested on kimi-k25-optimize (25 nudges, 9 trials) and glm5-optimize
(16 nudges, 2 trials). Zero nudge traces in curated output.

* feat(evals): add 9 validated eval instances from merged PRs

New bugfix eval instances curated from sglang, vllm, and aiter repos:

- aiter-mxfp4-rounding-fix (ROCm/aiter#2249)
- sglang-json-nonfinite-fix (sgl-project/sglang#20714)
- vllm-corrupt-image-400 (vllm-project/vllm#38253)
- vllm-mxfp4-moe-fallback (vllm-project/vllm#35893)
- vllm-rocm-attn-blocksize-qwen35 (vllm-project/vllm#35923)
- vllm-rocm-cross-attn-dispatch (vllm-project/vllm#38450)
- vllm-rocm-fused-moe-fix (vllm-project/vllm#36100)
- vllm-rocm-lru-cache-fix (vllm-project/vllm#37547)
- vllm-rocm-nonpow2-blocksize (vllm-project/vllm#31380)

All instances validated end-to-end: test FAILS without fix, PASSES
with fix. vllm instances use rocm/vllm-dev base image.

* Add 14 new AMD GPU eval instances, fix 7 existing, drop 2 broken

New instances (14):
- aiter: asm-pa-headsize-fix, lru-cache-pollution, nonpow2-blocksize-crash, splitk-buffer-fix
- sglang: cutedsl-lazy-import, fp8-w8a8-gfx950-tune, kscale-vscale-fix, mla-ps-kernel-guard, shuffle-weight-attrs
- vllm: aiter-import-fix, cache-stride-fix, mla-last-page-len, mla-nhead-fix, quark-dtype-fix, spec-decode-dispatch

Fixed existing (7):
- sglang-qwen35-rope-fix, sglang-rotary-crash: Dockerfile + test harness improvements
- vllm-ck-mxfp4-moe, vllm-encoder-rocm: base image + metadata fixes
- vllm-mxfp4-moe-fallback, vllm-rocm-attn-blocksize-qwen35, vllm-rocm-nonpow2-blocksize: test harness rewrites

Dropped (2):
- sglang-fused-moe-fix: bug untestable in container (triton deps fail before buggy line)
- sglang-kimi-moe-tune: optimization, not a bug fix

All 28 remaining instances validated: score < 100 without fix, score = 100 with fix.

* Add 5 validated vLLM ROCm eval instances (follow-up)

premature-cuda-init, dynamo-arch-crash, cache-blocksize-backend,
aiter-headsize-fallback, slidingwin-cudagraph-fix — all validated
pre-fix <100%, post-fix 100%.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* Fix data leak: clean up FETCH_HEAD and git history in all Dockerfiles

The merge commit (containing the fix diff) was accessible via
FETCH_HEAD and git reflog after checkout. An agent could trivially
cat .git/FETCH_HEAD then git show to see the solution.

Fix: remove FETCH_HEAD, delete origin remote, expire reflog, and
gc prune after checkout in all 33 affected Dockerfiles.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* Fix splitk harness buffer formula to match aiter _moe_sorting_impl (#6)

* Fix splitk harness buffer formula to match aiter _moe_sorting_impl

The test harness check 6a used a ceil-based formula
(ceil(token_num*topk/block_m)*block_m) to compute sorted_token_ids
length, but the actual aiter kernel uses a different formula from
_moe_sorting_impl: token_num*topk + num_experts*block_m - topk.

These formulas diverge on many parameter combinations (e.g.,
DeepSeek V3 decode: tn=1, tk=8, bm=4, ne=8 gives ceil=8 vs
aiter=32), making the harness unreliable.

Changes:
- Align check 6 sorted_len formula with _moe_sorting_impl
- Add check 7: diverging formula regression (DeepSeek decode params)
  where ceil formula incorrectly reports no overflow but actual
  formula correctly detects it (32 > 8)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* Check 7: assert actual formula divergence (ceil vs aiter)

Per review: Check 7 previously only computed the aiter formula,
so it would pass even if the harness reverted to ceil. Now computes
both paths and asserts the divergence explicitly:
- ceil_overflow=False (ceil(8/4)*4 = 8, no overflow)
- aiter_overflow=True (8 + 8*4 - 8 = 32, overflow detected)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

* Add aiter-blockscale-stream-fix eval instance

FP8 GEMM kernel does not respect caller HIP stream context,
breaking non-default stream callers. Source inspection + runtime
harness validates the fix (aiter PR #2520).

* fix: strip solution leakage from 33 eval instances

Remove Fix: docstrings, Affected Files sections, and inline file path
references that disclosed solutions to the agent. Instances now describe
symptoms only, requiring independent diagnosis.

Changes across 33 instances (23 LEAK + 10 BORDERLINE):
- 18 harness docstrings: stripped Fix: paragraphs
- 26 task descriptions: removed Affected Files sections
- 10 task descriptions: replaced inline file paths with generic refs
- 7 instances: individual task description and docstring rewrites

* fix(eval): harden sglang-speculative-decode-fix harness for intermittent bug

- Increase prompts from 8 to 20 to reduce false positive probability
  (from ~2.3% to <10^-8 per round with 62.5% per-prompt coherent rate)
- Require 2 consecutive passing rounds for score 100.0
- Remove root-cause hint from task description (purely behavioral now)

Previous false positive: agent scored 100.0 with zero code changes on
unmodified codebase because 8/8 prompts happened to pass by chance.

* fix(eval): restart server between rounds for independent evidence

Address review feedback from @alex:
- Each round now starts a fresh server and shuts it down afterward,
  ensuring rounds are independent (different GPU init, stream state)
- Remove "non-speculative inference under TP=2" clue from task
  description — not source-grounded in the original issue report

---------

Co-authored-by: Jinn <47354855+jhinpan@users.noreply.github.com>
Co-authored-by: jhinpan <311651cb+jhinpan@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug][ROCm]: Failed to send request to Qwen3-Next

2 participants