Skip to content

[Diffusion] Add qknorm rope fuse kernel#21440

Merged
BBuf merged 16 commits intomainfrom
add_qknorm_rope_fuse_kernel
Mar 27, 2026
Merged

[Diffusion] Add qknorm rope fuse kernel#21440
BBuf merged 16 commits intomainfrom
add_qknorm_rope_fuse_kernel

Conversation

@BBuf
Copy link
Copy Markdown
Collaborator

@BBuf BBuf commented Mar 26, 2026

Summary

Made with Codex and this skills: https://github.com/sgl-project/sglang/tree/main/python/sglang/multimodal_gen/.claude/skills/sglang-diffusion-ako4all-kernel

$Radixark03 SGLang $SGLang AKO4ALL Kernel $Sglang Diffusion Benchmark Profile 帮我在sglang diffusion里面基于AKO4ALL框架继续优化diffusion kernel,运行任何模型和benchmark之前都需要保证使用的gpu是完全空闲的。现在需要你帮我优化diffusion模型里面的一个常见pattern,qk norm+rope fuse,你可以看看diffusion models实现,里面已经有大量这种pattern,只不过目前是分别调用了jit_kernel的qk norm和flashinfer的rope实现,并没有实现fuse的效果,现在我需要你帮我在jit_kernel里面实现这个fuse kernel。

This PR adds a new JIT CUDA kernel that fuses QK RMSNorm and RoPE into a single in-place kernel for diffusion models.

It also wires the fused path into the main diffusion DiT implementations that already use the QK norm + RoPE pattern, while keeping the existing split path as a fallback.

qknorm_rope_flux qknorm_rope_flux2 qknorm_rope_qwen_edit qknorm_rope_qwen qknorm_rope_zimage

What Changed

  • Added a new fused JIT kernel for QK RMSNorm + RoPE:

    • python/sglang/jit_kernel/csrc/elementwise/qknorm_rope.cuh
    • python/sglang/jit_kernel/qknorm_rope.py
  • Added a shared runtime helper:

    • python/sglang/multimodal_gen/runtime/layers/layernorm.py
    • apply_qk_norm_rope(...) uses the fused kernel when the shape/dtype/layout is supported, and falls back to split QK norm + FlashInfer RoPE otherwise.
  • Integrated the fused path into diffusion model implementations:

    • python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py
    • python/sglang/multimodal_gen/runtime/models/dits/flux.py
    • python/sglang/multimodal_gen/runtime/models/dits/flux_2.py
    • python/sglang/multimodal_gen/runtime/models/dits/zimage.py
  • Added correctness coverage and a dedicated micro benchmark:

    • python/sglang/jit_kernel/tests/test_qknorm_rope.py
    • python/sglang/jit_kernel/benchmark/bench_qknorm_rope.py

Key Optimization Points

  • Fuse QK RMSNorm and RoPE into a single kernel to remove an extra read/write pass over Q and K.
  • Keep the operation fully in-place on supported CUDA paths.
  • Reuse a shared runtime entry point so model code does not need model-specific kernel handling.
  • Add segmented position offset support so the fused path also works for FLUX / FLUX.2 dual-stream attention blocks, where text and image tokens use different RoPE position ranges.
  • Keep a safe fallback to the existing split implementation for unsupported cases.

Fused QKNorm+RoPE Kernel Design

The new kernel fuses Q/K RMSNorm and RoPE into a single warp-level in-place CUDA kernel.

Implementation highlights:

  • Each warp processes one (token, head) work item.
  • Input values and RMSNorm weights are loaded with vectorized packed loads.
  • RMSNorm is computed fully within a warp using warp-level reduction, without shared memory.
  • The normalized values stay in registers and are immediately consumed by RoPE.
  • RoPE is applied in-register:
    • pairwise rotation for the standard layout
    • __shfl_xor_sync-based lane exchange for the Neox layout
  • Results are packed back and written in place.

Key optimizations:

  • Eliminates the extra global memory round trip between split QKNorm and RoPE.
  • Merges Q and K processing into one kernel launch.
  • Uses vectorized loads/stores to reduce memory instructions.
  • Uses fp32 accumulation for RMSNorm for numerical stability.
  • Uses occupancy-aware launch sizing and JIT specialization on
    head_dim, rope_dim, is_neox, and dtype.

This PR also adds position_offset support in the shared runtime helper so the fused path can be used for segmented RoPE ranges in FLUX / FLUX.2 dual-stream attention blocks.

Micro Benchmark

All numbers below compare the split path (jit qknorm + flashinfer rope) vs the new fused JIT kernel.

Shape notation:

  • q/k shape = [B*T, H, D]
  • rope_dim is the applied rotary dimension
Case Shape Split (ms) Fused (ms) Speedup
flux_1024 B=1, T=4096, H=24, D=128, rope_dim=128 0.059520 0.043072 1.3819x
qwen_image_1024 B=1, T=4096, H=32, D=128, rope_dim=128 0.081152 0.055008 1.4753x
qwen_image_partial B=1, T=4096, H=32, D=128, rope_dim=64 0.079680 0.054560 1.4604x
zimage_1024 B=1, T=4096, H=30, D=128, rope_dim=128 0.074112 0.051008 1.4529x
batch2_medium B=2, T=2048, H=24, D=128, rope_dim=128 0.059488 0.043232 1.3760x

Weighted micro benchmark speedup: 1.4387x

End-to-End Denoise Stage

Model Split Denoise (s) Fused Denoise (s) Delta Speedup
qwen 14.43 12.36 -14.35% 1.1675x
qwen-edit 28.62 28.26 -1.26% 1.0127x
flux 6.495 6.421 -1.14% 1.0116x
flux2 22.314 22.311 -0.01% 1.0001x
zimage 0.723 0.712 -1.47% 1.0149x

Commands used for the end-to-end denoise benchmark:

# Qwen Image (split)
sglang generate \
  --model-path=Qwen/Qwen-Image-2512 \
  --prompt="A futuristic cyberpunk city at night, neon lights reflecting on wet streets, highly detailed, 8k" \
  --negative-prompt=" " \
  --log-level=info \
  --seed=42 \
  --width=1024 \
  --height=1024 \
  --num-inference-steps=50 \
  --guidance-scale=4.0 \
  --dit-cpu-offload false \
  --text-encoder-cpu-offload false \
  --save-output \
  --warmup \
  --enable-torch-compile \
  --perf-dump-path outputs/qknorm_rope_pr/qwen_split.json

# Qwen Image (fused)
sglang generate \
  --model-path=Qwen/Qwen-Image-2512 \
  --prompt="A futuristic cyberpunk city at night, neon lights reflecting on wet streets, highly detailed, 8k" \
  --negative-prompt=" " \
  --log-level=info \
  --seed=42 \
  --width=1024 \
  --height=1024 \
  --num-inference-steps=50 \
  --guidance-scale=4.0 \
  --dit-cpu-offload false \
  --text-encoder-cpu-offload false \
  --save-output \
  --warmup \
  --enable-torch-compile \
  --perf-dump-path outputs/qknorm_rope_pr/qwen_fused.json

# Qwen Image Edit (split)
sglang generate \
  --model-path=Qwen/Qwen-Image-Edit-2511 \
  --prompt="Transform into anime style" \
  --negative-prompt=" " \
  --image-path=<ASSET_DIR>/cat.png \
  --log-level=info \
  --seed=42 \
  --width=1024 \
  --height=1024 \
  --num-inference-steps=50 \
  --guidance-scale=4.0 \
  --dit-cpu-offload false \
  --text-encoder-cpu-offload false \
  --save-output \
  --warmup \
  --enable-torch-compile \
  --perf-dump-path outputs/qknorm_rope_pr/qwen_edit_split.json

# Qwen Image Edit (fused)
sglang generate \
  --model-path=Qwen/Qwen-Image-Edit-2511 \
  --prompt="Transform into anime style" \
  --negative-prompt=" " \
  --image-path=<ASSET_DIR>/cat.png \
  --log-level=info \
  --seed=42 \
  --width=1024 \
  --height=1024 \
  --num-inference-steps=50 \
  --guidance-scale=4.0 \
  --dit-cpu-offload false \
  --text-encoder-cpu-offload false \
  --save-output \
  --warmup \
  --enable-torch-compile \
  --perf-dump-path outputs/qknorm_rope_pr/qwen_edit_fused.json

# FLUX.1-dev (split)
sglang generate \
  --model-path=black-forest-labs/FLUX.1-dev \
  --prompt="A futuristic cyberpunk city at night, neon lights reflecting on wet streets, highly detailed, 8k" \
  --log-level=info \
  --seed=42 \
  --width=1024 \
  --height=1024 \
  --num-inference-steps=50 \
  --guidance-scale=4.0 \
  --save-output \
  --warmup \
  --enable-torch-compile \
  --perf-dump-path outputs/qknorm_rope_pr/flux_split.json

# FLUX.1-dev (fused)
sglang generate \
  --model-path=black-forest-labs/FLUX.1-dev \
  --prompt="A futuristic cyberpunk city at night, neon lights reflecting on wet streets, highly detailed, 8k" \
  --log-level=info \
  --seed=42 \
  --width=1024 \
  --height=1024 \
  --num-inference-steps=50 \
  --guidance-scale=4.0 \
  --save-output \
  --warmup \
  --enable-torch-compile \
  --perf-dump-path outputs/qknorm_rope_pr/flux_fused.json

# FLUX.2-dev (split)
sglang generate \
  --model-path=black-forest-labs/FLUX.2-dev \
  --prompt="A Logo With Bold Large Text: SGL Diffusion" \
  --log-level=info \
  --seed=42 \
  --width=1024 \
  --height=1024 \
  --dit-layerwise-offload false \
  --dit-cpu-offload false \
  --text-encoder-cpu-offload true \
  --vae-cpu-offload false \
  --save-output \
  --warmup \
  --enable-torch-compile \
  --perf-dump-path outputs/qknorm_rope_pr/flux2_split.json

# FLUX.2-dev (fused)
sglang generate \
  --model-path=black-forest-labs/FLUX.2-dev \
  --prompt="A Logo With Bold Large Text: SGL Diffusion" \
  --log-level=info \
  --seed=42 \
  --width=1024 \
  --height=1024 \
  --dit-layerwise-offload false \
  --dit-cpu-offload false \
  --text-encoder-cpu-offload true \
  --vae-cpu-offload false \
  --save-output \
  --warmup \
  --enable-torch-compile \
  --perf-dump-path outputs/qknorm_rope_pr/flux2_fused.json

# Z-Image-Turbo (split)
sglang generate \
  --model-path=Tongyi-MAI/Z-Image-Turbo \
  --prompt="A fantasy landscape with mountains and a river, detailed, vibrant colors" \
  --log-level=info \
  --seed=42 \
  --width=1024 \
  --height=1024 \
  --num-inference-steps=9 \
  --guidance-scale=0.0 \
  --dit-cpu-offload false \
  --text-encoder-cpu-offload false \
  --save-output \
  --warmup \
  --enable-torch-compile \
  --perf-dump-path outputs/qknorm_rope_pr/zimage_split.json

# Z-Image-Turbo (fused)
sglang generate \
  --model-path=Tongyi-MAI/Z-Image-Turbo \
  --prompt="A fantasy landscape with mountains and a river, detailed, vibrant colors" \
  --log-level=info \
  --seed=42 \
  --width=1024 \
  --height=1024 \
  --num-inference-steps=9 \
  --guidance-scale=0.0 \
  --dit-cpu-offload false \
  --text-encoder-cpu-offload false \
  --save-output \
  --warmup \
  --enable-torch-compile \
  --perf-dump-path outputs/qknorm_rope_pr/zimage_fused.json

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@github-actions github-actions Bot added diffusion SGLang Diffusion jit-kernel labels Mar 26, 2026
@BBuf
Copy link
Copy Markdown
Collaborator Author

BBuf commented Mar 26, 2026

/tag-and-rerun-ci

@BBuf
Copy link
Copy Markdown
Collaborator Author

BBuf commented Mar 26, 2026

@mickqian @yingluosanqian It's ready now.

@BBuf BBuf changed the title Add qknorm rope fuse kernel [Diffusion] Add qknorm rope fuse kernel Mar 26, 2026
Comment thread python/sglang/jit_kernel/csrc/elementwise/qknorm_rope.cuh Outdated
Comment thread python/sglang/jit_kernel/csrc/diffusion/qknorm_rope.cuh
Comment thread python/sglang/jit_kernel/tests/test_qknorm_rope.py Outdated
Comment thread python/sglang/jit_kernel/csrc/elementwise/qknorm_rope.cuh Outdated
head_dim=self.head_dim,
allow_inplace=True,
)
if cos_sin_cache is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we use a helper function to generalize these logic, and put it in layernorm.py?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Copy Markdown
Collaborator

@mickqian mickqian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

excellent

@yhyang201
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@BBuf BBuf merged commit d633ab7 into main Mar 27, 2026
143 of 209 checks passed
@BBuf BBuf deleted the add_qknorm_rope_fuse_kernel branch March 27, 2026 06:27
satyamk7054 pushed a commit to satyamk7054/sglang that referenced this pull request Apr 3, 2026
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants