Skip to content

[Diffusion] Speed up Qwen select01 Triton modulation kernels#21318

Merged
BBuf merged 1 commit intomainfrom
optimize_qwen_select01_kernel
Mar 25, 2026
Merged

[Diffusion] Speed up Qwen select01 Triton modulation kernels#21318
BBuf merged 1 commit intomainfrom
optimize_qwen_select01_kernel

Conversation

@BBuf
Copy link
Copy Markdown
Collaborator

@BBuf BBuf commented Mar 24, 2026

Summary

This PR keeps the Qwen select01 Triton kernel version that showed a stable end-to-end win in Qwen-Image denoise.

The final change set:

  • switches modulation loads to pointer-select so each row only loads the chosen scale/shift/gate branch
  • pins both select01 launches to num_warps=4, num_stages=4
  • drops later scalar-base / 8w1s / residual-only experiments from the active code path because they did not produce stable model-level gains
图片

Made with Codex(AKO4ALL framework and SGLang Diffusion SKILL).

Implementation

The optimized kernels are:

  • fuse_layernorm_scale_shift_gate_select01_kernel
  • fuse_residual_layernorm_scale_shift_gate_select01_kernel

Main code changes:

  • build branch-specific pointer tensors for scale0/1, shift0/1, and gate0/1
  • select pointers with tl.where(idx, ...)
  • issue only one load for each modulation tensor on the chosen branch
  • keep the validated 4w4s launch config for both kernels

Validation

Correctness:

  • python -m py_compile python/sglang/jit_kernel/diffusion/triton/scale_shift.py
  • pytest -q python/sglang/jit_kernel/tests/test_qwen_image_modulation.py -q

Performance:

Benchmark Baseline New Delta
AKO aggregate microbench 0.039589 ms 0.036297 ms 1.090681x
layernorm (2,2048,3072) 0.036896 ms 0.031680 ms 1.164647x
residual (2,2048,3072) 0.049248 ms 0.046528 ms 1.058459x
Qwen DenoisingStage 12788.15 ms 12432.77 ms -355.39 ms (-2.8%)
Qwen E2E 12838.08 ms 12526.77 ms -311.30 ms (-2.4%)

Nsight Compute

A representative ncu check on the layernorm select01 kernel at (2,2048,3072) shows:

  • gpu__time_duration.avg: 35.744 us -> 28.704 us
  • launch__registers_per_thread: 96 -> 72

The optimized kernel reduces single-launch latency from the baseline by 19.7% (28.70 us here), while also increasing L2 and DRAM throughput. This indicates that the win is not just a launch-parameter artifact: the kernel is doing less wasted work and using the memory hierarchy more effectively.

图片

Executed and issued instruction counts both drop by about 18.5%, which is consistent with the kernel rewrite: the pointer-select path avoids loading and computing both modulation branches before selecting one.

图片

This kernel is register-limited, and lowering register pressure improves occupancy materially. In this tuned version, registers per thread drop from 96 to 72, which raises both theoretical and achieved occupancy and improves latency hiding.

图片

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the performance of Qwen-Image denoise operations by refining the underlying Triton modulation kernels. The changes focus on reducing redundant memory operations and standardizing kernel launch parameters, leading to a measurable speedup in both microbenchmarks and end-to-end denoising stages. The improvements are validated through correctness tests and detailed performance profiling, demonstrating a more efficient utilization of GPU resources.

Highlights

  • Performance Optimization: Optimized Qwen select01 Triton modulation kernels by switching modulation loads to a pointer-select mechanism, ensuring only the chosen scale/shift/gate branch is loaded per row.
  • Kernel Configuration Pinning: Pinned the launch configurations for both fuse_layernorm_scale_shift_gate_select01_kernel and fuse_residual_layernorm_scale_shift_gate_select01_kernel to num_warps=4 and num_stages=4 for stable performance gains.
  • Code Simplification: Removed experimental scalar-base, 8w1s, and residual-only code paths that did not yield stable model-level gains, streamlining the active codebase.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@BBuf
Copy link
Copy Markdown
Collaborator Author

BBuf commented Mar 24, 2026

/tag-and-rerun-ci

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a significant performance optimization for the Qwen select01 Triton modulation kernels. By switching to pointer-select for modulation loads, the kernels now only load the necessary scale/shift/gate branch, avoiding redundant memory accesses. Pinning num_warps=4 and num_stages=4 further refines the kernel launch configuration. The provided Nsight Compute analysis and benchmarks clearly demonstrate the positive impact of these changes, showing reduced GPU time, lower register usage, and improved throughput. The changes are well-justified and directly address the goal of speeding up these kernels.

@yhyang201
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@yhyang201
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@BBuf BBuf merged commit 68f7f00 into main Mar 25, 2026
342 of 391 checks passed
@BBuf BBuf deleted the optimize_qwen_select01_kernel branch March 25, 2026 12:48
satyamk7054 pushed a commit to satyamk7054/sglang that referenced this pull request Apr 3, 2026
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants