[AMD] [diffusion] feat: enable AITer GroupNorm for VAE decode on ROCm#20170
[AMD] [diffusion] feat: enable AITer GroupNorm for VAE decode on ROCm#20170HaiShaw merged 2 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the VAE decoding stage by enabling an optimized AITer GroupNorm implementation specifically for ROCm platforms. The change aims to improve performance by dynamically swapping out standard PyTorch GroupNorm layers with AITer-optimized versions when a specific environment variable is enabled. This provides a configurable performance boost without altering the core model architecture by integrating the replacement logic into the model loading process. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request optimizes ROCm platforms by enabling AITer's GroupNorm implementation within the VAE decoding stage. It achieves this by recursively replacing torch.nn.GroupNorm modules with their AITer counterparts when the SGLANG_USE_AITER_VAE environment variable is set, and includes a fallback mechanism. No vulnerabilities or significant security issues were found in the changes. A suggestion has been made to improve the robustness of the error handling during the module replacement process to prevent potential model inconsistencies.
53c23f2 to
5631434
Compare
|
/tag-and-rerun-ci |
HaiShaw
left a comment
There was a problem hiding this comment.
Suggest to change SGLANG_USE_AITER_VAE to SGLANG_USE_ROCM_VAE, which/env can be reused for future ROCm based VAE optimizations.
|
@yctseng0211 tested on both gfx942 and gfx950? |
|
|
@mickqian All CI (Nvidia + AMD) passed and PR is approved, ready for merge — SGLDHelper bot |
| logger = init_logger(__name__) | ||
|
|
||
| _is_hip = current_platform.is_hip() | ||
| _use_aiter_vae = get_bool_env_var("SGLANG_USE_ROCM_VAE") and _is_hip |
There was a problem hiding this comment.
please register them via envs.py
| _use_aiter_vae = get_bool_env_var("SGLANG_USE_ROCM_VAE") and _is_hip | ||
|
|
||
|
|
||
| def _replace_groupnorm_with_aiter(module: torch.nn.Module) -> int: |
There was a problem hiding this comment.
These patterns (including _is_hip) should be avoided.
Decoding is a critical core stage of the system. If we allow platform-specific conditionals to accumulate here, the complexity will grow rapidly and become increasingly hard to manage.
Over time, this will significantly hurt maintainability and extensibility.
Suggestions:
- Introduce a Platform abstraction
- Consider a plugin-based architecture
discussions are welcomed
There was a problem hiding this comment.
The codebase already has a well-established pattern for isolating platform-specific logic via Platform polymorphic methods:
Platform.get_attn_backend_cls_str()— platform decides which attention backend to usePlatform.enable_dit_layerwise_offload_for_wan_by_default()— platform decides default behaviorPlatform.get_device_communicator_cls()— platform decides communicator
We can directly reuse this Strategy pattern:
1. Add optimize_vae() hook to Platform base class
# interface.py
class Platform:
...
@classmethod
def optimize_vae(cls, vae: torch.nn.Module) -> torch.nn.Module:
“”"Apply platform-specific optimizations to VAE after loading.“”"
return vae2. RocmPlatform overrides it with AITer GroupNorm replacement
# rocm.py
class RocmPlatform(Platform):
...
@classmethod
def optimize_vae(cls, vae: torch.nn.Module) -> torch.nn.Module:
if not envs.SGLANG_USE_ROCM_VAE:
return vae
try:
from aiter.ops.groupnorm import GroupNorm as AITerGroupNorm
count = cls._replace_groupnorm(vae, AITerGroupNorm)
logger.info(“replaced %d GroupNorm with AITer GroupNorm”, count)
except Exception as e:
logger.warning(“failed to apply AITer GroupNorm: %s”, e)
return vae
@staticmethod
def _replace_groupnorm(module, aiter_gn_cls):
count = 0
for name, child in module.named_children():
if isinstance(child, torch.nn.GroupNorm):
replacement = aiter_gn_cls(child.num_groups, child.num_channels, ...)
replacement.weight = child.weight # zero-copy weight sharing
replacement.bias = child.bias
setattr(module, name, replacement)
count += 1
else:
count += RocmPlatform._replace_groupnorm(child, aiter_gn_cls)
return count3. Call it from VAELoader — one line
VAELoader already does similar model-level optimizations (e.g. _convert_conv3d_weights_to_channels_last_3d). Just add one line before returning:
# vae_loader.py
vae = current_platform.optimize_vae(vae)
return vae4. decoding.py — zero changes needed
The core decoding stage stays completely platform-agnostic. No _is_hip, no _use_aiter_vae, no _replace_groupnorm_with_aiter.
Why this is the right approach:
- Follows the existing architecture exactly (same pattern as
get_attn_backend_cls_str) - Separation of concerns: ROCm code lives only in
rocm.py; core stages don’t know about platforms - Extensible: future NPU/MUSA VAE optimizations just override
optimize_vae()in their own platform file — no core code changes needed - Env var
SGLANG_USE_ROCM_VAEis registered inenvs.pyand consumed internally byRocmPlatform - Minimal diff: ~3 files, a few lines each
There was a problem hiding this comment.
@mickqian Thanks for the detailed review and suggestions, I'll follow the Platform abstraction pattern and address this in a follow-up PR.
There was a problem hiding this comment.
Hi @mickqian could you help review this PR-20496 #20496?
This PR moves platform-specific AITer GroupNorm replacement out of decoding.py into the Platform polymorphic pattern:
- Register SGLANG_USE_ROCM_VAE in envs.py
- Add optimize_vae() hook to Platform base class
- Override in RocmPlatform with AITer GroupNorm replacement
- Call from VAELoader at load time
- Remove all platform-specific logic from decoding.py
Thanks.
…sgl-project#20170) Co-authored-by: HaiShaw <hixiao@gmail.com>
…sgl-project#20170) Co-authored-by: HaiShaw <hixiao@gmail.com>
…sgl-project#20170) Co-authored-by: HaiShaw <hixiao@gmail.com>
…sgl-project#20170) Co-authored-by: HaiShaw <hixiao@gmail.com>
…sgl-project#20170) Co-authored-by: HaiShaw <hixiao@gmail.com>
…sgl-project#20170) Co-authored-by: HaiShaw <hixiao@gmail.com>
Summary
Replace
nn.GroupNormwith aiter's CK GroupNorm in VAE decode stage on ROCm, controlled bySGLANG_USE_ROCM_VAE=1(off by default).Motivation
Profiling Z-Image-Turbo VAE decode on MI300X/MI35X shows PyTorch's native
RowwiseMomentsCUDAKernelis poorly optimized for AMD GPUs, 1.57ms per call vs AITer's 0.13ms (12x faster).Results (Z-Image-Turbo, 1024×1024, MI300X)
max_concurrency=8, num_prompts=64
Modifications
decoding.py: recursively swapnn.GroupNorm→aiter.GroupNormafter VAE loadDesign note
GroupNorm lives in diffusers'
Decoder/ResnetBlock2D/UNetMidBlock2D, not in SGLang code.We use runtime module replacement (
named_children+setattr) rather than patching diffusers directly so that:SGLANG_USE_ROCM_VAE=1) and has no effect on non-ROCm platformsAccuracy Tests
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci