Skip to content

Conversation

@rootonchair
Copy link
Contributor

@rootonchair rootonchair commented Jan 9, 2026

What does this PR do?

Fixes #12925

Test script t2i

import torch
from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline
from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
from diffusers.pipelines.ltx2.export_utils import encode_video
from diffusers.utils import load_image

pipe = LTX2Pipeline.from_pretrained("rootonchair/LTX-2-19b-distilled", torch_dtype=torch.bfloat16)
upsample_pipe = LTX2LatentUpsamplePipeline.from_pretrained("rootonchair/LTX-2-19b-distilled", subfolder="upsample_pipeline", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()

prompt = "A beautiful sunset over the ocean"
negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."

frame_rate = 24.0
video_latent, audio_latent = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=768,
    height=512,
    num_frames=121,
    frame_rate=frame_rate,
    num_inference_steps=8,
    sigmas=DISTILLED_SIGMA_VALUES,
    guidance_scale=1.0,
    output_type="latent",
    return_dict=False,
)

upscaled_video_latent = upsample_pipe(
    latents=video_latent,
    output_type="latent",
    return_dict=False,
)[0]

video, audio = pipe(
    latents=upscaled_video_latent,
    audio_latents=audio_latent,
    prompt=prompt,
    negative_prompt=negative_prompt,
    num_inference_steps=3,
    sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
    guidance_scale=1.0,
    output_type="np",
    return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)

encode_video(
    video[0],
    fps=frame_rate,
    audio=audio[0].float().cpu(),
    audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
    output_path="ltx2_sample.mp4",
)

Before submitting

Who can review?

@sayakpaul

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@rootonchair rootonchair marked this pull request as ready for review January 11, 2026 15:39
@rootonchair
Copy link
Contributor Author

Running results
Original

output.mp4

Converted ckpt

ltx2_sample.mp4

@sayakpaul sayakpaul requested a review from dg845 January 12, 2026 03:41
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for shipping this so quickly! Left some comments, LMK if they make sense.

"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": False,
"timestep_conditioning": timestep_conditioning,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is timestep_conditioning used in the corresponding Video VAE class?

Copy link
Contributor Author

@rootonchair rootonchair Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

timestep_conditioning is used here. It will depend on the metadata stored in the original checkpoint and distilled checkpoint is configured to use timestep_conditioning. You can confirm it using below script:

from ltx_core.loader.sft_loader import SafetensorsModelStateDictLoader
model_loader = SafetensorsModelStateDictLoader()
model_loader.metadata("weights/ltx-2-19b-distilled.safetensors")["vae"]["timestep_conditioning"] # True

Hence, I got some unexpected keys if don't turn timestep_conditioning to True when converting distilled weights

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @rootonchair, there was a recent update on the official Lightricks/LTX-2 repo which updated the VAE for the distilled checkpoint: https://huggingface.co/Lightricks/LTX-2/commit/1931de987c8e265eb64a9123227b903754e3cc68. So I think rootonchair/LTX-2-19b-distilled needs to be updated with the new converted VAE as well.

It looks like timestep_conditioning should be set to False for the new video VAE, according to the config metadata on ltx-2-19b-distilled.safetensors. So we should probably remove the unexpected keys (last_time_embedder, last_scale_shift_table) rather than processing them in the VAE conversion script.

import json
import safetensors
from huggingface_hub import hf_hub_download

ckpt_path = hf_hub_download("Lightricks/LTX-2", "ltx-2-19b-distilled.safetensors")
with safetensors.safe_open(ckpt_path, framework="pt") as f:
    config = json.loads(f.metadata()["config"])
config["vae"]["timestep_conditioning"]  # Should be False

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I will convert the new weight. For the convert script, I thinnk we should keep it there until the original repo decide to delete it

latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
if latents.ndim == 5:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed to support precomputed latents?

Copy link
Contributor Author

@rootonchair rootonchair Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it will add another support layer for precomputed latents. The latents returned are in shape [B, C, F, H, W], however, the model takes input as [B, seq_len, C] and passing the returned latents directly would cause shape mismatch. The same applies for audio vae

latent_length = round(duration_s * latents_per_second)

if latents is not None:
if latents.ndim == 4:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question as above.


# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's apply similar changes to the I2V module as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will apply the same change

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Left some comments about the distilled sigmas schedule.

If I print out the timesteps for the Stage 1 distilled pipeline, I get (for commit faeccc5):

Distilled timesteps: tensor([1000.0000,  999.6502,  999.2961,  998.9380,  998.5754,  994.4882,
         979.3755,  929.6974,  100.0000], device='cuda:0')

Here the sigmas (and thus the timesteps) are shifted toward a terminal value of 0.1, and use_dynamic_shifting is applied as well. However, I believe the distilled sigmas are used as-is in the original LTX 2.0 code:

https://github.com/Lightricks/LTX-2/blob/391c0a2462f8bc2238fa0a2b7717bc1914969230/packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py#L132

So I think when creating the distilled scheduler we need to disable use_dynamic_shifting and shift_terminal so that the distilled sigmas are used without changes.

Can you check whether the final distilled sigmas match up with those of the original implementation?

@dg845
Copy link
Collaborator

dg845 commented Jan 13, 2026

The original test script didn't work for me, but I was able to get a working version as follows:

import torch
from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
from diffusers.pipelines.ltx2.export_utils import encode_video

device = "cuda:0"
width = 768
height = 512

pipe = LTX2Pipeline.from_pretrained(
    "rootonchair/LTX-2-19b-distilled", torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload(device=device)

prompt = "A beautiful sunset over the ocean"
negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."

frame_rate = 24.0
video_latent, audio_latent = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width,
    height=height,
    num_frames=121,
    frame_rate=frame_rate,
    num_inference_steps=8,
    sigmas=DISTILLED_SIGMA_VALUES,
    guidance_scale=1.0,
    output_type="latent",
    return_dict=False,
)

latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
    "rootonchair/LTX-2-19b-distilled",
    subfolder="upsample_pipeline/latent_upsampler",
    torch_dtype=torch.bfloat16,
)
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
upsample_pipe.enable_model_cpu_offload(device=device)
upscaled_video_latent = upsample_pipe(
    latents=video_latent,
    output_type="latent",
    return_dict=False,
)[0]

video, audio = pipe(
    latents=upscaled_video_latent,
    audio_latents=audio_latent,
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width * 2,
    height=height * 2,
    num_inference_steps=3,
    sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
    guidance_scale=1.0,
    output_type="np",
    return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)

encode_video(
    video[0],
    fps=frame_rate,
    audio=audio[0].float().cpu(),
    audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
    output_path="ltx2_distilled_sample.mp4",
)

The necessary changes were to create LTX2LatentUpsamplePipeline directly from the models, as from_pretrained doesn't work because there are two pipelines (the distilled pipeline and the latent upsampling pipeline) in the same repo (a known limitation), and to set Stage 2 inference to use width * 2 and height * 2, as the upsampling pipeline upsamples the video by 2x in both the width and height.

rootonchair and others added 2 commits January 13, 2026 12:01
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
@sayakpaul
Copy link
Member

Not jeopardizing this PR at all but while we're at the two-stage pipeline stuff, it could also be cool to verify it with the distilled LoRA that we have in place (PR already merged: #12933).

So, what we would do is:

  1. Obtain video and audio latents with the regular LTX2 pipeline
  2. Upsample the video latents
  3. Load the distilled LoRA into the regular LTX2 pipeline
  4. Run the pipeline with 4 steps and these sigmas and with a guidance scale of 1.

Once we're close to merging the PR, we could document all of these to inform the community.

@rootonchair
Copy link
Contributor Author

@dg845 thank you for your detail reviews. Let's me take a closer look on that

@rootonchair
Copy link
Contributor Author

Not jeopardizing this PR at all but while we're at the two-stage pipeline stuff, it could also be cool to verify it with the distilled LoRA that we have in place (PR already merged: #12933).

So, what we would do is:

  1. Obtain video and audio latents with the regular LTX2 pipeline
  2. Upsample the video latents
  3. Load the distilled LoRA into the regular LTX2 pipeline
  4. Run the pipeline with 4 steps and these sigmas and with a guidance scale of 1.

Once we're close to merging the PR, we could document all of these to inform the community.

That sounds interesting. Let's have a quick test on two stage distilled LoRA too

@dg845
Copy link
Collaborator

dg845 commented Jan 13, 2026

For two-stage inference with the Stage 2 distilled LoRA, I think this script should work:

import torch
from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
from diffusers.pipelines.ltx2.export_utils import encode_video

device = "cuda:0"
width = 768
height = 512

pipe = LTX2Pipeline.from_pretrained(
    "rootonchair/LTX-2-19b-distilled", torch_dtype=torch.bfloat16
)
# This scheduler should use distilled sigmas without any changes
new_scheduler = FlowMatchEulerDiscreteScheduler.from_config(
    pipe.scheduler.config, use_dynamic_shifting=False, shift_terminal=None
)
pipe.scheduler = new_scheduler
pipe.enable_model_cpu_offload(device=device)

prompt = "A beautiful sunset over the ocean"
negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."

frame_rate = 24.0
video_latent, audio_latent = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width,
    height=height,
    num_frames=121,
    frame_rate=frame_rate,
    num_inference_steps=8,
    sigmas=DISTILLED_SIGMA_VALUES,
    guidance_scale=1.0,
    output_type="latent",
    return_dict=False,
)

latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
    "rootonchair/LTX-2-19b-distilled",
    subfolder="upsample_pipeline/latent_upsampler",
    torch_dtype=torch.bfloat16,
)
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
upsample_pipe.enable_model_cpu_offload(device=device)
upscaled_video_latent = upsample_pipe(
    latents=video_latent,
    output_type="latent",
    return_dict=False,
)[0]

# Load Stage 2 distilled LoRA
pipe.load_lora_weights(
    "Lightricks/LTX-2", adapter_name="stage_2_distilled", weight_name="ltx-2-19b-distilled-lora-384.safetensors"
)
pipe.set_adapters("stage_2_distilled", 1.0)
# VAE tiling seems necessary to avoid OOM error when VAE decoding
pipe.vae.enable_tiling()
video, audio = pipe(
    latents=upscaled_video_latent,
    audio_latents=audio_latent,
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width * 2,
    height=height * 2,
    num_inference_steps=3,
    sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
    guidance_scale=1.0,
    output_type="np",
    return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)

encode_video(
    video[0],
    fps=frame_rate,
    audio=audio[0].float().cpu(),
    audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
    output_path="ltx2_distilled_sample.mp4",
)

Sample with LoRA and scheduler fix:

ltx2_distilled_sample_lora_fix.mp4

@rootonchair
Copy link
Contributor Author

For two-stage inference with the Stage 2 distilled LoRA, I think this script should work:

import torch
from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
from diffusers.pipelines.ltx2.export_utils import encode_video

device = "cuda:0"
width = 768
height = 512

pipe = LTX2Pipeline.from_pretrained(
    "rootonchair/LTX-2-19b-distilled", torch_dtype=torch.bfloat16
)
# This scheduler should use distilled sigmas without any changes
new_scheduler = FlowMatchEulerDiscreteScheduler.from_config(
    pipe.scheduler.config, use_dynamic_shifting=False, shift_terminal=None
)
pipe.scheduler = new_scheduler
pipe.enable_model_cpu_offload(device=device)

prompt = "A beautiful sunset over the ocean"
negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."

frame_rate = 24.0
video_latent, audio_latent = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width,
    height=height,
    num_frames=121,
    frame_rate=frame_rate,
    num_inference_steps=8,
    sigmas=DISTILLED_SIGMA_VALUES,
    guidance_scale=1.0,
    output_type="latent",
    return_dict=False,
)

latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
    "rootonchair/LTX-2-19b-distilled",
    subfolder="upsample_pipeline/latent_upsampler",
    torch_dtype=torch.bfloat16,
)
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
upsample_pipe.enable_model_cpu_offload(device=device)
upscaled_video_latent = upsample_pipe(
    latents=video_latent,
    output_type="latent",
    return_dict=False,
)[0]

# Load Stage 2 distilled LoRA
pipe.load_lora_weights(
    "Lightricks/LTX-2", adapter_name="stage_2_distilled", weight_name="ltx-2-19b-distilled-lora-384.safetensors"
)
pipe.set_adapters("stage_2_distilled", 1.0)
# VAE tiling seems necessary to avoid OOM error when VAE decoding
pipe.vae.enable_tiling()
video, audio = pipe(
    latents=upscaled_video_latent,
    audio_latents=audio_latent,
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width * 2,
    height=height * 2,
    num_inference_steps=3,
    sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
    guidance_scale=1.0,
    output_type="np",
    return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)

encode_video(
    video[0],
    fps=frame_rate,
    audio=audio[0].float().cpu(),
    audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
    output_path="ltx2_distilled_sample.mp4",
)

Sample with LoRA and scheduler fix:

ltx2_distilled_sample_lora_fix.mp4

I think we should run with the original LTX2 weight and not the distilled checkpoint. WDYT?

@sayakpaul
Copy link
Member

Yes, the first stage, in this case, should use the non-distilled ckpt.

@dg845
Copy link
Collaborator

dg845 commented Jan 13, 2026

Fixed script (I think):

import torch
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
from diffusers.pipelines.ltx2.export_utils import encode_video

device = "cuda:0"
width = 768
height = 512

pipe = LTX2Pipeline.from_pretrained(
    "Lightricks/LTX-2", torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload(device=device)

prompt = "A beautiful sunset over the ocean"
negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."

# Stage 1 default (non-distilled) inference
frame_rate = 24.0
video_latent, audio_latent = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width,
    height=height,
    num_frames=121,
    frame_rate=frame_rate,
    num_inference_steps=40,
    sigmas=None,
    guidance_scale=4.0,
    output_type="latent",
    return_dict=False,
)

latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
    "Lightricks/LTX-2",
    subfolder="latent_upsampler",
    torch_dtype=torch.bfloat16,
)
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
upsample_pipe.enable_model_cpu_offload(device=device)
upscaled_video_latent = upsample_pipe(
    latents=video_latent,
    output_type="latent",
    return_dict=False,
)[0]

# Load Stage 2 distilled LoRA
pipe.load_lora_weights(
    "Lightricks/LTX-2", adapter_name="stage_2_distilled", weight_name="ltx-2-19b-distilled-lora-384.safetensors"
)
pipe.set_adapters("stage_2_distilled", 1.0)
# VAE tiling seems necessary to avoid OOM error when VAE decoding
pipe.vae.enable_tiling()
# Change scheduler to use Stage 2 distilled sigmas as is
new_scheduler = FlowMatchEulerDiscreteScheduler.from_config(
    pipe.scheduler.config, use_dynamic_shifting=False, shift_terminal=None
)
pipe.scheduler = new_scheduler
# Stage 2 inference with distilled LoRA and sigmas
video, audio = pipe(
    latents=upscaled_video_latent,
    audio_latents=audio_latent,
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width * 2,
    height=height * 2,
    num_inference_steps=3,
    sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
    guidance_scale=1.0,
    output_type="np",
    return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)

encode_video(
    video[0],
    fps=frame_rate,
    audio=audio[0].float().cpu(),
    audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
    output_path="ltx2_distilled_sample.mp4",
)

@dg845
Copy link
Collaborator

dg845 commented Jan 13, 2026

If I test the distilled pipeline with the prompt "a dog dancing to energetic electronic dance music", I get the following sample:

ltx2_distilled_sample_dog_edm.mp4

I would expect the audio to be music for this prompt, but instead the audio is only noise, so I think there might be something wrong with the way audio is currently being handled in the distilled pipeline. (The video also doesn't follow the prompt closely; I'm not sure if this is a symptom of the audio being messed up or if there are also bugs for video processing.)

@@ -0,0 +1,4 @@
DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also mention the source for these values.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!

@sayakpaul
Copy link
Member

@dg845 should the second stage inference with LoRA be run with 4 num_inference_steps? Also, is the following necessary?

width=width * 2,
height=height * 2,

@dg845
Copy link
Collaborator

dg845 commented Jan 14, 2026

should the second stage inference with LoRA be run with 4 num_inference_steps?

I believe it should be run with 3, as STAGE_2_DISTILLED_SIGMA_VALUES has 3 values excluding the trailing 0.0. Note that if sigmas is set to STAGE_2_DISTILLED_SIGMA_VALUES, this will currently override num_inference_steps.

Also, is the following necessary?

It is currently necessary as otherwise we'd get a shape error, but we could modify the code to infer the latent_height and latent_width from supplied latents, by changing this part:

latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
video_sequence_length = latent_num_frames * latent_height * latent_width

@dg845
Copy link
Collaborator

dg845 commented Jan 14, 2026

Wrote a sample commit f4d47b9 which infers the latent dimensions (e.g. latent_height, latent_width, etc.) if latents or audio_latents is supplied. This should make it easier to run two-stage inference without having to manage the width and height.

@rootonchair feel free to cherry-pick the changes if they work for you.

@rootonchair
Copy link
Contributor Author

Thanks for the PR! Left some comments about the distilled sigmas schedule.

If I print out the timesteps for the Stage 1 distilled pipeline, I get (for commit faeccc5):

Distilled timesteps: tensor([1000.0000,  999.6502,  999.2961,  998.9380,  998.5754,  994.4882,
         979.3755,  929.6974,  100.0000], device='cuda:0')

Here the sigmas (and thus the timesteps) are shifted toward a terminal value of 0.1, and use_dynamic_shifting is applied as well. However, I believe the distilled sigmas are used as-is in the original LTX 2.0 code:

https://github.com/Lightricks/LTX-2/blob/391c0a2462f8bc2238fa0a2b7717bc1914969230/packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py#L132

So I think when creating the distilled scheduler we need to disable use_dynamic_shifting and shift_terminal so that the distilled sigmas are used without changes.

Can you check whether the final distilled sigmas match up with those of the original implementation?

I have checked with the original implementation and the timesteps is match with each other. As the sigmas later scaled up to 1000 here

@rootonchair
Copy link
Contributor Author

If I test the distilled pipeline with the prompt "a dog dancing to energetic electronic dance music", I get the following sample:

ltx2_distilled_sample_dog_edm.mp4

I would expect the audio to be music for this prompt, but instead the audio is only noise, so I think there might be something wrong with the way audio is currently being handled in the distilled pipeline. (The video also doesn't follow the prompt closely; I'm not sure if this is a symptom of the audio being messed up or if there are also bugs for video processing.)

This is really weird, let's investigate more on this before we ship the model

@rootonchair
Copy link
Contributor Author

Wrote a sample commit f4d47b9 which infers the latent dimensions (e.g. latent_height, latent_width, etc.) if latents or audio_latents is supplied. This should make it easier to run two-stage inference without having to manage the width and height.

@rootonchair feel free to cherry-pick the changes if they work for you.

Thanks

f4d47b9

@dg845 I think your commit is good. Just check a few more before applying

if latents is not None:
if latents.ndim == 4:
# latents are of shape [B, C, L, M], need to be packed
latents = self._pack_audio_latents(latents)
Copy link
Collaborator

@dg845 dg845 Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the output_type is set to "latent", LTX2Pipeline will return the denormalized audio latents:

audio_latents = self._denormalize_audio_latents(
audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std
)
audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins)

Since the DiT expects normalized latents, I think we need to normalize the audio latents here:

                latents = self._pack_audio_latents(latents)
                latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std)

where _normalize_audio_latents is something like

    @staticmethod
    def _normalize_audio_latents(
        latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
    ) -> torch.Tensor:
        # Normalize latents across the combined channel and mel bin dimension [B, L, C * M]
        latents_mean = latents_mean.to(latents.device, latents.dtype)
        latents_std = latents_std.to(latents.device, latents.dtype)
        latents = (latents - latents_mean) / latents_std
        return latents

This should make the Stage 2 audio latents have the expected distribution.

Copy link
Collaborator

@dg845 dg845 Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that although LTX2Pipeline also returns the denormalized video latents when output_type="latent", LTX2LatentUpsamplePipeline returns the normalized latents, so the Stage 2 video latents are not affected by this issue.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(However, prepare_latents-type methods usually expect supplied latents to be normalized, since we generally return them as-is without normalizing them, so we might need to think through the design on where we expect latents to be normalized or denormalized. CC @sayakpaul.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this is a bit of a spiraling situation and I wonder if exposing a flag like normalize_latents (or leverage the latents_normalized flag as used in the upsampling pipeline) could make sense here (we keep it to a reasonable default) and we educate the users about when to use what for normalize_latents. I think this also comes with the uniqueness of the LTX2 pipeline a bit.

The situation is getting confusing likely because the video latents returned from the upsampling pipeline is always normalized. If the users wants them unnormalized (with a flag normalize_latents=False, then I guess both upsampled_video_latent and audio_latent could be passed as is to the stage 2 pipeline?

(However, prepare_latents-type methods usually expect supplied latents to be normalized, since we generally return them as-is without normalizing them, so we might need to think through the design on where we expect latents to be normalized or denormalized.

I think this is both true and false. Inside Flux Kontext, the image latents are normalized inside prepare_latents():

image_latents = self._encode_vae_image(image=image, generator=generator)

This is what we also follow for LTX2 I2V:

init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One perspective we should consider to keep the returned latents denormalized as default is that there could be a pipeline where each module has different methods for normalizing. Only passing denormalized latents and each module need to normalize the latents themselves will help them maintain the plug-and-play ability. Moreover, the computation required for normalizing is acceptable

@rootonchair
Copy link
Contributor Author

I just updated the model repository. Move latent_upsampler outside and remove upsampler_pipeline

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

LTX-2 distilled checkpoint support

4 participants