Skip to content

[AMD] fused qk gemma norm kernels to reduce four kernels #23575

Merged
HaiShaw merged 1 commit intosgl-project:mainfrom
HaiShaw:fuse/qk_norm_for_qwen3_5
Apr 25, 2026
Merged

[AMD] fused qk gemma norm kernels to reduce four kernels #23575
HaiShaw merged 1 commit intosgl-project:mainfrom
HaiShaw:fuse/qk_norm_for_qwen3_5

Conversation

@kkHuang-amd
Copy link
Copy Markdown
Collaborator

@kkHuang-amd kkHuang-amd commented Apr 23, 2026

Co-author: @hubertlu-tw

Motivation

From the profiling data, apply_qk_norm function will bring 4 kernels launch on ROCm platform compared two kernels overlapped on CUDA platform. In order to reduce the e2e time cost, fused 4 kernels into one triton kernel

image

Modifications

models/utils.py
Add triton kernel implementation for fused kernel
models/qwen3_5.py
Check the path of hip to apply the fused triton kernel

Accuracy Tests

Server launch command

SGLANG_USE_AITER_UNIFIED_ATTN=1 SGLANG_USE_AITER=1 \
python3 -m sglang.launch_server \
  --model-path /dockerx/data/models/Qwen3.5-397B-A17B-FP8/ --tp 8 \
  --attention-backend aiter --trust-remote-code \
  --chunked-prefill-size 32768 \
  --model-loader-extra-config '{"enable_multithread_load": true}' \
  --watchdog-timeout 1200 --mem-fraction-static 0.9 \
  --host 0.0.0.0 --port 8000 --disable-radix-cache \
  --enable-aiter-allreduce-fusion --max-running-requests 128 \
  --page-size 16

Speed Tests and Profiling

Concurrency total token throughput before total token throughput after Ratio
4 864.11 883.77 +2.2%
8 1582.05 1624.43 +2.6%
16 2888.91 2945.39 +1.9%
32 5005.01 5021.38 +0.3%
64 7609.9 7634.51 +0.3%

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fused Triton kernel for Gemma RMSNorm to optimize the QK normalization process in Qwen 3.5 models on HIP-supported hardware. The changes aim to improve efficiency by handling both query and key normalization in a single pass. Feedback indicates that the kernel's hardcoded output types may cause data corruption with float32 inputs and that the use of reshape in the wrapper function might lead to unnecessary memory copies, contradicting the performance goals mentioned in the documentation.

Comment thread python/sglang/srt/models/utils.py
Comment on lines +518 to +522
Passes input strides to the kernel so non-contiguous tensors (e.g. from
qkv.split()) are read correctly without an extra .contiguous() copy.
"""
q_flat = q.reshape(-1, head_dim)
k_flat = k.reshape(-1, head_dim)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring claims to avoid an extra .contiguous() copy by passing strides, but q.reshape(-1, head_dim) will internally trigger a copy if the tensor is non-contiguous (which is common for slices from qkv.split()). To truly avoid a copy, the kernel should be designed to accept the original multi-dimensional tensor and its strides, or you should use view and handle potential contiguity errors explicitly.

@kkHuang-amd kkHuang-amd marked this pull request as ready for review April 24, 2026 01:53
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@kkHuang-amd kkHuang-amd changed the title [Opt] fused qk gemma norm kernels to reduce four kernels [AMD] fused qk gemma norm kernels to reduce four kernels Apr 24, 2026
@HaiShaw
Copy link
Copy Markdown
Collaborator

HaiShaw commented Apr 24, 2026

/tag-and-rerun-ci

@HaiShaw
Copy link
Copy Markdown
Collaborator

HaiShaw commented Apr 25, 2026

@amd-bot ci-status

@HaiShaw HaiShaw merged commit 393252f into sgl-project:main Apr 25, 2026
141 of 175 checks passed
vguduruTT pushed a commit to vguduruTT/sglang that referenced this pull request May 2, 2026
…#23575)

Co-authored-by: root <root@smci355-ccs-aus-g12-26.cs-aus.dcgpu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants