Skip to content

[Diffusion] Optimize diffusion Triton rotary embedding by processing multiple heads per token#21387

Merged
BBuf merged 5 commits intomainfrom
optimize-diffusion-triton-rotary
Mar 26, 2026
Merged

[Diffusion] Optimize diffusion Triton rotary embedding by processing multiple heads per token#21387
BBuf merged 5 commits intomainfrom
optimize-diffusion-triton-rotary

Conversation

@BBuf
Copy link
Copy Markdown
Collaborator

@BBuf BBuf commented Mar 25, 2026

Motivation

Made this with AKO4ALL kernel optimization skill and CodeX(gpt5.4 high)

This PR optimizes the Triton apply_rotary_embedding kernel used by diffusion models.

The main idea is to process multiple heads for the same token in one kernel program instead of launching one program per (token, head). This improves cos/sin row reuse, reduces redundant loads, and lowers launch overhead. The kernel now uses a 2D launch layout and autotunes both BLOCK_HEADS and BLOCK_HS_HALF.

HunyuanVideo

Scope Baseline Optimized Speedup
Denoise 43321.437 ms 42857.272 ms 1.0108x
End-to-end 60402.167 ms 59642.583 ms 1.0127x

end2end 1.3% speed up.

sglang generate --model-path hunyuanvideo-community/HunyuanVideo --text-encoder-cpu-offload --pin-cpu-memory --prompt \"A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window.\" --save-output --num-frames 65 --width 848 --height 480 --num-inference-steps 30 --seed=42 --save-output --warmup --enable-torch-compile true --dit-layerwise-offload true --dit-cpu-offload false --vae-cpu-offload false --text-encoder-cpu-offload true --warmup --enable-torch-compile
  • main:
hunyuan_rotary_baseline_20260325.mp4
  • pr:
hunyuan_rotary_tuned_20260325.mp4

Microbenchmark

Shape Baseline Optimized Speedup
(1, 115200, 24, 128) 1.667968 ms 0.536736 ms 3.1076x
(1, 32760, 12, 128) 0.242592 ms 0.082656 ms 2.9350x
(2, 4096, 24, 128) 0.124384 ms 0.042016 ms 2.9604x
(1, 4096, 32, 64) 0.084768 ms 0.026144 ms 3.2424x
(4096, 24, 128) 0.065504 ms 0.024800 ms 2.6413x
aggregate 1.348186 ms 0.434541 ms 3.1026x

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant optimization to the Triton apply_rotary_embedding kernel, a critical component in diffusion models. By refactoring the kernel to process multiple attention heads concurrently for each token, the changes aim to enhance GPU resource utilization, minimize redundant memory loads, and reduce kernel launch overhead. The result is a more efficient rotary embedding computation, demonstrated by notable performance gains in targeted microbenchmarks and a measurable speedup in a full diffusion model workload.

Highlights

  • Kernel Optimization Strategy: The apply_rotary_embedding Triton kernel was optimized to process multiple attention heads for a single token within one kernel launch, moving away from a per-(token, head) launch.
  • Performance Improvements: This change led to significant speedups in microbenchmarks (up to 3.24x) and a modest 1.3% end-to-end speedup for HunyuanVideo.
  • Triton Autotuning Enhancements: The kernel now utilizes a 2D launch layout and includes BLOCK_HEADS in its autotuning configurations, alongside BLOCK_HS_HALF, to improve cos/sin row reuse and reduce overhead.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@BBuf
Copy link
Copy Markdown
Collaborator Author

BBuf commented Mar 25, 2026

/tag-and-rerun-ci

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the Triton rotary embedding kernel to improve performance by parallelizing across heads. This includes updating the kernel's program ID structure, modifying tensor views, and adjusting the kernel launch grid and stride arguments. The review suggests re-introducing larger BLOCK_HS_HALF values in the autotune configurations to prevent potential suboptimal performance for models with large head_size.

Comment on lines 9 to 15
configs=[
triton.Config({"BLOCK_HS_HALF": 32}, num_warps=2),
triton.Config({"BLOCK_HS_HALF": 64}, num_warps=4),
triton.Config({"BLOCK_HS_HALF": 128}, num_warps=4),
triton.Config({"BLOCK_HS_HALF": 256}, num_warps=8),
triton.Config({"BLOCK_HEADS": 1, "BLOCK_HS_HALF": 32}, num_warps=2),
triton.Config({"BLOCK_HEADS": 2, "BLOCK_HS_HALF": 32}, num_warps=2),
triton.Config({"BLOCK_HEADS": 4, "BLOCK_HS_HALF": 32}, num_warps=4),
triton.Config({"BLOCK_HEADS": 4, "BLOCK_HS_HALF": 64}, num_warps=4),
triton.Config({"BLOCK_HEADS": 8, "BLOCK_HS_HALF": 64}, num_warps=8),
],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new autotune configurations have a maximum BLOCK_HS_HALF of 64, while the previous version included values up to 256. For models with a large head_size, this could result in suboptimal performance because the kernel would need more loop iterations. It would be beneficial to re-introduce configurations with larger BLOCK_HS_HALF values (e.g., 128, 256) to give the autotuner more choices, especially for combinations with smaller BLOCK_HEADS values. This will help ensure optimal performance across a wider range of model architectures.

    configs=[
        triton.Config({"BLOCK_HEADS": 1, "BLOCK_HS_HALF": 32}, num_warps=2),
        triton.Config({"BLOCK_HEADS": 2, "BLOCK_HS_HALF": 32}, num_warps=2),
        triton.Config({"BLOCK_HEADS": 4, "BLOCK_HS_HALF": 32}, num_warps=4),
        triton.Config({"BLOCK_HEADS": 1, "BLOCK_HS_HALF": 64}, num_warps=4),
        triton.Config({"BLOCK_HEADS": 2, "BLOCK_HS_HALF": 64}, num_warps=4),
        triton.Config({"BLOCK_HEADS": 4, "BLOCK_HS_HALF": 64}, num_warps=4),
        triton.Config({"BLOCK_HEADS": 8, "BLOCK_HS_HALF": 64}, num_warps=8),
        triton.Config({"BLOCK_HEADS": 1, "BLOCK_HS_HALF": 128}, num_warps=4),
        triton.Config({"BLOCK_HEADS": 2, "BLOCK_HS_HALF": 128}, num_warps=8),
        triton.Config({"BLOCK_HEADS": 4, "BLOCK_HS_HALF": 128}, num_warps=8),
        triton.Config({"BLOCK_HEADS": 1, "BLOCK_HS_HALF": 256}, num_warps=8),
    ],

@BBuf
Copy link
Copy Markdown
Collaborator Author

BBuf commented Mar 25, 2026

/rerun-failed-ci

@BBuf BBuf merged commit 6f2b51a into main Mar 26, 2026
87 of 128 checks passed
@BBuf BBuf deleted the optimize-diffusion-triton-rotary branch March 26, 2026 00:59
satyamk7054 pushed a commit to satyamk7054/sglang that referenced this pull request Apr 3, 2026
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants