Skip to content

Fix NaN in inpainting sample generation under fp16 autocast#2322

Merged
kohya-ss merged 1 commit intodevfrom
fix/inpainting-vae-dtype-lpw
May 6, 2026
Merged

Fix NaN in inpainting sample generation under fp16 autocast#2322
kohya-ss merged 1 commit intodevfrom
fix/inpainting-vae-dtype-lpw

Conversation

@kohya-ss
Copy link
Copy Markdown
Owner

@kohya-ss kohya-ss commented May 6, 2026

Summary

Bug fix for inpainting sample generation following PR #2309. With --mixed_precision fp16 (with or without --no_half_vae), the SDXL training-time inpainting sample generation produces NaN latents on the very first sample, leading to RuntimeWarning: invalid value encountered in cast and a black image. With bf16, sampling works.

Two related issues in the lpw inpaint encode path, both surfacing under the accelerator.autocast() block in train_util.sample_image_inference:

  1. Wrong dtype for VAE inputs. Inpaint inputs were cast to dtype = unet.dtype (UNet dtype), not self.vae.dtype. Under --no_half_vae (vae=fp32, unet=fp16) this fed mismatched precision into self.vae.encode. Mirrors the existing pattern used in decode_latents() (self.vae.decode(latents.to(self.vae.dtype))).
  2. Autocast forcing fp16 conv kernels. Even after fixing (1), fp16 autocast forces conv kernels to fp16 regardless of input/weight dtype, and SDXL VAE produces NaN in fp16. The existing decode_latents() happens to be called outside the caller's accelerator.autocast() block, so it was unaffected — but inpaint encode runs inside pipeline() and hit this.

Fix wraps the inpaint encode in torch.autocast(device_type=device.type, enabled=False) so the encode runs in the VAE's actual precision.

Same change applied to lpw_stable_diffusion.py (SD1.5). SD1.5 VAE is more fp16-tolerant so the bug was latent there, but the path is structurally identical and worth keeping consistent.

Test plan

Verified by @kohya-ss:

  • SDXL + --mixed_precision fp16 --no_half_vae — was NaN, now produces valid samples
  • SDXL + --mixed_precision bf16 — no regression
  • SD1.5 — no regression

Notes

🤖 Generated with Claude Code

The lpw inpaint encode path had two related dtype issues that surfaced
when training-time sample generation hit the VAE inside the
accelerator.autocast() block in train_util.sample_image_inference:

1. Inputs were cast to the UNet dtype (`dtype = unet.dtype`) rather than
   `self.vae.dtype`. With `--no_half_vae` or fp32 VAE this fed the
   wrong-precision tensor into the encode.
2. More importantly, even after fixing (1), fp16 autocast forces conv
   kernels to fp16 regardless of input/weight dtype, and SDXL VAE
   produces NaN in fp16 — so `vae.encode` inside autocast NaN'd out
   even with fp32 weights and fp32 inputs.

Cast inpaint inputs to `self.vae.dtype` and wrap the encode in
`torch.autocast(..., enabled=False)`. Mirrors how the existing
`decode_latents()` runs outside the caller's autocast block.

Verified on SDXL with `--mixed_precision fp16 --no_half_vae` (was NaN,
now produces valid samples) and `--mixed_precision bf16` (no
regression). Same change applied to the SD1.5 lpw pipeline; SD1.5 VAE
is more fp16-tolerant so the bug was latent there.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@kohya-ss kohya-ss merged commit 4e2a3fd into dev May 6, 2026
3 checks passed
@kohya-ss kohya-ss deleted the fix/inpainting-vae-dtype-lpw branch May 6, 2026 02:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant