Skip to content

[AMD] [diffusion] feat: enable AITer GroupNorm for VAE decode on ROCm#20170

Merged
HaiShaw merged 2 commits intosgl-project:mainfrom
yctseng0211:enable_aiter_gn
Mar 12, 2026
Merged

[AMD] [diffusion] feat: enable AITer GroupNorm for VAE decode on ROCm#20170
HaiShaw merged 2 commits intosgl-project:mainfrom
yctseng0211:enable_aiter_gn

Conversation

@yctseng0211
Copy link
Copy Markdown
Collaborator

@yctseng0211 yctseng0211 commented Mar 9, 2026

Summary

Replace nn.GroupNorm with aiter's CK GroupNorm in VAE decode stage on ROCm, controlled by SGLANG_USE_ROCM_VAE=1 (off by default).

Motivation

Profiling Z-Image-Turbo VAE decode on MI300X/MI35X shows PyTorch's native RowwiseMomentsCUDAKernel is poorly optimized for AMD GPUs, 1.57ms per call vs AITer's 0.13ms (12x faster).

Results (Z-Image-Turbo, 1024×1024, MI300X)

max_concurrency=8, num_prompts=64

Metric Baseline AITer GroupNorm Change
GroupNorm kernel time 51.4ms 6.0ms -88%
VAE decode total ~205ms ~162ms -21%
End-to-end throughput 0.745 0.770 +3~5%

Modifications

  • decoding.py: recursively swap nn.GroupNormaiter.GroupNorm after VAE load
  • Zero-copy weight sharing; safe fallback on error

Design note

GroupNorm lives in diffusers' Decoder / ResnetBlock2D / UNetMidBlock2D, not in SGLang code.
We use runtime module replacement (named_children + setattr) rather than patching diffusers directly so that:

  • The change is opt-in (SGLANG_USE_ROCM_VAE=1) and has no effect on non-ROCm platforms
  • Diffusers upgrades won't break or overwrite the optimization
  • Only the VAE decode stage is affected; other pipelines remain untouched

Accuracy Tests

export SGLANG_USE_ROCM_VAE=1
sglang generate --model-path=Tongyi-MAI/Z-Image-Turbo --log-level=info --prompt='A fantasy landscape with mountains and a river, detailed, vibrant colors' --width=1024 --height=1024 --num-inference-steps=9 --guidance-scale=0.0 --seed=42 --save-output --enable-torch-compile --warmup --dit-cpu-offload false --text-encoder-cpu-offload false
image

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@github-actions github-actions Bot added the diffusion SGLang Diffusion label Mar 9, 2026
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the VAE decoding stage by enabling an optimized AITer GroupNorm implementation specifically for ROCm platforms. The change aims to improve performance by dynamically swapping out standard PyTorch GroupNorm layers with AITer-optimized versions when a specific environment variable is enabled. This provides a configurable performance boost without altering the core model architecture by integrating the replacement logic into the model loading process.

Highlights

  • AITer GroupNorm Integration: Introduced functionality to conditionally replace torch.nn.GroupNorm with aiter.ops.groupnorm.GroupNorm within the VAE (Variational Autoencoder) stage.
  • Platform and Environment Variable Control: The GroupNorm replacement is activated only when the SGLANG_USE_AITER_VAE environment variable is set to True and the current platform is ROCm (HIP).
  • Recursive Module Replacement: Added a new utility function, _replace_groupnorm_with_aiter, which recursively traverses a PyTorch module to identify and replace GroupNorm layers with their AITer counterparts.
  • Controlled Application: A flag _aiter_gn_applied was added to the VAEDecodingStage to ensure that the GroupNorm replacement process is executed only once during the model's lifecycle.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding.py
    • Added _is_hip and _use_aiter_vae global variables to determine if AITer VAE optimization should be applied.
    • Implemented _replace_groupnorm_with_aiter function to recursively replace torch.nn.GroupNorm with aiter.ops.groupnorm.GroupNorm.
    • Introduced _aiter_gn_applied attribute in VAEDecodingStage to track if AITer GroupNorm has been applied.
    • Created _maybe_apply_aiter_groupnorm method to conditionally apply the AITer GroupNorm replacement to the VAE.
    • Called _maybe_apply_aiter_groupnorm within the load_model method of VAEDecodingStage.
    • Imported get_bool_env_var for environment variable parsing.
Activity
  • The pull request description provides a standard template, indicating no specific human activity or comments have been recorded yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes ROCm platforms by enabling AITer's GroupNorm implementation within the VAE decoding stage. It achieves this by recursively replacing torch.nn.GroupNorm modules with their AITer counterparts when the SGLANG_USE_AITER_VAE environment variable is set, and includes a fallback mechanism. No vulnerabilities or significant security issues were found in the changes. A suggestion has been made to improve the robustness of the error handling during the module replacement process to prevent potential model inconsistencies.

Comment thread python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding.py Outdated
@yctseng0211 yctseng0211 changed the title [AMD] diffusion - enable aiter groupnorm in VAE stage on rocm [AMD] [diffusion] feat: enable AITer GroupNorm for VAE decode on ROCm Mar 9, 2026
@yctseng0211 yctseng0211 marked this pull request as ready for review March 12, 2026 08:22
@yctseng0211 yctseng0211 requested review from HaiShaw and yichiche March 12, 2026 08:23
@HaiShaw
Copy link
Copy Markdown
Collaborator

HaiShaw commented Mar 12, 2026

/tag-and-rerun-ci

Copy link
Copy Markdown
Collaborator

@HaiShaw HaiShaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest to change SGLANG_USE_AITER_VAE to SGLANG_USE_ROCM_VAE, which/env can be reused for future ROCm based VAE optimizations.

@HaiShaw
Copy link
Copy Markdown
Collaborator

HaiShaw commented Mar 12, 2026

@yctseng0211 tested on both gfx942 and gfx950?

@yctseng0211
Copy link
Copy Markdown
Collaborator Author

@yctseng0211 tested on both gfx942 and gfx950?
yes, tested on mi325, mi300, mi350

Copy link
Copy Markdown
Collaborator

@HaiShaw HaiShaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@HaiShaw HaiShaw added the amd label Mar 12, 2026
@yhyang201
Copy link
Copy Markdown
Collaborator

@mickqian All CI (Nvidia + AMD) passed and PR is approved, ready for merge

— SGLDHelper bot

Copy link
Copy Markdown
Collaborator

@yichiche yichiche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@HaiShaw HaiShaw merged commit 78a467c into sgl-project:main Mar 12, 2026
69 of 74 checks passed
logger = init_logger(__name__)

_is_hip = current_platform.is_hip()
_use_aiter_vae = get_bool_env_var("SGLANG_USE_ROCM_VAE") and _is_hip
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please register them via envs.py

_use_aiter_vae = get_bool_env_var("SGLANG_USE_ROCM_VAE") and _is_hip


def _replace_groupnorm_with_aiter(module: torch.nn.Module) -> int:
Copy link
Copy Markdown
Collaborator

@mickqian mickqian Mar 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These patterns (including _is_hip) should be avoided.
Decoding is a critical core stage of the system. If we allow platform-specific conditionals to accumulate here, the complexity will grow rapidly and become increasingly hard to manage.

Over time, this will significantly hurt maintainability and extensibility.

Suggestions:

  • Introduce a Platform abstraction
  • Consider a plugin-based architecture

discussions are welcomed

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The codebase already has a well-established pattern for isolating platform-specific logic via Platform polymorphic methods:

  • Platform.get_attn_backend_cls_str() — platform decides which attention backend to use
  • Platform.enable_dit_layerwise_offload_for_wan_by_default() — platform decides default behavior
  • Platform.get_device_communicator_cls() — platform decides communicator

We can directly reuse this Strategy pattern:

1. Add optimize_vae() hook to Platform base class

# interface.py
class Platform:
    ...
    @classmethod
    def optimize_vae(cls, vae: torch.nn.Module) -> torch.nn.Module:
        “”"Apply platform-specific optimizations to VAE after loading.“”"
        return vae

2. RocmPlatform overrides it with AITer GroupNorm replacement

# rocm.py
class RocmPlatform(Platform):
    ...
    @classmethod
    def optimize_vae(cls, vae: torch.nn.Module) -> torch.nn.Module:
        if not envs.SGLANG_USE_ROCM_VAE:
            return vae
        try:
            from aiter.ops.groupnorm import GroupNorm as AITerGroupNorm
            count = cls._replace_groupnorm(vae, AITerGroupNorm)
            logger.info(“replaced %d GroupNorm with AITer GroupNorm”, count)
        except Exception as e:
            logger.warning(“failed to apply AITer GroupNorm: %s”, e)
        return vae

    @staticmethod
    def _replace_groupnorm(module, aiter_gn_cls):
        count = 0
        for name, child in module.named_children():
            if isinstance(child, torch.nn.GroupNorm):
                replacement = aiter_gn_cls(child.num_groups, child.num_channels, ...)
                replacement.weight = child.weight  # zero-copy weight sharing
                replacement.bias = child.bias
                setattr(module, name, replacement)
                count += 1
            else:
                count += RocmPlatform._replace_groupnorm(child, aiter_gn_cls)
        return count

3. Call it from VAELoader — one line

VAELoader already does similar model-level optimizations (e.g. _convert_conv3d_weights_to_channels_last_3d). Just add one line before returning:

# vae_loader.py
vae = current_platform.optimize_vae(vae)
return vae

4. decoding.py — zero changes needed

The core decoding stage stays completely platform-agnostic. No _is_hip, no _use_aiter_vae, no _replace_groupnorm_with_aiter.

Why this is the right approach:

  • Follows the existing architecture exactly (same pattern as get_attn_backend_cls_str)
  • Separation of concerns: ROCm code lives only in rocm.py; core stages don’t know about platforms
  • Extensible: future NPU/MUSA VAE optimizations just override optimize_vae() in their own platform file — no core code changes needed
  • Env var SGLANG_USE_ROCM_VAE is registered in envs.py and consumed internally by RocmPlatform
  • Minimal diff: ~3 files, a few lines each

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mickqian Thanks for the detailed review and suggestions, I'll follow the Platform abstraction pattern and address this in a follow-up PR.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mickqian could you help review this PR-20496 #20496?
This PR moves platform-specific AITer GroupNorm replacement out of decoding.py into the Platform polymorphic pattern:

  • Register SGLANG_USE_ROCM_VAE in envs.py
  • Add optimize_vae() hook to Platform base class
  • Override in RocmPlatform with AITer GroupNorm replacement
  • Call from VAELoader at load time
  • Remove all platform-specific logic from decoding.py

Thanks.

liubiyongge pushed a commit to liubiyongge/sglang that referenced this pull request Mar 13, 2026
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Mar 15, 2026
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
0-693 pushed a commit to 0-693/sglang that referenced this pull request Mar 25, 2026
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants