Skip to content

[Diffusion] LTX-2 Support PR3#19151

Merged
mickqian merged 8 commits intosgl-project:mainfrom
gmixiaojin:LTX-2
Feb 24, 2026
Merged

[Diffusion] LTX-2 Support PR3#19151
mickqian merged 8 commits intosgl-project:mainfrom
gmixiaojin:LTX-2

Conversation

@gmixiaojin
Copy link
Copy Markdown
Contributor

@gmixiaojin gmixiaojin commented Feb 22, 2026

Motivation

Support LTX-2 Video & Audio Joint model.

This PR mainly fix bugs in Sglang LTX-2 pipelines

How to use

generate

  • T2V
sglang generate \
  --model-path LTX-2_Model \
  --prompt "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a breath-taking, movie-like shot. [Audio Description] Subtle ambient space atmosphere with deep bass hum, gentle cracking and crunching sounds as the eggshell breaks apart, soft whooshing of lunar dust particles floating in low gravity, quiet breathing sounds from inside the spacesuit, a faint heartbeat rhythm building tension, ethereal cinematic synth pads creating a sense of wonder and discovery, all blended with vast empty silence of space." \
  --negative-prompt "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static." \
  --width 768 --height 512 --num-frames 121 --fps 24 \
  --num-inference-steps 40 --guidance-scale 4.0 \
  --seed 1024 \
  --output-path ./output/ --output-file-name ltx2_sample.mp4
  • TI2V
sglang generate \
  --model-path LTX-2_Model \
  --prompt "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a breath-taking, movie-like shot. [Audio Description] Subtle ambient space atmosphere with deep bass hum, gentle cracking and crunching sounds as the eggshell breaks apart, soft whooshing of lunar dust particles floating in low gravity, quiet breathing sounds from inside the spacesuit, a faint heartbeat rhythm building tension, ethereal cinematic synth pads creating a sense of wonder and discovery, all blended with vast empty silence of space." \
  --image-path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" \
  --negative-prompt "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static." \
  --width 768 --height 512 --num-frames 121 --fps 24 \
  --num-inference-steps 40 --guidance-scale 4.0 \
  --seed 1024 \
  --output-path ./output/ --output-file-name ltx2_sample.mp4
  • SP, TP & Cache DiT (8 GPUs)
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
SGLANG_CACHE_DIT_ENABLED=true \
SGLANG_CACHE_DIT_SCM_PRESET=medium \
sglang generate \
  --model-path LTX-2_Model \
  --num-gpus 8 \
  --tp-size 2 \
  --sp-degree 4 \
  --ulysses-degree 2 \
  --ring-degree 2 \
  --prompt "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a breath-taking, movie-like shot. [Audio Description] Subtle ambient space atmosphere with deep bass hum, gentle cracking and crunching sounds as the eggshell breaks apart, soft whooshing of lunar dust particles floating in low gravity, quiet breathing sounds from inside the spacesuit, a faint heartbeat rhythm building tension, ethereal cinematic synth pads creating a sense of wonder and discovery, all blended with vast empty silence of space." \
  --image-path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" \
  --negative-prompt "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static." \
  --width 768 --height 512 --num-frames 121 --fps 24 \
  --num-inference-steps 40 --guidance-scale 4.0 \
  --seed 1024 \
  --output-path ./output/ --output-file-name ltx2_sample.mp4

Bugs fix

configs/models/dits/ltx_2.py

  • Changed audio_positional_embedding_max_pos default from [2048] to [20] to match Official LTX-2 audio positional embedding configuration

configs/pipeline_configs/ltx_2.py

  • Added prepare_latent_shape() method to return packed latent shape [B, seq, C] directly, matching Official's latent preparation
  • Changed prepare_audio_latent_shape() return shape from (B, C, L, M) to (B, L, C*M) to match Official's packed audio latent format
  • Changed prepare_sigmas() from np.linspace(1.0, 1.0/steps, steps) to torch.linspace(1.0, 0.0, steps+1)[:-1] to match Official's sigma schedule
  • Added early return in maybe_pack_latents() to skip packing if latents are already 3D [B, seq, C]
  • Added early return in maybe_pack_audio_latents() to skip packing if audio latents are already 3D [B, T, C*F]

runtime/models/adapter/ltx_2_connector.py

  • Removed .float() upcast in apply_split_rotary_emb to keep computation in native dtype (bfloat16) instead of upcasting to float32
  • Added dtype parameter to LTX2RotaryPosEmbed1d.forward() to allow explicit dtype control for cos/sin freqs
  • Added dtype casting for cos_freqs and sin_freqs when dtype is specified
  • Changed self.rope() call to pass dtype=hidden_states.dtype for consistent precision

runtime/models/dits/ltx_2.py

  • Removed .float() upcast in apply_split_rotary_emb to keep rotary embedding computation in native dtype
  • Added self.coords_dtype attribute: bfloat16 for video, float32 for audio, to control coordinate precision per modality
  • Added explicit coords = coords.to(self.coords_dtype) before coordinate averaging
  • Changed RoPE output to explicitly cast to torch.bfloat16 instead of returning in computed dtype
  • Changed rope_double_precision to read from HuggingFace config first (hf_config.get(...)) with arch config as fallback
  • Changed adaln_single calls to pass hidden_dtype=hidden_states.dtype for explicit dtype control in AdaLN timestamp embeddings
  • Changed timestep * ts_ca_mult to timestep * self.av_ca_timestep_scale_multiplier to avoid intermediate float64 computation
  • Wrapped self.norm_out() and self.audio_norm_out() in torch.autocast(enabled=False) to disable autocast during final layer norm, ensuring deterministic precision

runtime/models/encoders/gemma_3.py

  • Added _rotate_half() helper function implementing HuggingFace's rotate_half formula
  • Added CPU-precomputed _hf_inv_freq buffer in Gemma3Attention.__init__() to match HuggingFace's exact inverse frequency computation
  • Added custom rotary_emb() method that uses _hf_inv_freq with rotate_half formula, replacing SGLang's shared RotaryEmbedding which uses a slightly different numerical formula
  • Changed attention mask dtype from torch.float32 to query.dtype for consistent precision
  • Changed float("-inf") mask fill to torch.finfo(query.dtype).min to avoid overflow in bfloat16
  • Changed embed_scale multiplication to use explicit torch.tensor(self.embed_scale, device, dtype) to avoid precision mismatch
  • Commented out position_ids = position_ids + 1 to match Official LTX-2's position indexing

