Skip to content

Add/modify some implementation for anima#2261

Merged
kohya-ss merged 49 commits intosd3from
feat-anima-polish
Feb 12, 2026
Merged

Add/modify some implementation for anima#2261
kohya-ss merged 49 commits intosd3from
feat-anima-polish

Conversation

@kohya-ss
Copy link
Copy Markdown
Owner

@kohya-ss kohya-ss commented Feb 8, 2026

  • Fix _typos.toml
  • Exclude anima test case from pytest
  • Replace WanVAE_ with Qwen Image VAE for compatibility with Musubi Tuner and clarify license.

@duongve13112002
Copy link
Copy Markdown
Contributor

I forgot to change the default value of discrete_flow_shift to 1.0 in the anima_train_network.md document. Would you mind updating it? Thanks!

@kohya-ss
Copy link
Copy Markdown
Owner Author

kohya-ss commented Feb 8, 2026

I forgot to change the default value of discrete_flow_shift to 1.0 in the anima_train_network.md document. Would you mind updating it? Thanks!

Of course, thank you for letting me know!

@kohya-ss
Copy link
Copy Markdown
Owner Author

kohya-ss commented Feb 8, 2026

I'd like to replace Anima's VAE code with Qwen Image VAE, licensed under ASL 2.0 from Diffusers.

@kohya-ss kohya-ss marked this pull request as draft February 8, 2026 03:17
@kohya-ss
Copy link
Copy Markdown
Owner Author

kohya-ss commented Feb 8, 2026

I would like to change the LoRA module selection to be regular expression based, as I think it would be more flexible.

@kohya-ss
Copy link
Copy Markdown
Owner Author

kohya-ss commented Feb 8, 2026

I don't see the need to use fp32 for transformer weights for LoRA training, and there are other options for fp8. Also, finetuning defaults to fp32 and can optionally change to bf16/fp16, so I'll remove the transformer_dtype option.

@kohya-ss
Copy link
Copy Markdown
Owner Author

kohya-ss commented Feb 8, 2026

I will change it to use existing options such as --pretrained_model_name_or_path.

Also, it seems that caption dropout is effective in Anima. However, I think it will be confusing if caption_dropout_rate of the dataset is ignored, so I will consider some way to address this.

@kohya-ss
Copy link
Copy Markdown
Owner Author

kohya-ss commented Feb 8, 2026

I'd like to use sd3_train_utils.get_noisy_model_input_and_timesteps instead of anima_train_utils.get_noisy_model_input_and_timesteps to support min/max timesteps etc. Please let me know if there are any issues with replacing it.

Also, replacing process_batch in AnimaNetworkTraineer involves a lot of overlapping functionality, so I plan to do something about it.

@duongve13112002
Copy link
Copy Markdown
Contributor

I don’t see any issues with making that replacement. Some of the functions in anima_train_utils were copied or adapted from sd3_train_utils in the first place, so using sd3_train_utils.get_noisy_model_input_and_timesteps should be fine as long as it fits your training setup.

The overlapping functionality was intentional to help avoid conflicts with other training models. If you plan to refactor process_batch to reduce duplication, that sounds like a good improvement.

If you have any question about my implement. Feel free to ask i am happy to help you.

@kohya-ss
Copy link
Copy Markdown
Owner Author

kohya-ss commented Feb 8, 2026

The following processing is performed using LLMAdapter.

  crossattn_emb = self.llm_adapter(
      source_hidden_states=crossattn_emb,
      target_input_ids=t5_input_ids,
      target_attention_mask=t5_attn_mask,
      source_attention_mask=source_attention_mask,
  )
  if t5_attn_mask is not None:
      crossattn_emb[~t5_attn_mask.bool()] = 0

When the uncond string "" is tokenized, all of Qwen3's attention masks become 0. So, crossattn_emb becomes all 0.

Therefore, when uncond, no gradients to the LLMAdapter. It also seems that we do not need to cache the uncond embedding because it's all zero always.

Please let me know if there is anything in my understanding that is incorrect.

@duongve13112002
Copy link
Copy Markdown
Contributor

Your understanding is incorrect. Here's why:

Qwen3 tokenizing "" does NOT produce all-zero attention mask.

Qwen3 has eos_token = "<|endoftext|>" (ID 151643), no BOS. So tokenizing "" with padding="max_length", max_length=512 gives:

  • tokens: [151643, 0, 0, ..., 0] (1 EOS + 511 pad)
  • attn_mask: [1, 0, 0, ..., 0] — one position is 1

Same for T5: eos_token_id=1, so mask is also [1, 0, ..., 0].

After LLM Adapter + zero-out:

crossattn_emb[~t5_attn_mask.bool()] = 0

Position 0 (EOS) keeps its value from the adapter. Only positions 1-511 become zero.

Result: uncond crossattn_emb = [non_zero_EOS_embedding, 0, 0, ..., 0], not all zeros. The LLM Adapter still receives gradients through position 0, and the uncond embedding must be properly cached.

@kohya-ss
Copy link
Copy Markdown
Owner Author

kohya-ss commented Feb 8, 2026

Thank you!

My simple test returns Text: [''], Qwen3 mask sum: 0, T5 mask sum: 1.

        # Tokenize with Qwen3
        qwen3_encoding = self.qwen3_tokenizer.batch_encode_plus(
            text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.qwen3_max_length
        )
        qwen3_input_ids = qwen3_encoding["input_ids"]
        qwen3_attn_mask = qwen3_encoding["attention_mask"]

        # Tokenize with T5 (for LLM Adapter target tokens)
        t5_encoding = self.t5_tokenizer.batch_encode_plus(
            text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.t5_max_length
        )
        t5_input_ids = t5_encoding["input_ids"]
        t5_attn_mask = t5_encoding["attention_mask"]

        print(f"Text: {text}, Qwen3 mask sum: {qwen3_attn_mask.sum().item()}, T5 mask sum: {t5_attn_mask.sum().item()}")

T5 has a mask of length 1. It's my misunderstanding.

However, all outputs of Qwen3 are masked and are therefore 0. So should we only cache T5's input_ids (1, 0, 0, ...)?

@duongve13112002
Copy link
Copy Markdown
Contributor

Oh, I misunderstood Qwen3 as well, I was just reading the config and making assumptions based on my knowledge. I tested it, and you're right. Your idea makes sense; I think we should cache only T5's input_ids.

@kohya-ss
Copy link
Copy Markdown
Owner Author

kohya-ss commented Feb 8, 2026

@duongve13112002

When caption dropout is enabled, llm_adapter seems to return NaN for empty string regardless of whether it is cached or not. This also happens during inference.

Do you have any idea? I'm really confused...

@kohya-ss
Copy link
Copy Markdown
Owner Author

kohya-ss commented Feb 9, 2026

When caption dropout is enabled, llm_adapter seems to return NaN for empty string regardless of whether it is cached or not. This also happens during inference.

Do you have any idea? I'm really confused...

This seems to be a mathematical stability issue with scaled_dot_product_attention when mask is all 0. I will change it to skip SDPA when mask is 0.

outputs = qwen3_text_encoder(input_ids=nd_input_ids, attention_mask=nd_attn_mask)
nd_encoded_text = outputs.last_hidden_state
# Zero out padding positions
nd_encoded_text[~nd_attn_mask.bool()] = 0
Copy link
Copy Markdown
Contributor

@duongve13112002 duongve13112002 Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked and noticed that this line was missed during the refactor.

Why this matters: the Anima architecture relies on a zero-padding, no-mask”design:

  1. Qwen3 encodes the text using an attention mask, then the padding positions are explicitly zeroed out.
  2. The LLM Adapter processes the cleaned embeddings, and its padding is zeroed out as well.
  3. DiT cross-attention runs with attn_mask_type="no_mask" — no attention mask is required because zero embeddings contribute nothing to attention
    (K = 0 → score ≈ 0, V = 0 → contribution = 0).

When I compared my old version with your refactored one, this is what I got after running the encoder:

--- Text: '' ---
Qwen3 mask sum: 0, T5 mask sum: 1
OLD: padding non-zero = 0
NEW: padding non-zero = 524288

--- Text: 'a cat on a mat' ---
Qwen3 mask sum: 5, T5 mask sum: 8
OLD: padding non-zero = 0
NEW: padding non-zero = 518833

Without this, garbage values in the padding positions propagate through the entire pipeline.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, I'll check it as soon as I have time.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I accidentally deleted this sentence: nd_encoded_text[~nd_attn_mask.bool()] = 0.

However, even after adding this line it still returns NaN, as tested below:

            print(f"Source hidden states shape: {source_hidden_states.shape},sum of attention mask: {torch.sum(source_attention_mask)}")
            print(f"non zero source_hidden_states before LLM Adapter: {torch.sum(source_hidden_states != 0)}")
            context = self.net.llm_adapter(
                source_hidden_states,
                target_input_ids,
                target_attention_mask=target_attention_mask,
                source_attention_mask=source_attention_mask,
            )
            context[~target_attention_mask.bool()] = 0  # zero out padding tokens
            print(f"LLM Adapter output context: {context.shape}, {torch.isnan(context).sum()}")
            return context

->

Source hidden states shape: torch.Size([1, 512, 1024]),sum of attention mask: 0
non zero source_hidden_states before LLM Adapter: 0
LLM Adapter output context: torch.Size([1, 512, 1024]), 1024

Even a simple SDPA test returns NaN (even with float32).

q = torch.randn(1, 4, 16, 32, dtype=torch.bfloat16)
k = torch.randn(1, 4, 16, 32, dtype=torch.bfloat16)
v = torch.randn(1, 4, 16, 32, dtype=torch.bfloat16)
x = F.scaled_dot_product_attention(q,k,v,attn_mask=torch.zeros(1,1,1,16, dtype=torch.bool))
print(torch.isnan(x).sum())
tensor(2048)

I don't understand why the diffusion-pipe is working.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kohya-ss I think you might be using an older version of PyTorch. Based on my research, PyTorch versions earlier than 2.5 can cause this issue. Here is the reference: pytorch/pytorch#103749, and a proposed fix is available here: pytorch/pytorch#133882.

I also tested this in my environment using PyTorch 2.9.0, and here are the results:
image

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, that makes sense!

Updating the dependencies is easy, but it might be confusing for users, so I'd like to address this in a way that skips cross attention.

Copy link
Copy Markdown
Contributor

@duongve13112002 duongve13112002 Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add a check for the PyTorch version. If they are using version >= 2.5, it affects how cross-attention behaves. I’m not entirely sure whether this applies when cross-attention is not used, but any misunderstanding here could lead to poorer model training.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comment. You're right.

I checked and found that the minimum recommended version of PyTorch was recently changed to 2.6 after merging the sd3 branch into main. (I was testing in an older environment.)

So it seems like it would be enough to just write in the documentation asking users to migrate to 2.6 or later.

dtype: torch.dtype,
device: torch.device,
guidance_scale: float = 1.0,
flow_shift: float = 3.0,
Copy link
Copy Markdown
Contributor

@duongve13112002 duongve13112002 Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should change flow_shift to 1.0 to avoid confusion, since the default value for discrete_flow_shift during training is already set to 1.0.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's certainly true. However, the flow_shift during training and the flow_shift during inference do not necessarily need to match, and especially if there are few inference steps, quality will decrease with flow_shift=1. Considering the user's intended use, I think it would be better to set a somewhat higher value as the default.

@kohya-ss
Copy link
Copy Markdown
Owner Author

@kohya-ss Hi, i wonder when you finish refactor the code i am looking forward to use this

It's almost done so I think I'll merge it tomorrow.

@kohya-ss kohya-ss marked this pull request as ready for review February 11, 2026 13:07
@kohya-ss kohya-ss requested a review from Copilot February 11, 2026 13:07
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the Anima training/inference stack to improve compatibility (notably swapping WanVAE_ usage for the Qwen Image VAE), expands safetensors loading utilities for large models, and adjusts Anima-specific caching/attention/training plumbing.

Changes:

  • Replace WanVAE_-based latent handling with Qwen Image VAE APIs across Anima training and caching strategies.
  • Add Anima metadata support to SAI model spec + update attention path to use unified library.attention.
  • Extend memory-efficient safetensors loading with optional numpy memmap disablement and weight transform hooks (split/concat/rename), and integrate into FP8 + LoRA loading.

Reviewed changes

Copilot reviewed 22 out of 23 changed files in this pull request and generated 26 comments.

Show a summary per file
File Description
train_network.py Makes loss reduction dimension-agnostic (mean over all dims except batch).
tests/manual_test_anima_real_training.py Adds a manual end-to-end training smoke test runner for cache TE outputs.
tests/manual_test_anima_cache.py Updates manual caching diagnostics for Qwen Image VAE + formatting changes.
networks/lora_anima.py Refactors Anima LoRA module selection to regex include/exclude + regex LR/dim overrides.
library/train_util.py Threads caption dropout rate into ImageInfo, adjusts TE cacheability checks, adds anima to model spec metadata builder.
library/strategy_anima.py Removes TE-side dropout, adds cached-output dropout based on per-sample rates, switches latent caching to Qwen Image VAE.
library/sai_model_spec.py Adds Anima architecture/implementation identifiers and default resolution behavior.
library/safetensors_utils.py Adds disable_numpy_memmap, split-weight filename helper, and weight transform adapter/hooks.
library/lora_utils.py Integrates split weights helper + weight transform hooks; supports optional numpy memmap disablement.
library/fp8_optimization_utils.py Integrates transform hooks + numpy memmap toggle; adds guard against applying FP8 scaling to already-FP8 weights.
library/flux_train_utils.py Makes timestep/noise input utilities tolerant to latent dimensionality changes.
library/custom_offloading_utils.py Adds synchronization/waits before switching forward-only mode and before forward device preparation.
library/attention.py Adds small capability properties on AttentionParams and docstring parameter name fix.
library/anima_vae.py Removes WanVAE_ implementation file from the repo.
library/anima_utils.py Reworks model loading to use dynamic fp8/LoRA merge pipeline; updates save path metadata handling.
library/anima_train_utils.py Updates sampling/inference flow, attention args, VAE decode path, and metadata emission for Anima.
library/anima_models.py Switches DiT attention impl to unified library.attention, renames main model to Anima, adds block-swap inference/training switching.
anima_train_network.py Updates Anima LoRA training to new model loader, Qwen VAE, caching dropout behavior, and scheduler utils.
anima_train.py Updates full finetune training for Qwen VAE and new noise/timestep utilities; removes blockwise fused optimizer path.
anima_minimal_inference.py Adds a large minimal inference script for Anima with batch/interactive modes, LoRA support, and VAE decode.
_typos.toml Updates typos config entries and excludes configs from typo scanning.
Comments suppressed due to low confidence (4)

tests/manual_test_anima_cache.py:90

  • qwen_image_autoencoder_kl.load_vae does not accept a dtype argument (it returns a model you can .to(dtype) afterward). This call will raise TypeError. Consider removing the dtype= kwarg and casting the returned VAE via .to(vae_dtype).
    tests/manual_test_anima_cache.py:89
  • Keyword argument 'dtype' is not a supported parameter name of function load_vae.
    tests/manual_test_anima_cache.py:350
  • Keyword argument 'dropout_rate' is not a supported parameter name of AnimaTextEncodingStrategy.init.
    tests/manual_test_anima_cache.py:396
  • Keyword argument 'dropout_rate' is not a supported parameter name of AnimaTextEncodingStrategy.init.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_out(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This expression mutates a default value.

Copilot uses AI. Check for mistakes.
kohya-ss and others added 12 commits February 11, 2026 22:28
Use torch.all instead of all.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Fix duplicated new_key for concat_hook.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Remove unused code.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Remove unused import.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Remove unused import.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@duongve13112002
Copy link
Copy Markdown
Contributor

Everything looks good so far. I reviewed all your changes and tested them, and I didn’t find any potential bugs.

@kohya-ss
Copy link
Copy Markdown
Owner Author

Everything looks good so far. I reviewed all your changes and tested them, and I didn’t find any potential bugs.

I'm glad to hear that! Thank you again for your great contribution!

@kohya-ss kohya-ss merged commit 34e7138 into sd3 Feb 12, 2026
3 checks passed
@kohya-ss kohya-ss deleted the feat-anima-polish branch February 12, 2026 23:15
@kohya-ss kohya-ss mentioned this pull request Feb 12, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants