[Diffusion] LTX-2 Support PR3#19151
Conversation
Summary of ChangesHello @gmixiaojin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on enhancing support for the LTX-2 Video & Audio Joint model within SGLang. The primary goal is to achieve pixel-perfect alignment with the official LTX-2 implementation by resolving several numerical precision bugs and aligning various pipeline components. Key changes include refining data type handling in attention mechanisms and VAE operations, updating latent shape calculations, and ensuring consistent scheduler behavior. These modifications are crucial for accurate model reproduction and reliable video and audio generation. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
The pull request introduces LTX-2 model support, primarily focusing on bug fixes in Sglang LTX-2 pipelines and precision alignment with the official LTX-2 implementation. Key changes include adjustments to audio positional embedding, latent shape preparation, sigma schedules, and rotary embedding calculations to ensure numerical consistency. The PR also adds a custom LTX2FlowMatchScheduler to align with official precision requirements and modifies VAE encoding/decoding to handle bfloat16 correctly. Overall, the changes aim to achieve pixel-perfect alignment with the official LTX-2 model outputs.
| ) | ||
| if self.audio_positional_embedding_max_pos is None: | ||
| self.audio_positional_embedding_max_pos = [2048] | ||
| self.audio_positional_embedding_max_pos = [20] |
| def prepare_latent_shape(self, batch, batch_size, num_frames): | ||
| """Return packed latent shape [B, seq, C] directly.""" | ||
| height = batch.height // self.vae_scale_factor | ||
| width = batch.width // self.vae_scale_factor | ||
|
|
||
| post_patch_num_frames = num_frames // self.patch_size_t | ||
| post_patch_height = height // self.patch_size | ||
| post_patch_width = width // self.patch_size | ||
| seq_len = post_patch_num_frames * post_patch_height * post_patch_width | ||
|
|
||
| num_channels = self.in_channels * self.patch_size_t * self.patch_size * self.patch_size | ||
|
|
||
| shape = (batch_size, seq_len, num_channels) | ||
| return shape |
| # If already packed (3D shape [B, seq, C]), skip packing | ||
| if latents.dim() == 3: | ||
| return latents |
There was a problem hiding this comment.
Adding an early return in maybe_pack_latents() if latents are already 3D [B, seq, C] is a good optimization and prevents unnecessary reprocessing. This improves efficiency and robustness.
| # If already packed (3D shape [B, seq, C]), skip packing | |
| if latents.dim() == 3: | |
| return latents | |
| if latents.dim() == 3: | |
| return latents |
| # If already packed (3D shape [B, T, C*F]), skip packing | ||
| if latents.dim() == 3: | ||
| return latents |
There was a problem hiding this comment.
Adding an early return in maybe_pack_audio_latents() if audio latents are already 3D [B, T, C*F] is a good optimization and prevents unnecessary reprocessing. This improves efficiency and robustness.
| # If already packed (3D shape [B, T, C*F]), skip packing | |
| if latents.dim() == 3: | |
| return latents | |
| if latents.dim() == 3: | |
| return latents |
| import math | ||
| import os | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from diffusers import FlowMatchEulerDiscreteScheduler |
| pass | ||
| if not vae_autocast_enabled: | ||
| latents = latents.to(vae_dtype) | ||
| decode_output = self.vae.decode(latents) |
| import math | ||
| import time | ||
| from io import BytesIO | ||
|
|
||
| import av | ||
| import numpy as np |
|
/tag-and-rerun-ci |
|
/rerun-failed-ic |
|
/rerun-failed-ci |
|
|
||
| # (..., 2, r) | ||
| split_x = x.reshape(*x.shape[:-1], 2, r).float() # Explicitly upcast to float | ||
| split_x = x.reshape(*x.shape[:-1], 2, r) |
There was a problem hiding this comment.
In my Opinion, apply_split_rotary_emb is a bit complicated to understand. There are some suggestions to imporve its readablity:
- Using
unbindAPI to getfirst_x/second_xlikeapply_interleaved_rotary_emb - Using
transposeAPI instead ofswapaxes, asswapaxesis a Numpy-style method
There was a problem hiding this comment.
For the swapaxes → transpose change: we agreed and changed in line 37 above
For the unbind suggestion: we tried refactoring apply_split_rotary_emb to use unbind + separate multiply/subtract (like apply_interleaved_rotary_emb), but it breaks pixel-level alignment with the official LTX-2 pipeline. The current code uses addcmul which is a fused multiply-add operation — under bf16, it rounds only once at the end, whereas separate a * b - c * d rounds at each intermediate step. The per-element difference is tiny (~1 ULP), but it compounds over 40 denoising steps over many attention layers, resulting in a max_diff of ~2.38 in the final output. So we'll keep the addcmul approach as-is to preserve numerical correctness.
| sin_freqs = torch.swapaxes(sin_freq, 1, 2) | ||
|
|
||
| return cos_freqs, sin_freqs | ||
| return cos_freqs.to(torch.bfloat16), sin_freqs.to(torch.bfloat16) |
There was a problem hiding this comment.
could you add some comment for this dtype cast? It's confused and why don't you use self.coords_dtype?
There was a problem hiding this comment.
We added some commets for this dtype cast and unfortunately we can't use self.coords_dtype here because it controls the intermediate coordinate precision (fp32 for audio, bf16 for video), while this cast ensures the final RoPE output always matches the model's bf16 compute dtype. Otherwise, it won't align to the official LTX-2 model.
|
it seems like there are two code style in this model |
Good catch — the style difference comes from the fact that these changes are precision-alignment fixes on top of the existing SGLang codebase. We kept the surrounding code untouched to minimize the diff and focused on correctness first (ensuring pixel-perfect alignment with the official LTX-2 pipeline). Happy to do a follow-up pass to unify the code style once the functional changes are reviewed and merged. |
|
/rerun-failed-ci |
Motivation
Support LTX-2 Video & Audio Joint model.
This PR mainly fix bugs in Sglang LTX-2 pipelines
How to use
generate
sglang generate \ --model-path LTX-2_Model \ --prompt "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a breath-taking, movie-like shot. [Audio Description] Subtle ambient space atmosphere with deep bass hum, gentle cracking and crunching sounds as the eggshell breaks apart, soft whooshing of lunar dust particles floating in low gravity, quiet breathing sounds from inside the spacesuit, a faint heartbeat rhythm building tension, ethereal cinematic synth pads creating a sense of wonder and discovery, all blended with vast empty silence of space." \ --negative-prompt "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static." \ --width 768 --height 512 --num-frames 121 --fps 24 \ --num-inference-steps 40 --guidance-scale 4.0 \ --seed 1024 \ --output-path ./output/ --output-file-name ltx2_sample.mp4Bugs fix
configs/models/dits/ltx_2.pyaudio_positional_embedding_max_posdefault from[2048]to[20]to match Official LTX-2 audio positional embedding configurationconfigs/pipeline_configs/ltx_2.pyprepare_latent_shape()method to return packed latent shape[B, seq, C]directly, matching Official's latent preparationprepare_audio_latent_shape()return shape from(B, C, L, M)to(B, L, C*M)to match Official's packed audio latent formatprepare_sigmas()fromnp.linspace(1.0, 1.0/steps, steps)totorch.linspace(1.0, 0.0, steps+1)[:-1]to match Official's sigma schedulemaybe_pack_latents()to skip packing if latents are already 3D[B, seq, C]maybe_pack_audio_latents()to skip packing if audio latents are already 3D[B, T, C*F]runtime/models/adapter/ltx_2_connector.py.float()upcast inapply_split_rotary_embto keep computation in native dtype (bfloat16) instead of upcasting to float32dtypeparameter toLTX2RotaryPosEmbed1d.forward()to allow explicit dtype control for cos/sin freqscos_freqsandsin_freqswhendtypeis specifiedself.rope()call to passdtype=hidden_states.dtypefor consistent precisionruntime/models/dits/ltx_2.py.float()upcast inapply_split_rotary_embto keep rotary embedding computation in native dtypeself.coords_dtypeattribute:bfloat16for video,float32for audio, to control coordinate precision per modalitycoords = coords.to(self.coords_dtype)before coordinate averagingtorch.bfloat16instead of returning in computed dtyperope_double_precisionto read from HuggingFace config first (hf_config.get(...)) with arch config as fallbackadaln_singlecalls to passhidden_dtype=hidden_states.dtypefor explicit dtype control in AdaLN timestamp embeddingstimestep * ts_ca_multtotimestep * self.av_ca_timestep_scale_multiplierto avoid intermediate float64 computationself.norm_out()andself.audio_norm_out()intorch.autocast(enabled=False)to disable autocast during final layer norm, ensuring deterministic precisionruntime/models/encoders/gemma_3.py_rotate_half()helper function implementing HuggingFace's rotate_half formula_hf_inv_freqbuffer inGemma3Attention.__init__()to match HuggingFace's exact inverse frequency computationrotary_emb()method that uses_hf_inv_freqwithrotate_halfformula, replacing SGLang's sharedRotaryEmbeddingwhich uses a slightly different numerical formulatorch.float32toquery.dtypefor consistent precisionfloat("-inf")mask fill totorch.finfo(query.dtype).minto avoid overflow in bfloat16embed_scalemultiplication to use explicittorch.tensor(self.embed_scale, device, dtype)to avoid precision mismatchposition_ids = position_ids + 1to match Official LTX-2's position indexingruntime/pipelines/ltx_2_pipeline.pyLTX2FlowMatchSchedulerclass that overrides_time_shift_exponentialto use torch float32 arithmetic instead of numpy float64, matching Official LTX-2's scheduler precisioninitialize_pipeline()method to replace the defaultFlowMatchEulerDiscreteSchedulerwithLTX2FlowMatchSchedulerprepare_mu()to use fixed token count4096(matching Official'sMAX_SHIFT_ANCHORdefault) instead of computing from actual latent dimensionsruntime/pipelines_core/stages/denoising_av.pyencode_dtype = batch.latents.dtype(bfloat16) for both input tensor and VAE weights, instead of hardcodedtorch.float32original_dtypefromvae_precision, converts VAE toencode_dtypebefore encoding, restoresoriginal_dtypeafter per-channel normalizationvae_precisionstays"fp32",vae_autocast_enabledisFalse), matching Official's pure bf16 encode without autocastruntime/pipelines_core/stages/decoding_av.pyself.vae.to(original_dtype)after decode to restore VAE dtype (previously VAE stayed in bfloat16 after the forced.to(torch.bfloat16)conversion)Precision Alignment [verified on a single GPU]
To reproduce pixel-perfect alignment with Official LTX-2 (
ltx-pipelines/ltx-core), the following deterministic torch settings must be applied to both pipelines before any model execution:These settings disable FlashAttention and Memory-Efficient SDPA backends, forcing PyTorch to use the deterministic Math SDPA backend. Without them, non-deterministic attention kernels produce different results across runs even with identical inputs and seeds.
How to apply
Both Official and SGLang pipelines must load these settings before any torch operation. This is done via a
sitecustomize.py+torch_settings.pypair placed onPYTHONPATH:Then prepend this directory to
PYTHONPATHwhen launching either pipeline:Python automatically imports
sitecustomize.pyat startup, which triggerstorch_settings.pyand applies the deterministic settings before any model code runs.Verified alignment
With these settings applied and identical generation parameters (prompt, seed, resolution, num_frames, guidance_scale, num_inference_steps, image conditioning), the two pipelines produce bit-for-bit identical float pixel outputs (
max_diff = 0.0).Checklist
LTX-2
generateand theserverLTX-2 Stack PR
Review Process
\tag-run-ci-label,\rerun-failed-ci,\tag-and-rerun-ci