runtime/pipelines/ltx_2_pipeline.py

  • Added LTX2FlowMatchScheduler class that overrides _time_shift_exponential to use torch float32 arithmetic instead of numpy float64, matching Official LTX-2's scheduler precision
  • Added initialize_pipeline() method to replace the default FlowMatchEulerDiscreteScheduler with LTX2FlowMatchScheduler
  • Changed prepare_mu() to use fixed token count 4096 (matching Official's MAX_SHIFT_ANCHOR default) instead of computing from actual latent dimensions

runtime/pipelines_core/stages/denoising_av.py

  • Changed image conditioning encode path to use encode_dtype = batch.latents.dtype (bfloat16) for both input tensor and VAE weights, instead of hardcoded torch.float32
  • Added VAE dtype save/restore: saves original_dtype from vae_precision, converts VAE to encode_dtype before encoding, restores original_dtype after per-channel normalization
  • Autocast remains disabled during encode (since vae_precision stays "fp32", vae_autocast_enabled is False), matching Official's pure bf16 encode without autocast

runtime/pipelines_core/stages/decoding_av.py

  • Added self.vae.to(original_dtype) after decode to restore VAE dtype (previously VAE stayed in bfloat16 after the forced .to(torch.bfloat16) conversion)

Precision Alignment [verified on a single GPU]

To reproduce pixel-perfect alignment with Official LTX-2 (ltx-pipelines / ltx-core), the following deterministic torch settings must be applied to both pipelines before any model execution:

import torch
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_cudnn_sdp(False)

These settings disable FlashAttention and Memory-Efficient SDPA backends, forcing PyTorch to use the deterministic Math SDPA backend. Without them, non-deterministic attention kernels produce different results across runs even with identical inputs and seeds.

How to apply

Both Official and SGLang pipelines must load these settings before any torch operation. This is done via a sitecustomize.py + torch_settings.py pair placed on PYTHONPATH:

mkdir -p /path/to/_site

cat > /path/to/_site/torch_settings.py << 'EOF'
import torch
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_cudnn_sdp(False)
EOF

cat > /path/to/_site/sitecustomize.py << 'EOF'
import torch_settings
EOF

Then prepend this directory to PYTHONPATH when launching either pipeline:

PYTHONPATH="/path/to/_site:$PYTHONPATH" python -m ltx_pipelines.ti2vid_one_stage ...
PYTHONPATH="/path/to/_site:$PYTHONPATH" sglang generate ...

Python automatically imports sitecustomize.py at startup, which triggers torch_settings.py and applies the deterministic settings before any model code runs.

Verified alignment

With these settings applied and identical generation parameters (prompt, seed, resolution, num_frames, guidance_scale, num_inference_steps, image conditioning), the two pipelines produce bit-for-bit identical float pixel outputs (max_diff = 0.0).

Checklist

LTX-2

  • Accuracy alignment achieved
  • Support TI2V inference
  • Implement Gemma 3 support using SGLang (not Transformers)
  • Gemma 3: support TP (tensor parallelism)
  • Support RoPE / SP
  • Wire up / validate Cache-DiT
  • Support TP for the primary model
  • Verify correctness of generate and the server
  • Ensure structural and code-style consistency across modules
  • Polish code style; remove redundant comments
  • Clean up debug code
  • Fix lint and update to latest main
  • Follow up on PR reviews

LTX-2 Stack PR

  • Support one-stage generation
  • Support distilled generation
  • Support IC LoRA
  • Support keyframe_interpolation
  • Support two-stage generation

Review Process

  1. Ping Merge Oncalls to start the PR flow. See [PR Merge Process].
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • \tag-run-ci-label, \rerun-failed-ci, \tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@github-actions github-actions Bot added the diffusion SGLang Diffusion label Feb 22, 2026
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @gmixiaojin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on enhancing support for the LTX-2 Video & Audio Joint model within SGLang. The primary goal is to achieve pixel-perfect alignment with the official LTX-2 implementation by resolving several numerical precision bugs and aligning various pipeline components. Key changes include refining data type handling in attention mechanisms and VAE operations, updating latent shape calculations, and ensuring consistent scheduler behavior. These modifications are crucial for accurate model reproduction and reliable video and audio generation.

Highlights

  • LTX-2 Model Alignment: Implemented numerous fixes and adjustments across the codebase to achieve pixel-perfect alignment with the Official LTX-2 model, addressing discrepancies in latent shapes, sigma schedules, and numerical precision.
  • Precision Control Enhancements: Refined dtype handling in various components, including Rotary Positional Embeddings (RoPE), AdaLN timestamp embeddings, and VAE encoding/decoding, to ensure consistent bfloat16 computation where appropriate and prevent unintended upcasting.
  • Gemma 3 Integration Refinement: Introduced HuggingFace-exact Rotary Embedding computation for Gemma 3, including a custom _rotate_half function and CPU-precomputed inverse frequencies, along with adjustments to attention mask dtypes and embed_scale application.
  • Scheduler and Latent Processing Updates: Modified the flow-matching scheduler to use torch.float32 for exponential time shifts and updated latent/audio latent shape preparation methods to directly return packed shapes and include early exit conditions.
  • VAE Encoding/Decoding Improvements: Reworked VAE image conditioning encoding to use bfloat16 for input tensors and VAE weights, added VAE dtype save/restore mechanisms, and replaced generic scaling/shifting with per-channel normalization.
  • Deterministic Settings Guidance: Provided explicit instructions and a sitecustomize.py example to enforce deterministic PyTorch settings (disabling FlashAttention, etc.) for reproducible results, crucial for bit-for-bit alignment.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/multimodal_gen/configs/models/dits/ltx_2.py
    • Updated the default maximum position for audio positional embedding.
  • python/sglang/multimodal_gen/configs/pipeline_configs/ltx_2.py
    • Added prepare_latent_shape method for packed latent dimensions.
    • Modified prepare_audio_latent_shape to return a packed format.
    • Changed prepare_sigmas to use torch.linspace for sigma schedule.
    • Introduced early return logic in maybe_pack_latents and maybe_pack_audio_latents for already packed tensors.
  • python/sglang/multimodal_gen/runtime/models/adapter/ltx_2_connector.py
    • Removed unnecessary float upcasting in apply_split_rotary_emb.
    • Added dtype parameter and casting for cos_freqs and sin_freqs in LTX2RotaryPosEmbed1d.
    • Passed hidden_states.dtype to self.rope for consistent precision.
  • python/sglang/multimodal_gen/runtime/models/dits/ltx_2.py
    • Removed explicit .float() upcasting in apply_split_rotary_emb.
    • Introduced coords_dtype attribute for modality-specific coordinate precision.
    • Explicitly cast coords to coords_dtype before averaging.
    • Ensured RoPE output is cast to torch.bfloat16.
    • Prioritized HuggingFace config for rope_double_precision.
    • Passed hidden_dtype to adaln_single calls.
    • Adjusted timestep multiplication to avoid float64 intermediate.
    • Disabled autocast for final layer normalization in norm_out and audio_norm_out.
  • python/sglang/multimodal_gen/runtime/models/encoders/gemma_3.py
    • Added a _rotate_half helper function.
    • Implemented CPU-precomputed _hf_inv_freq for exact HuggingFace RoPE.
    • Introduced a custom rotary_emb method using _hf_inv_freq and _rotate_half.
    • Changed attention mask dtype to query.dtype and fill value to torch.finfo(query.dtype).min.
    • Ensured embed_scale multiplication uses explicit torch.tensor for precision.
    • Commented out a position_ids increment for official LTX-2 position indexing.
  • python/sglang/multimodal_gen/runtime/pipelines/ltx_2_pipeline.py
    • Introduced a custom scheduler for torch.float32 arithmetic.
    • Added pipeline initialization to replace the default scheduler.
    • Updated prepare_mu to use a fixed token count of 4096.
  • python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding_av.py
    • Implemented VAE dtype save/restore logic around decoding.
    • Replaced latent scaling/shifting with explicit mean/std application.
  • python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_av.py
    • Added video codec compression and tensor resizing utilities.
    • Updated image conditioning encoding to use bfloat16 and VAE dtype save/restore.
    • Replaced latent scaling/shifting with per-channel normalization for image latents.
    • Explicitly cast intermediate denoising calculations to latents.dtype.
Activity
  • The author has addressed several bugs in the Sglang LTX-2 pipelines.
  • The pull request includes detailed instructions for achieving pixel-perfect alignment with the Official LTX-2 model by enforcing deterministic PyTorch settings.
  • The author has verified accuracy alignment and support for TI2V inference, Gemma 3 (with TP), RoPE/SP, and Cache-DiT.
  • The author is following up on PR reviews and ensuring CI tests pass.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces LTX-2 model support, primarily focusing on bug fixes in Sglang LTX-2 pipelines and precision alignment with the official LTX-2 implementation. Key changes include adjustments to audio positional embedding, latent shape preparation, sigma schedules, and rotary embedding calculations to ensure numerical consistency. The PR also adds a custom LTX2FlowMatchScheduler to align with official precision requirements and modifies VAE encoding/decoding to handle bfloat16 correctly. Overall, the changes aim to achieve pixel-perfect alignment with the official LTX-2 model outputs.

)
if self.audio_positional_embedding_max_pos is None:
self.audio_positional_embedding_max_pos = [2048]
self.audio_positional_embedding_max_pos = [20]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The change from [2048] to [20] for audio_positional_embedding_max_pos is noted in the PR description as matching the Official LTX-2 audio positional embedding configuration. This is a critical fix for correctness.

Comment on lines +142 to +155
def prepare_latent_shape(self, batch, batch_size, num_frames):
"""Return packed latent shape [B, seq, C] directly."""
height = batch.height // self.vae_scale_factor
width = batch.width // self.vae_scale_factor

post_patch_num_frames = num_frames // self.patch_size_t
post_patch_height = height // self.patch_size
post_patch_width = width // self.patch_size
seq_len = post_patch_num_frames * post_patch_height * post_patch_width

num_channels = self.in_channels * self.patch_size_t * self.patch_size * self.patch_size

shape = (batch_size, seq_len, num_channels)
return shape
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The prepare_latent_shape method is added to return the packed latent shape [B, seq, C] directly. This is crucial for matching the Official's latent preparation and ensuring correct input to the model.

Comment on lines +228 to +230
# If already packed (3D shape [B, seq, C]), skip packing
if latents.dim() == 3:
return latents
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Adding an early return in maybe_pack_latents() if latents are already 3D [B, seq, C] is a good optimization and prevents unnecessary reprocessing. This improves efficiency and robustness.

Suggested change
# If already packed (3D shape [B, seq, C]), skip packing
if latents.dim() == 3:
return latents
if latents.dim() == 3:
return latents

Comment on lines +360 to +362
# If already packed (3D shape [B, T, C*F]), skip packing
if latents.dim() == 3:
return latents
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Adding an early return in maybe_pack_audio_latents() if audio latents are already 3D [B, T, C*F] is a good optimization and prevents unnecessary reprocessing. This improves efficiency and robustness.

Suggested change
# If already packed (3D shape [B, T, C*F]), skip packing
if latents.dim() == 3:
return latents
if latents.dim() == 3:
return latents

Comment on lines +3 to +8
import math
import os

import numpy as np
import torch
from diffusers import FlowMatchEulerDiscreteScheduler
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Importing math, numpy as np, torch, and FlowMatchEulerDiscreteScheduler is necessary for the new LTX2FlowMatchScheduler and other numerical operations introduced in this PR. This is a functional change.

pass
if not vae_autocast_enabled:
latents = latents.to(vae_dtype)
decode_output = self.vae.decode(latents)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Removing the latents = latents.to(vae_dtype) line when vae_autocast_enabled is false is correct because the VAE is already explicitly cast to torch.bfloat16 earlier. This avoids redundant casting and potential precision issues.

Comment on lines +2 to +7
import math
import time
from io import BytesIO

import av
import numpy as np
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Importing math, BytesIO, av, and numpy as np is necessary for the new video codec compression and tensor resizing utilities. This is a functional change.

@gmixiaojin gmixiaojin mentioned this pull request Feb 22, 2026
18 tasks
@gmixiaojin gmixiaojin marked this pull request as ready for review February 22, 2026 10:27
@mickqian
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@mickqian
Copy link
Copy Markdown
Collaborator

/rerun-failed-ic

@mickqian
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

Comment thread python/sglang/multimodal_gen/configs/pipeline_configs/ltx_2.py Outdated

# (..., 2, r)
split_x = x.reshape(*x.shape[:-1], 2, r).float() # Explicitly upcast to float
split_x = x.reshape(*x.shape[:-1], 2, r)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my Opinion, apply_split_rotary_emb is a bit complicated to understand. There are some suggestions to imporve its readablity:

  1. Using unbind API to get first_x/second_x like apply_interleaved_rotary_emb
  2. Using transpose API instead of swapaxes, as swapaxes is a Numpy-style method

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the swapaxestranspose change: we agreed and changed in line 37 above
For the unbind suggestion: we tried refactoring apply_split_rotary_emb to use unbind + separate multiply/subtract (like apply_interleaved_rotary_emb), but it breaks pixel-level alignment with the official LTX-2 pipeline. The current code uses addcmul which is a fused multiply-add operation — under bf16, it rounds only once at the end, whereas separate a * b - c * d rounds at each intermediate step. The per-element difference is tiny (~1 ULP), but it compounds over 40 denoising steps over many attention layers, resulting in a max_diff of ~2.38 in the final output. So we'll keep the addcmul approach as-is to preserve numerical correctness.

sin_freqs = torch.swapaxes(sin_freq, 1, 2)

return cos_freqs, sin_freqs
return cos_freqs.to(torch.bfloat16), sin_freqs.to(torch.bfloat16)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add some comment for this dtype cast? It's confused and why don't you use self.coords_dtype?

Copy link
Copy Markdown
Contributor Author

@gmixiaojin gmixiaojin Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We added some commets for this dtype cast and unfortunately we can't use self.coords_dtype here because it controls the intermediate coordinate precision (fp32 for audio, bf16 for video), while this cast ensures the final RoPE output always matches the model's bf16 compute dtype. Otherwise, it won't align to the official LTX-2 model.

Comment thread python/sglang/multimodal_gen/runtime/models/encoders/gemma_3.py Outdated
Comment thread python/sglang/multimodal_gen/runtime/pipelines/ltx_2_pipeline.py
@ping1jing2
Copy link
Copy Markdown
Collaborator

it seems like there are two code style in this model

@github-actions github-actions Bot added the dependencies Pull requests that update a dependency file label Feb 23, 2026
@gmixiaojin
Copy link
Copy Markdown
Contributor Author

it seems like there are two code style in this model

Good catch — the style difference comes from the fact that these changes are precision-alignment fixes on top of the existing SGLang codebase. We kept the surrounding code untouched to minimize the diff and focused on correctness first (ensuring pixel-perfect alignment with the official LTX-2 pipeline). Happy to do a follow-up pass to unify the code style once the functional changes are reviewed and merged.

@github-actions github-actions Bot added the npu label Feb 24, 2026
@mickqian
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@mickqian mickqian merged commit fcfd964 into sgl-project:main Feb 24, 2026
190 of 210 checks passed
magicYang1573 pushed a commit to magicYang1573/sglang that referenced this pull request Mar 9, 2026
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

dependencies Pull requests that update a dependency file diffusion SGLang Diffusion npu run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants