Add/modify some implementation for anima#2261
Conversation
|
I forgot to change the default value of discrete_flow_shift to 1.0 in the anima_train_network.md document. Would you mind updating it? Thanks! |
Of course, thank you for letting me know! |
|
I'd like to replace Anima's VAE code with Qwen Image VAE, licensed under ASL 2.0 from Diffusers. |
|
I would like to change the LoRA module selection to be regular expression based, as I think it would be more flexible. |
|
I don't see the need to use fp32 for transformer weights for LoRA training, and there are other options for fp8. Also, finetuning defaults to fp32 and can optionally change to bf16/fp16, so I'll remove the transformer_dtype option. |
|
I will change it to use existing options such as Also, it seems that caption dropout is effective in Anima. However, I think it will be confusing if |
|
I'd like to use Also, replacing process_batch in AnimaNetworkTraineer involves a lot of overlapping functionality, so I plan to do something about it. |
|
I don’t see any issues with making that replacement. Some of the functions in anima_train_utils were copied or adapted from sd3_train_utils in the first place, so using sd3_train_utils.get_noisy_model_input_and_timesteps should be fine as long as it fits your training setup. The overlapping functionality was intentional to help avoid conflicts with other training models. If you plan to refactor process_batch to reduce duplication, that sounds like a good improvement. If you have any question about my implement. Feel free to ask i am happy to help you. |
|
The following processing is performed using LLMAdapter. When the uncond string "" is tokenized, all of Qwen3's attention masks become 0. So, Therefore, when uncond, no gradients to the LLMAdapter. It also seems that we do not need to cache the uncond embedding because it's all zero always. Please let me know if there is anything in my understanding that is incorrect. |
|
Your understanding is incorrect. Here's why: Qwen3 tokenizing "" does NOT produce all-zero attention mask. Qwen3 has eos_token = "<|endoftext|>" (ID 151643), no BOS. So tokenizing "" with padding="max_length", max_length=512 gives:
Same for T5: eos_token_id=1, so mask is also [1, 0, ..., 0]. After LLM Adapter + zero-out:
Position 0 (EOS) keeps its value from the adapter. Only positions 1-511 become zero. Result: uncond crossattn_emb = [non_zero_EOS_embedding, 0, 0, ..., 0], not all zeros. The LLM Adapter still receives gradients through position 0, and the uncond embedding must be properly cached. |
|
Thank you! My simple test returns T5 has a mask of length 1. It's my misunderstanding. However, all outputs of Qwen3 are masked and are therefore 0. So should we only cache T5's input_ids ( |
|
Oh, I misunderstood Qwen3 as well, I was just reading the config and making assumptions based on my knowledge. I tested it, and you're right. Your idea makes sense; I think we should cache only T5's input_ids. |
|
When caption dropout is enabled, llm_adapter seems to return NaN for empty string regardless of whether it is cached or not. This also happens during inference. Do you have any idea? I'm really confused... |
This seems to be a mathematical stability issue with scaled_dot_product_attention when mask is all 0. I will change it to skip SDPA when mask is 0. |
| outputs = qwen3_text_encoder(input_ids=nd_input_ids, attention_mask=nd_attn_mask) | ||
| nd_encoded_text = outputs.last_hidden_state | ||
| # Zero out padding positions | ||
| nd_encoded_text[~nd_attn_mask.bool()] = 0 |
There was a problem hiding this comment.
I checked and noticed that this line was missed during the refactor.
Why this matters: the Anima architecture relies on a zero-padding, no-mask”design:
- Qwen3 encodes the text using an attention mask, then the padding positions are explicitly zeroed out.
- The LLM Adapter processes the cleaned embeddings, and its padding is zeroed out as well.
- DiT cross-attention runs with
attn_mask_type="no_mask"— no attention mask is required because zero embeddings contribute nothing to attention
(K = 0 → score ≈ 0, V = 0 → contribution = 0).
When I compared my old version with your refactored one, this is what I got after running the encoder:
--- Text: '' ---
Qwen3 mask sum: 0, T5 mask sum: 1
OLD: padding non-zero = 0
NEW: padding non-zero = 524288
--- Text: 'a cat on a mat' ---
Qwen3 mask sum: 5, T5 mask sum: 8
OLD: padding non-zero = 0
NEW: padding non-zero = 518833
Without this, garbage values in the padding positions propagate through the entire pipeline.
There was a problem hiding this comment.
Thank you, I'll check it as soon as I have time.
There was a problem hiding this comment.
Sorry, I accidentally deleted this sentence: nd_encoded_text[~nd_attn_mask.bool()] = 0.
However, even after adding this line it still returns NaN, as tested below:
print(f"Source hidden states shape: {source_hidden_states.shape},sum of attention mask: {torch.sum(source_attention_mask)}")
print(f"non zero source_hidden_states before LLM Adapter: {torch.sum(source_hidden_states != 0)}")
context = self.net.llm_adapter(
source_hidden_states,
target_input_ids,
target_attention_mask=target_attention_mask,
source_attention_mask=source_attention_mask,
)
context[~target_attention_mask.bool()] = 0 # zero out padding tokens
print(f"LLM Adapter output context: {context.shape}, {torch.isnan(context).sum()}")
return context
->
Source hidden states shape: torch.Size([1, 512, 1024]),sum of attention mask: 0
non zero source_hidden_states before LLM Adapter: 0
LLM Adapter output context: torch.Size([1, 512, 1024]), 1024
Even a simple SDPA test returns NaN (even with float32).
q = torch.randn(1, 4, 16, 32, dtype=torch.bfloat16)
k = torch.randn(1, 4, 16, 32, dtype=torch.bfloat16)
v = torch.randn(1, 4, 16, 32, dtype=torch.bfloat16)
x = F.scaled_dot_product_attention(q,k,v,attn_mask=torch.zeros(1,1,1,16, dtype=torch.bool))
print(torch.isnan(x).sum())
tensor(2048)
I don't understand why the diffusion-pipe is working.
There was a problem hiding this comment.
@kohya-ss I think you might be using an older version of PyTorch. Based on my research, PyTorch versions earlier than 2.5 can cause this issue. Here is the reference: pytorch/pytorch#103749, and a proposed fix is available here: pytorch/pytorch#133882.
I also tested this in my environment using PyTorch 2.9.0, and here are the results:

There was a problem hiding this comment.
Thank you, that makes sense!
Updating the dependencies is easy, but it might be confusing for users, so I'd like to address this in a way that skips cross attention.
There was a problem hiding this comment.
I think we should add a check for the PyTorch version. If they are using version >= 2.5, it affects how cross-attention behaves. I’m not entirely sure whether this applies when cross-attention is not used, but any misunderstanding here could lead to poorer model training.
There was a problem hiding this comment.
Thanks for your comment. You're right.
I checked and found that the minimum recommended version of PyTorch was recently changed to 2.6 after merging the sd3 branch into main. (I was testing in an older environment.)
So it seems like it would be enough to just write in the documentation asking users to migrate to 2.6 or later.
| dtype: torch.dtype, | ||
| device: torch.device, | ||
| guidance_scale: float = 1.0, | ||
| flow_shift: float = 3.0, |
There was a problem hiding this comment.
I think we should change flow_shift to 1.0 to avoid confusion, since the default value for discrete_flow_shift during training is already set to 1.0.
There was a problem hiding this comment.
That's certainly true. However, the flow_shift during training and the flow_shift during inference do not necessarily need to match, and especially if there are few inference steps, quality will decrease with flow_shift=1. Considering the user's intended use, I think it would be better to set a somewhat higher value as the default.
It's almost done so I think I'll merge it tomorrow. |
There was a problem hiding this comment.
Pull request overview
This PR updates the Anima training/inference stack to improve compatibility (notably swapping WanVAE_ usage for the Qwen Image VAE), expands safetensors loading utilities for large models, and adjusts Anima-specific caching/attention/training plumbing.
Changes:
- Replace WanVAE_-based latent handling with Qwen Image VAE APIs across Anima training and caching strategies.
- Add Anima metadata support to SAI model spec + update attention path to use unified
library.attention. - Extend memory-efficient safetensors loading with optional numpy memmap disablement and weight transform hooks (split/concat/rename), and integrate into FP8 + LoRA loading.
Reviewed changes
Copilot reviewed 22 out of 23 changed files in this pull request and generated 26 comments.
Show a summary per file
| File | Description |
|---|---|
train_network.py |
Makes loss reduction dimension-agnostic (mean over all dims except batch). |
tests/manual_test_anima_real_training.py |
Adds a manual end-to-end training smoke test runner for cache TE outputs. |
tests/manual_test_anima_cache.py |
Updates manual caching diagnostics for Qwen Image VAE + formatting changes. |
networks/lora_anima.py |
Refactors Anima LoRA module selection to regex include/exclude + regex LR/dim overrides. |
library/train_util.py |
Threads caption dropout rate into ImageInfo, adjusts TE cacheability checks, adds anima to model spec metadata builder. |
library/strategy_anima.py |
Removes TE-side dropout, adds cached-output dropout based on per-sample rates, switches latent caching to Qwen Image VAE. |
library/sai_model_spec.py |
Adds Anima architecture/implementation identifiers and default resolution behavior. |
library/safetensors_utils.py |
Adds disable_numpy_memmap, split-weight filename helper, and weight transform adapter/hooks. |
library/lora_utils.py |
Integrates split weights helper + weight transform hooks; supports optional numpy memmap disablement. |
library/fp8_optimization_utils.py |
Integrates transform hooks + numpy memmap toggle; adds guard against applying FP8 scaling to already-FP8 weights. |
library/flux_train_utils.py |
Makes timestep/noise input utilities tolerant to latent dimensionality changes. |
library/custom_offloading_utils.py |
Adds synchronization/waits before switching forward-only mode and before forward device preparation. |
library/attention.py |
Adds small capability properties on AttentionParams and docstring parameter name fix. |
library/anima_vae.py |
Removes WanVAE_ implementation file from the repo. |
library/anima_utils.py |
Reworks model loading to use dynamic fp8/LoRA merge pipeline; updates save path metadata handling. |
library/anima_train_utils.py |
Updates sampling/inference flow, attention args, VAE decode path, and metadata emission for Anima. |
library/anima_models.py |
Switches DiT attention impl to unified library.attention, renames main model to Anima, adds block-swap inference/training switching. |
anima_train_network.py |
Updates Anima LoRA training to new model loader, Qwen VAE, caching dropout behavior, and scheduler utils. |
anima_train.py |
Updates full finetune training for Qwen VAE and new noise/timestep utilities; removes blockwise fused optimizer path. |
anima_minimal_inference.py |
Adds a large minimal inference script for Anima with batch/interactive modes, LoRA support, and VAE decode. |
_typos.toml |
Updates typos config entries and excludes configs from typo scanning. |
Comments suppressed due to low confidence (4)
tests/manual_test_anima_cache.py:90
qwen_image_autoencoder_kl.load_vaedoes not accept adtypeargument (it returns a model you can.to(dtype)afterward). This call will raiseTypeError. Consider removing thedtype=kwarg and casting the returned VAE via.to(vae_dtype).
tests/manual_test_anima_cache.py:89- Keyword argument 'dtype' is not a supported parameter name of function load_vae.
tests/manual_test_anima_cache.py:350 - Keyword argument 'dropout_rate' is not a supported parameter name of AnimaTextEncodingStrategy.init.
tests/manual_test_anima_cache.py:396 - Keyword argument 'dropout_rate' is not a supported parameter name of AnimaTextEncodingStrategy.init.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) | ||
| x = self.conv_out(x, feat_cache[idx]) | ||
| feat_cache[idx] = cache_x | ||
| feat_idx[0] += 1 |
There was a problem hiding this comment.
This expression mutates a default value.
Use torch.all instead of all. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Fix duplicated new_key for concat_hook. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Remove unused code. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Remove unused import. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Remove unused import. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…(WIP, not tested yet)
…der loading function
…ule name adjustments
|
Everything looks good so far. I reviewed all your changes and tested them, and I didn’t find any potential bugs. |
I'm glad to hear that! Thank you again for your great contribution! |
Uh oh!
There was an error while loading. Please reload this page.