Skip to content

Add inpainting training and sampling support for SD1.5 and SDXL#2318

Merged
kohya-ss merged 16 commits intomainfrom
dev
May 7, 2026
Merged

Add inpainting training and sampling support for SD1.5 and SDXL#2318
kohya-ss merged 16 commits intomainfrom
dev

Conversation

@kohya-ss
Copy link
Copy Markdown
Owner

@kohya-ss kohya-ss commented May 5, 2026

  • Add inpainting support based off of original Fannovel16 push that didn't appear to get merged

  • Add inpainting training and sampling support for SD1.5 and SDXL

  • 9-channel UNet input (noisy_latents + mask + masked_image_latents) wired through all training scripts (train_db, train_network, fine_tune, train_textual_inversion, sdxl_train)
  • Auto-detect in_channels from checkpoint conv_in weight shape in model_util.py and sdxl_model_util.py; UNet constructors accept explicit in_channels parameter
  • Inpainting inference added to lpw_stable_diffusion.py and sdxl_lpw_stable_diffusion.py: encodes masked image before denoising loop, prepends 9-ch input each step; latent init uses vae.config.latent_channels (4) not unet.in_channels (9)
  • --train_inpainting CLI flag; cache_latents incompatibility assertion; --img prompt directive for sampling source image; missing image gracefully skips sample; resolution rounded to multiples of 64
  • library/mask_generator.py: procedural cloud (fBm), polygon, shape, and combined random mask generation using numpy/cv2/PIL
  • tests/: synthetic data generator, mask visualizer, HuggingFace training data downloader, SD1.5 and SDXL smoke test scripts and TOML
  • Have tests/visualize_masks.py use downloaded training data

  • Add documentation for the inpainting feature

  • Added inpainting_minimal_inference.py for inpainting SD1.5/SDXL testing. Added wobbly elipse mask for better sampling

  • Support standard (4-ch) checkpoints for inpainting training; add SD1.5 smoke test

Add expand_unet_to_inpainting() to model_util.py, which expands conv_in from 4 to 9 channels when --train_inpainting is set on a standard (non-inpainting) checkpoint. Original weights are preserved in channels 0-3; channels 4-8 are zero-initialised. Called automatically in both train_util.load_target_model (SD1.5) and sdxl_train_util.load_target_model (SDXL) when in_channels==4.

Also fix a FutureWarning from diffusers by setting steps_offset=1 in get_my_scheduler(), matching the expected SD1.5 scheduler configuration.

Add tests/sd15_inpainting_test.toml and tests/run_sd15_inpainting_test.sh, a smoke test equivalent to the SDXL one using train_db.py at 512x512/fp16. Accepts both standard and inpainting SD1.5 checkpoints.

Update docs/inpainting_training.md to reflect that standard checkpoints now work automatically with --train_inpainting.

  • Fix mask/masked_image batch shapes and simplify mask interpolation

In prepare_mask_and_masked_image (train_util.py), image was shaped [1,C,H,W] and mask [1,1,H,W] due to image[None] and mask[None,None]. After torch.stack these became [B,1,C,H,W] and [B,1,1,H,W]. Fixed to image.transpose(2,0,1) → [C,H,W] and mask[None] → [1,H,W], so stacked batches are the correct [B,C,H,W] and [B,1,H,W].

Removed the .reshape(batch["images"].shape) workaround from fine_tune.py, train_db.py, and train_network.py that was compensating for the extra dim.

Replaced the per-item interpolate loop + stack + reshape in fine_tune.py, train_db.py, train_network.py, and sdxl_train.py with a single F.interpolate call on the full [B,1,H,W] batch tensor.

allanoepping and others added 2 commits May 5, 2026 21:32
* Add inpainting support based off of original Fannovel16 push that didn't appear to get merged

* Add inpainting training and sampling support for SD1.5 and SDXL

- 9-channel UNet input (noisy_latents + mask + masked_image_latents)
  wired through all training scripts (train_db, train_network,
  fine_tune, train_textual_inversion, sdxl_train)
- Auto-detect in_channels from checkpoint conv_in weight shape in
  model_util.py and sdxl_model_util.py; UNet constructors accept
  explicit in_channels parameter
- Inpainting inference added to lpw_stable_diffusion.py and
  sdxl_lpw_stable_diffusion.py: encodes masked image before denoising
  loop, prepends 9-ch input each step; latent init uses
  vae.config.latent_channels (4) not unet.in_channels (9)
- --train_inpainting CLI flag; cache_latents incompatibility assertion;
  --img prompt directive for sampling source image; missing image
  gracefully skips sample; resolution rounded to multiples of 64
- library/mask_generator.py: procedural cloud (fBm), polygon, shape,
  and combined random mask generation using numpy/cv2/PIL
- tests/: synthetic data generator, mask visualizer, HuggingFace
  training data downloader, SD1.5 and SDXL smoke test scripts and TOML

* Have tests/visualize_masks.py use downloaded training data

* Add documentation for the inpainting feature

* Added inpainting_minimal_inference.py for inpainting SD1.5/SDXL testing.
Added wobbly elipse mask for better sampling

* Support standard (4-ch) checkpoints for inpainting training; add SD1.5 smoke test

Add expand_unet_to_inpainting() to model_util.py, which expands conv_in from
4 to 9 channels when --train_inpainting is set on a standard (non-inpainting)
checkpoint. Original weights are preserved in channels 0-3; channels 4-8 are
zero-initialised. Called automatically in both train_util.load_target_model
(SD1.5) and sdxl_train_util.load_target_model (SDXL) when in_channels==4.

Also fix a FutureWarning from diffusers by setting steps_offset=1 in
get_my_scheduler(), matching the expected SD1.5 scheduler configuration.

Add tests/sd15_inpainting_test.toml and tests/run_sd15_inpainting_test.sh,
a smoke test equivalent to the SDXL one using train_db.py at 512x512/fp16.
Accepts both standard and inpainting SD1.5 checkpoints.

Update docs/inpainting_training.md to reflect that standard checkpoints now
work automatically with --train_inpainting.

* Fix mask/masked_image batch shapes and simplify mask interpolation

In prepare_mask_and_masked_image (train_util.py), image was shaped
[1,C,H,W] and mask [1,1,H,W] due to image[None] and mask[None,None].
After torch.stack these became [B,1,C,H,W] and [B,1,1,H,W]. Fixed to
image.transpose(2,0,1) → [C,H,W] and mask[None] → [1,H,W], so stacked
batches are the correct [B,C,H,W] and [B,1,H,W].

Removed the .reshape(batch["images"].shape) workaround from fine_tune.py,
train_db.py, and train_network.py that was compensating for the extra dim.

Replaced the per-item interpolate loop + stack + reshape in fine_tune.py,
train_db.py, train_network.py, and sdxl_train.py with a single
F.interpolate call on the full [B,1,H,W] batch tensor.
* Fix ControlNetDataset delegate: forward train_inpainting to DreamBoothDataset

ControlNetDataset.__init__ constructs an internal DreamBoothDataset via
positional arguments. PR #2309 inserted a new train_inpainting parameter
into DreamBoothDataset.__init__ between prior_loss_weight and
debug_dataset, but this delegate call was not updated, so every
subsequent positional argument was shifted by one slot:

  - train_inpainting       <- received debug_dataset (bool, type-correct
                              but logically wrong: enables inpainting whenever
                              debug mode is on)
  - debug_dataset          <- received validation_split (float)
  - validation_split       <- received validation_seed (int)
  - validation_seed        <- received resize_interpolation (str)
  - resize_interpolation   <- received skip_image_resolution (tuple)
  - skip_image_resolution  <- defaulted to None (missing)

This breaks ControlNet training entirely. Forward the dataset's own
train_inpainting attribute to keep the delegate consistent.

* Fix train_textual_inversion mask interpolate to match new batch shape

Commit ff7945d (PR #2309) fixed the masks/masked_images batch shapes from
[B,1,1,H,W]/[B,1,C,H,W] to [B,1,H,W]/[B,C,H,W] and replaced the per-item
F.interpolate loop with a single batched call in fine_tune.py, train_db.py,
train_network.py, and sdxl_train.py — but train_textual_inversion.py was
missed.

After ff7945d batch["masks"] is [B,1,H,W]; iterating over it yields
[1,H,W] slices. F.interpolate on a 3D tensor with a 2-element size argument
performs 1D interpolation and errors on the dimension mismatch, so
textual-inversion inpainting training crashes on the first step.

Replace the loop+stack with the same single-call form used by the other
training scripts, and drop the now-redundant .reshape() workaround.

* doc: avoid typo warning in documentation for mask visualization output
@kohya-ss kohya-ss changed the title Add inpainting training and sampling support for SD1.5 and SDXL (#2309) Add inpainting training and sampling support for SD1.5 and SDXL May 5, 2026
kohya-ss and others added 6 commits May 5, 2026 22:55
…rsion

Match the regular latents path: encode in vae_dtype, cast output to
weight_dtype. Prevents torch.cat dtype mismatch and VAE numerical
instability under --no_half_vae (vae_dtype != weight_dtype).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Fix masked-image VAE encode dtype in fine_tune and train_textual_inversion
- sdxl_train: drop leftover .reshape(batch["images"].shape) no-op in
  masked-image VAE encode (dead since the ff7945d batched-mask fix).
- model_util.expand_unet_to_inpainting: use unet.register_to_config(in_channels=9)
  so the diffusers FrozenDict stays in sync (the previous isinstance(..., dict)
  guard either silently skipped or raised, depending on diffusers version).
- sample_image_inference: round resolution to mod-32 for SDXL (2 downsamples,
  latent /4) and mod-64 for SD1.5/2.x (3 downsamples, latent /8); fix the
  inaccurate "SDXL requires latents divisible by 8" comment.
- inpainting_minimal_inference: remove unused denoise() — the actual loop is
  inlined in main().
- BaseDataset.random_mask: drop unused ratio / mask_full_image parameters and
  the dead cloud_mask import.
- Rename per-prompt --img directive to --i for consistency with Musubi Tuner
  (which already uses --i alongside the shared w/h/l/d/s directives).
- When --train_inpainting is set but the prompt has no --i, warn and skip the
  sample instead of falling through to the standard pipeline (which would
  crash the 9-channel UNet on 4-channel input). Matches the existing
  missing-file skip behaviour.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The lpw inpaint encode path had two related dtype issues that surfaced
when training-time sample generation hit the VAE inside the
accelerator.autocast() block in train_util.sample_image_inference:

1. Inputs were cast to the UNet dtype (`dtype = unet.dtype`) rather than
   `self.vae.dtype`. With `--no_half_vae` or fp32 VAE this fed the
   wrong-precision tensor into the encode.
2. More importantly, even after fixing (1), fp16 autocast forces conv
   kernels to fp16 regardless of input/weight dtype, and SDXL VAE
   produces NaN in fp16 — so `vae.encode` inside autocast NaN'd out
   even with fp32 weights and fp32 inputs.

Cast inpaint inputs to `self.vae.dtype` and wrap the encode in
`torch.autocast(..., enabled=False)`. Mirrors how the existing
`decode_latents()` runs outside the caller's autocast block.

Verified on SDXL with `--mixed_precision fp16 --no_half_vae` (was NaN,
now produces valid samples) and `--mixed_precision bf16` (no
regression). Same change applied to the SD1.5 lpw pipeline; SD1.5 VAE
is more fp16-tolerant so the bug was latent there.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Inpainting cleanup: misc fixes following PR #2309 review
Fix NaN in inpainting sample generation under fp16 autocast
@kohya-ss kohya-ss marked this pull request as ready for review May 6, 2026 22:51
@kohya-ss kohya-ss requested a review from Copilot May 6, 2026 22:52
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds end-to-end inpainting support (training + sampling/inference) for SD1.5 and SDXL by wiring the 9‑channel UNet input path across datasets, training loops, model loading, and the LPW pipelines, plus adding mask generation utilities and documentation/tests.

Changes:

  • Add --train_inpainting support to datasets/training loops, generating masks + masked images and concatenating [noisy_latents, mask, masked_latents] for 9‑channel UNet input.
  • Improve model loading to detect/handle inpainting checkpoints (9‑ch) and expand standard checkpoints (4‑ch) to 9‑ch when training inpainting.
  • Add inpainting-capable sampling/inference paths, plus new mask utilities, docs, and smoke-test scripts/configs.

Reviewed changes

Copilot reviewed 27 out of 27 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
train_textual_inversion.py Concatenates mask + masked-image latents for 9‑channel inpainting UNet training.
train_network.py Prepares masked_latents and routes 9‑channel UNet inputs during network training.
train_db.py Adds masked-latent encoding + 9‑channel UNet input concatenation for DreamBooth training.
fine_tune.py Adds masked-latent encoding + 9‑channel UNet input concatenation for fine-tuning.
sdxl_train.py Adds SDXL masked-latent encoding + 9‑channel UNet input concatenation.
library/train_util.py Adds --train_inpainting, generates procedural masks in dataset, blocks cache_latents, adds inpainting sampling support.
library/config_util.py Adds train_inpainting to dataset config schema.
library/model_util.py Detects in_channels from checkpoint and adds expand_unet_to_inpainting() helper.
library/original_unet.py Makes SD1.x original UNet conv_in/in_channels configurable.
library/sdxl_train_util.py Expands SDXL UNet conv_in for inpainting training and propagates actual in_channels.
library/sdxl_original_unet.py Makes SDXL original UNet in_channels configurable.
library/sdxl_model_util.py Detects SDXL inpainting in_channels from checkpoint and constructs UNet accordingly.
library/lpw_stable_diffusion.py Adds inpainting inputs to SD1.5 LPW pipeline and 9‑channel UNet conditioning in denoising loop.
library/sdxl_lpw_stable_diffusion.py Adds inpainting inputs to SDXL LPW pipeline and 9‑channel UNet conditioning in denoising loop.
library/mask_generator.py New procedural mask generator (cloud/polygon/shape/combined + random + wobbly ellipse).
inpainting_minimal_inference.py New minimal standalone inpainting inference script for SD1.5/SDXL.
README.md Mentions inpainting training and links new documentation.
README-ja.md Japanese README updates for inpainting training mention/link.
docs/inpainting_training.md New documentation describing inpainting training, requirements, sampling, and minimal inference.
tests/visualize_masks.py New mask visualizer that uses downloaded data when available with synthetic fallback.
tests/generate_inpainting_test_data.py Generates synthetic DreamBooth-style data for inpainting smoke tests.
tests/download_training_data.py Downloads CC-BY images into DreamBooth-style folder structure for testing/visualization.
tests/run_inpainting_test.sh Smoke-test runner for train_network.py inpainting training.
tests/run_sd15_inpainting_test.sh Smoke-test runner for SD1.5 inpainting training via train_db.py.
tests/run_sdxl_inpainting_test.sh Smoke-test runner for SDXL inpainting training via sdxl_train.py.
tests/sd15_inpainting_test.toml SD1.5 inpainting smoke-test config.
tests/sdxl_inpainting_test.toml SDXL inpainting smoke-test config.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread library/model_util.py
Comment thread library/train_util.py Outdated
Comment thread tests/sd15_inpainting_test.toml Outdated
Comment thread tests/run_inpainting_test.sh Outdated
Comment thread tests/run_sdxl_inpainting_test.sh Outdated
Comment thread library/mask_generator.py Outdated
kohya-ss and others added 5 commits May 7, 2026 08:22
…tional

PR #2321 changed sample_image_inference to return early when
--train_inpainting is set and a prompt line has no --i directive.
The doc still described the previous behavior.
…distinction

- Sampling/inference uses wobbly_ellipse_mask, not the training-time
  random_mask mixture. Note this in the "Sample images" section to
  prevent confusion.
- Add a Notes bullet distinguishing --train_inpainting from the
  unrelated --alpha_mask loss-mask feature.
The dataset-config note only mentioned cache_latents. cache_latents_to_disk
auto-enables cache_latents (train_util.py:4532), so it is equally
incompatible with --train_inpainting; spell that out for readers.
Remove unused import.

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Remove unused import.

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
kohya-ss and others added 2 commits May 7, 2026 11:31
Replace the three model-specific smoke scripts with two consolidated ones
(sd15/sdxl x ft/lora via --mode), add pytest coverage for mask_generator
and expand_unet_to_inpainting (covering the recent dtype/device regression),
and a checkpoint verifier that asserts conv_in is 9-channel after training.
Adds Windows .ps1 equivalents so the suite runs on the primary dev
environment, and switches from xformers/mem_eff_attn to sdpa (xformers
does not yet support Blackwell+CU130).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@kohya-ss kohya-ss merged commit 706a95e into main May 7, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants