feat: add fp16 safe patch option for training on older GPUs to prevent NaNs for Anima model#2274
feat: add fp16 safe patch option for training on older GPUs to prevent NaNs for Anima model#2274ihatenumbers wants to merge 1 commit intokohya-ss:mainfrom
Conversation
|
Thank you for this PR and for identifying the fp16 NaN issue on older GPUs! The core problem you've identified (residual stream overflow in fp16) is real and important. After reviewing the approach, we realized there's a much simpler solution. When training with The fix can be as simple as adding this to the beginning of if x_B_T_H_W_D.dtype == torch.float16:
x_B_T_H_W_D = x_B_T_H_W_D.float()This promotes the residual stream to fp32, preventing overflow. The sub-modules still run in fp16 thanks to the existing autocast context, and their outputs are automatically upcast when added back to the fp32 residual. No monkey-patching, no global settings, and no extra flag needed — it activates automatically when the input is fp16. We'll close this PR and implement this fix on our side, but we really appreciate you bringing this issue to our attention. Your work and the reference to the fp16 patch on HuggingFace were very helpful in understanding the problem. |
|
I've merged #2277. If you'd like, I'd be happy if you could check that it works. |
This fixes NaN loss for GPUs without bf16 support by adding
fp16_safe_patch, outputting black images even withfull_fp16 = true, and/ormixed_precision = "fp16". Just addfp16_safe_patch = truein config.toml. Confirmed it doesn't give black images during sampling while training and doesn't do NaN anymore.This was inspired from https://huggingface.co/RicemanT/Loras_Collection/blob/main/anina_fp16_patch.py and with the help of gemini.