fix: improve numerical stability by conditionally using float32 in Anima#2302
Merged
fix: improve numerical stability by conditionally using float32 in Anima#2302
Conversation
This was referenced Apr 2, 2026
…ma in version 0.10.3
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
re-fix for #2293 and #2297
This pull request refactors how float32 precision is handled during forward passes in the
anima_models.pymodel code. Instead of inferring precision inside each block, ause_fp32flag is now determined once in the main forward method and then explicitly passed down through all relevant forward calls and custom forward wrappers. This makes the precision logic clearer and more consistent, especially when dealing with float16 inputs for numerical stability.Precision handling improvements:
use_fp32argument to theforward,_forward, and related methods in model blocks, allowing explicit control over whether float32 precision is used for computations. [1] [2] [3]use_fp32flag, ensuring consistent precision handling throughout the model. [1] [2] [3] [4]Model forward pass changes:
forward_mini_train_dit, theuse_fp32flag is now set once based on the input tensor's dtype and passed to all block and final layer calls, improving clarity and reducing duplicated logic.