Skip to content

[NPU][diffusion] npu support enable_torch_compile for torchair backend on diffusion models #20687

Merged
sglang-npu-bot merged 3 commits intosgl-project:mainfrom
Alisehen:codex/denoising-torch-compile-fix
Mar 18, 2026
Merged

[NPU][diffusion] npu support enable_torch_compile for torchair backend on diffusion models #20687
sglang-npu-bot merged 3 commits intosgl-project:mainfrom
Alisehen:codex/denoising-torch-compile-fix

Conversation

@Alisehen
Copy link
Copy Markdown
Contributor

@Alisehen Alisehen commented Mar 16, 2026

Motivation

support torch.compile for npu

Modifications

Explicitly use the torchair backend for torch.compile on the NPU path in the denoising stage.

Benchmarking and Profiling

910B, Qwen-image

# base: 23.38s
sglang generate \
  --model-path=/model/qwen-image/ \
  --prompt='A fantasy landscape with mountains and a river, detailed, vibrant colors' \
  --width=1024 \
  --height=1024 \
  --vae-precision bf16 \
  --num-inference-steps=20 \
  --warmup \
# torch.compile 21.01s
sglang generate \
  --model-path=/model/qwen-image/ \
  --prompt='A fantasy landscape with mountains and a river, detailed, vibrant colors' \
  --width=1024 \
  --height=1024 \
  --vae-precision bf16 \
  --num-inference-steps=20 \
  --enable-torch-compile \
  --warmup \

910B, FLUX.1

# base: 16.91s
sglang generate \
  --model-path=/model/FLUX.1/ \
  --prompt='A fantasy landscape with mountains and a river, detailed, vibrant colors' \
  --width=1024 \
  --height=1024 \
  --vae-precision bf16 \
  --num-inference-steps=28 \
  --warmup \
# torch.compile 15.58s
sglang generate \
  --model-path=/model/FLUX.1/ \
  --prompt='A fantasy landscape with mountains and a river, detailed, vibrant colors' \
  --width=1024 \
  --height=1024 \
  --vae-precision bf16 \
  --num-inference-steps=28 \
  --enable-torch-compile \
  --warmup \

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@github-actions github-actions Bot added the diffusion SGLang Diffusion label Mar 16, 2026
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces support for torch.compile on NPU devices within the diffusion model denoising stage. By leveraging the torchair backend specifically for NPU, it aims to optimize the performance of diffusion model generation, as evidenced by benchmark results showing a reduction in inference time.

Highlights

  • NPU torch.compile support: Enabled torch.compile functionality for NPU (Neural Processing Unit) devices within the diffusion model denoising stage.
  • Torchair backend integration: Configured torch.compile to explicitly use the torchair backend when running on NPU, ensuring optimal compilation for this hardware.
  • Performance improvement: Demonstrated a performance improvement of approximately 10% for diffusion model generation on NPU when torch.compile is enabled, reducing inference time from 23.38s to 21.01s.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py
    • Imported get_compiler_backend for dynamic backend selection.
    • Modified the _maybe_enable_torch_compile method to conditionally apply torch.compile settings. For NPU platforms, it now explicitly uses the torchair backend and sets dynamic=False. For other platforms, it retains the existing torch._inductor configuration and mode selection.
Activity
  • No specific activity (comments, reviews, etc.) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@Alisehen
Copy link
Copy Markdown
Contributor Author

@ping1jing2

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for torch.compile on NPU devices by using the torchair backend. The implementation correctly isolates the NPU-specific compilation logic. However, the current structure could be improved to be more robust and extensible for other hardware backends like HPU, which are also handled by get_compiler_backend. I've suggested a refactoring to make the backend selection more generic.

Comment on lines +141 to +162
compile_kwargs: dict[str, Any] = {"fullgraph": False, "dynamic": None}

if current_platform.is_npu():
backend = get_compiler_backend()
compile_kwargs["backend"] = backend
compile_kwargs["dynamic"] = False
logger.info("Compiling transformer with torchair backend on NPU")
else:
try:
import torch._inductor.config as _inductor_cfg

_inductor_cfg.reorder_for_compute_comm_overlap = True
except ImportError:
pass
mode = os.environ.get(
"SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"
)
compile_kwargs["mode"] = mode
logger.info(f"Compiling transformer with mode: {mode}")

_inductor_cfg.reorder_for_compute_comm_overlap = True
except ImportError:
pass
mode = os.environ.get("SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs")
logger.info(f"Compiling transformer with mode: {mode}")
# TODO(triple-mu): support customized fullgraph and dynamic in the future
module.compile(mode=mode, fullgraph=False, dynamic=None)
module.compile(**compile_kwargs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation correctly adds support for NPU, but the if/else structure is a bit rigid. It specifically checks for NPU and assumes the else block is for inductor-based backends (like CUDA). This doesn't account for other potential backends that get_compiler_backend might return, such as for HPU.

A more robust approach would be to fetch the backend name first and then use a dispatcher-like pattern to set the appropriate compilation arguments. This would make the code more extensible for future hardware support.

Here is a suggested refactoring:

        compile_kwargs: dict[str, Any] = {"fullgraph": False, "dynamic": None}
        backend = get_compiler_backend()

        if current_platform.is_npu():
            compile_kwargs["backend"] = backend
            compile_kwargs["dynamic"] = False
            logger.info("Compiling transformer with torchair backend on NPU")
        elif current_platform.is_hpu():
            compile_kwargs["backend"] = backend
            logger.info("Compiling transformer with hpu_backend on HPU")
        else:  # Default to inductor for CUDA, etc.
            try:
                import torch._inductor.config as _inductor_cfg

                _inductor_cfg.reorder_for_compute_comm_overlap = True
            except ImportError:
                pass
            mode = os.environ.get(
                "SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"
            )
            compile_kwargs["mode"] = mode
            logger.info(f"Compiling transformer with mode: {mode}")

        # TODO(triple-mu): support customized fullgraph and dynamic in the future
        module.compile(**compile_kwargs)

@ping1jing2 ping1jing2 self-assigned this Mar 17, 2026
@ping1jing2
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

return
try:
import torch._inductor.config as _inductor_cfg
compile_kwargs: dict[str, Any] = {"fullgraph": False, "dynamic": None}
Copy link
Copy Markdown
Contributor

@zhuyijie88 zhuyijie88 Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about _maybe_enable_torch_compile of MOVADenoisingStage in mova.py?

Copy link
Copy Markdown
Contributor Author

@Alisehen Alisehen Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I’ve aligned it with the denoising-stage change and now explicitly use the torchair backend on NPU there as well.

@Alisehen Alisehen force-pushed the codex/denoising-torch-compile-fix branch from 177dd2f to cf57f0d Compare March 17, 2026 08:48
logger.info("Compiling %s with mode: %s", module.__class__.__name__, mode)
compile_kwargs: dict[str, object] = {"fullgraph": False, "dynamic": None}

if current_platform.is_npu():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please avoid this

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The platform split itself is hard to avoid here because the torch.compile configuration is inherently different on NPU vs non-NPU. We need an explicit backend selection on NPU, while other platforms should continue to use the existing inductor-mode path. I’m happy to move _maybe_enable_torch_compile to a shared helper, but some form of platform-specific branching is still required.

@sglang-npu-bot sglang-npu-bot merged commit c7a7174 into sgl-project:main Mar 18, 2026
68 of 71 checks passed
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
0-693 pushed a commit to 0-693/sglang that referenced this pull request Mar 25, 2026
dutsc pushed a commit to dutsc/sglang that referenced this pull request Mar 30, 2026
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

diffusion SGLang Diffusion run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants