[Diffusion] LTX-2 Support PR2#17496
Conversation
Summary of ChangesHello @gmixiaojin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly expands the diffusion capabilities by integrating the LTX-2 Video & Audio Joint model. It introduces a new pipeline that orchestrates text and image inputs to generate high-quality video and audio outputs. The changes encompass new model architectures, optimized data handling for distributed environments, and detailed configuration for various multimodal components, laying the groundwork for advanced video and audio synthesis. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces comprehensive support for the LTX-2 video and audio joint generation model. The changes are extensive, adding new configurations, model implementations for various components like DiT, VAEs, encoders, and a vocoder, as well as a new pipeline to orchestrate the generation process. The code is generally well-structured. My review focuses on improving code clarity and maintainability. I've identified a few areas where refactoring could enhance readability and consistency, such as replacing magic numbers with named constants and extracting complex logic into helper functions.
| (".gate_up_proj", ".gate_proj", "0"), # type: ignore | ||
| (".gate_up_proj", ".up_proj", "1"), # type: ignore |
There was a problem hiding this comment.
The type hint for stacked_params_mapping is list[tuple[str, str, str]], but you are using integers 0 and 1 for the shard IDs here, which is inconsistent. While the weight loading logic seems to handle this, it would be better for type safety and to avoid future confusion to use strings "0" and "1" to match the type hint.
(".gate_up_proj", ".gate_proj", "0"),
(".gate_up_proj", ".up_proj", "1"),| audio_latents_mean = getattr(audio_vae, "latents_mean", None) | ||
| audio_latents_std = getattr(audio_vae, "latents_std", None) | ||
| if ( | ||
| isinstance(audio_latents_mean, torch.Tensor) | ||
| and isinstance(audio_latents_std, torch.Tensor) | ||
| and audio_latents_mean.numel() == audio_latents_std.numel() | ||
| ): | ||
| audio_latents_mean = audio_latents_mean.to( | ||
| device=audio_latents.device, dtype=audio_latents.dtype | ||
| ) | ||
| audio_latents_std = audio_latents_std.to( | ||
| device=audio_latents.device, dtype=audio_latents.dtype | ||
| ) | ||
| if audio_latents.ndim == 3: | ||
| if audio_latents.shape[-1] != audio_latents_mean.numel(): | ||
| raise ValueError( | ||
| f"audio_latents last dim {audio_latents.shape[-1]} " | ||
| f"does not match audio_vae stats {audio_latents_mean.numel()}" | ||
| ) | ||
| audio_latents = audio_latents * audio_latents_std.view( | ||
| 1, 1, -1 | ||
| ) + audio_latents_mean.view(1, 1, -1) | ||
| elif audio_latents.ndim == 2: | ||
| if audio_latents.shape[-1] != audio_latents_mean.numel(): | ||
| raise ValueError( | ||
| f"audio_latents last dim {audio_latents.shape[-1]} " | ||
| f"does not match audio_vae stats {audio_latents_mean.numel()}" | ||
| ) | ||
| audio_latents = audio_latents * audio_latents_std.view( | ||
| 1, -1 | ||
| ) + audio_latents_mean.view(1, -1) | ||
| else: | ||
| audio_latents = audio_latents * audio_latents_std + audio_latents_mean | ||
|
|
There was a problem hiding this comment.
This block of code for denormalizing audio latents is quite complex and makes the _unpad_and_unpack_latents method very long and hard to read. Consider extracting this logic into a separate helper method to improve readability and modularity. Additionally, there's an unused static method _denormalize_audio_latents with a similar purpose. It would be good to either use it if it's correct or remove it to avoid confusion and code duplication.
| video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[ | ||
| :4, : | ||
| ] | ||
| video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :] | ||
|
|
||
| video_ca_scale_shift_table = ( | ||
| video_per_layer_ca_scale_shift[None, None, :, :].to( | ||
| dtype=temb_ca_scale_shift.dtype, device=temb_ca_scale_shift.device | ||
| ) | ||
| + temb_ca_scale_shift.reshape( | ||
| batch_size, temb_ca_scale_shift.shape[1], 4, -1 | ||
| ) | ||
| ).unbind(dim=2) | ||
| video_ca_gate = ( | ||
| video_per_layer_ca_gate[None, None, :, :].to( | ||
| dtype=temb_ca_gate.dtype, device=temb_ca_gate.device | ||
| ) | ||
| + temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1) | ||
| ).unbind(dim=2) |
There was a problem hiding this comment.
The magic number 4 is used multiple times in this block for slicing and reshaping tensors related to cross-attention scale and shift. This makes the code harder to understand and maintain. It would be better to define a constant for this value, for example, NUM_CA_SCALE_SHIFT_PARAMS = 4, and use it here. This would improve readability and make future modifications easier.
|
Please fix lint. |
fixed. |
Add decoding_av.py, denoising_av.py, latent_preparation_av.py, and text_connector.py to this PR, moved from PR1.
|
/tag-and-rerun-ci |
|
Hi @gmixiaojin, grate work! Curious if VAE supports SP? |
Co-authored-by: Fan Yin <1106310035@qq.com> Co-authored-by: Yuhao Yang <47235274+yhyang201@users.noreply.github.com>
|
The current commit appears to have an issue with audio synthesis—the background noise is very loud, and the same prompt fails to generate normal video audio. |
Thank you for your feedback. |
Motivation
Support LTX-2 Video & Audio Joint model.
this pr involves new config files and modeling files only
How to use
generate
Accuracy Tests
1:1 align with Diffusers
Checklist
LTX-2
generateand theserverLTX-2 Stack PR
Review Process
\tag-run-ci-label,\rerun-failed-ci,\tag-and-rerun-ci