[AMD] fused qk gemma norm kernels to reduce four kernels #23575
[AMD] fused qk gemma norm kernels to reduce four kernels #23575HaiShaw merged 1 commit intosgl-project:mainfrom
Conversation
…se + 2 norm) into one triton kernel
There was a problem hiding this comment.
Code Review
This pull request introduces a fused Triton kernel for Gemma RMSNorm to optimize the QK normalization process in Qwen 3.5 models on HIP-supported hardware. The changes aim to improve efficiency by handling both query and key normalization in a single pass. Feedback indicates that the kernel's hardcoded output types may cause data corruption with float32 inputs and that the use of reshape in the wrapper function might lead to unnecessary memory copies, contradicting the performance goals mentioned in the documentation.
| Passes input strides to the kernel so non-contiguous tensors (e.g. from | ||
| qkv.split()) are read correctly without an extra .contiguous() copy. | ||
| """ | ||
| q_flat = q.reshape(-1, head_dim) | ||
| k_flat = k.reshape(-1, head_dim) |
There was a problem hiding this comment.
The docstring claims to avoid an extra .contiguous() copy by passing strides, but q.reshape(-1, head_dim) will internally trigger a copy if the tensor is non-contiguous (which is common for slices from qkv.split()). To truly avoid a copy, the kernel should be designed to accept the original multi-dimensional tensor and its strides, or you should use view and handle potential contiguity errors explicitly.
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
/tag-and-rerun-ci |
|
@amd-bot ci-status |
…#23575) Co-authored-by: root <root@smci355-ccs-aus-g12-26.cs-aus.dcgpu>
Co-author: @hubertlu-tw
Motivation
From the profiling data, apply_qk_norm function will bring 4 kernels launch on ROCm platform compared two kernels overlapped on CUDA platform. In order to reduce the e2e time cost, fused 4 kernels into one triton kernel
Modifications
models/utils.py
Add triton kernel implementation for fused kernel
models/qwen3_5.py
Check the path of hip to apply the fused triton kernel
Accuracy Tests
Server launch command
Speed Tests and Profiling
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci