Skip to content

fix: improve numerical stability by conditionally using float32 in Anima#2302

Merged
kohya-ss merged 2 commits intosd3from
fix-anima-fp16-nan-issue
Apr 2, 2026
Merged

fix: improve numerical stability by conditionally using float32 in Anima#2302
kohya-ss merged 2 commits intosd3from
fix-anima-fp16-nan-issue

Conversation

@kohya-ss
Copy link
Copy Markdown
Owner

@kohya-ss kohya-ss commented Apr 2, 2026

re-fix for #2293 and #2297


This pull request refactors how float32 precision is handled during forward passes in the anima_models.py model code. Instead of inferring precision inside each block, a use_fp32 flag is now determined once in the main forward method and then explicitly passed down through all relevant forward calls and custom forward wrappers. This makes the precision logic clearer and more consistent, especially when dealing with float16 inputs for numerical stability.

Precision handling improvements:

  • Added a use_fp32 argument to the forward, _forward, and related methods in model blocks, allowing explicit control over whether float32 precision is used for computations. [1] [2] [3]
  • Updated all calls to block forward methods and custom forward wrappers to pass the use_fp32 flag, ensuring consistent precision handling throughout the model. [1] [2] [3] [4]

Model forward pass changes:

  • In forward_mini_train_dit, the use_fp32 flag is now set once based on the input tensor's dtype and passed to all block and final layer calls, improving clarity and reducing duplicated logic.

@kohya-ss kohya-ss merged commit fa53f71 into sd3 Apr 2, 2026
3 checks passed
@kohya-ss kohya-ss deleted the fix-anima-fp16-nan-issue branch April 2, 2026 03:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant