Skip to content

Add inpainting training and sampling support for SD1.5 and SDXL#2309

Merged
kohya-ss merged 8 commits intokohya-ss:devfrom
allanoepping:inpainting
May 5, 2026
Merged

Add inpainting training and sampling support for SD1.5 and SDXL#2309
kohya-ss merged 8 commits intokohya-ss:devfrom
allanoepping:inpainting

Conversation

@allanoepping
Copy link
Copy Markdown
Contributor

Summary

This PR implements inpainting model training support for both SD1.5 and SDXL, based on the approach originally proposed in #173 by @Fannovel16. That PR was never merged; this is a ground-up reimplementation that brings it up to date with the current codebase, extends it to SDXL, and adds inpainting-aware sampling during training.

The core technique: the UNet receives a 9-channel input during training — 4 channels of noisy latents, 1 channel of a downsampled binary mask, and 4 channels of the VAE-encoded masked image. At inference, the same concatenation is applied inside the denoising loop so that sample images generated during training checkpoints reflect the inpainting task.

Changes

UNet — 9-channel input support

  • library/original_unet.py, library/sdxl_original_unet.py: added explicit in_channels parameter so an inpainting checkpoint's 9-channel conv_in weight is accepted without shape mismatch errors.

Model loading — auto-detect in_channels from checkpoint

  • library/model_util.py: reads conv_in.weight.shape[1] after conversion and overrides unet_config["in_channels"] when it differs from 4, enabling transparent loading of existing inpainting checkpoints (e.g. sd-v1-5-inpainting.ckpt).
  • library/sdxl_model_util.py, library/sdxl_train_util.py: same for SDXL (input_blocks.0.0.weight).

Training loop

  • fine_tune.py, train_db.py, train_network.py, train_textual_inversion.py, sdxl_train.py: when batch["masks"] is present, encodes the masked image via VAE and concatenates [noisy_latents, mask, masked_image_latents] as the UNet input. Uses latents.shape[2:] for mask interpolation (fixes tuple // int error when args.resolution is a tuple) and casts mask to weight_dtype before concatenation.

Dataset / config pipeline

  • library/train_util.py: train_inpainting: bool propagated through BaseDataset, DreamBoothDataset, FineTuningDataset, ControlNetDataset; --train_inpainting CLI flag; assertion that cache_latents and train_inpainting cannot be used together (masks are generated randomly per step from the source image); --img /path directive in prompt files for sampling; graceful skip-on-missing-image; sampling resolution rounded to multiples of 64.

Sampling pipelines

  • library/lpw_stable_diffusion.py, library/sdxl_lpw_stable_diffusion.py: added inpaint_image/inpaint_mask parameters; encodes masked image before the denoising loop; prepends 9-channel input to latent_model_input each step with proper dtype casting; prepare_latents uses vae.config.latent_channels (4) instead of unet.in_channels (9) so the initial noise tensor is always 4-channel.

Procedural mask generation

  • library/mask_generator.py (new): generates inpainting masks procedurally — fractional Brownian motion cloud masks (layered via cv2), convex polygon masks, and basic shape masks (rect/ellipse); combined and fully random modes. Used both during training and for sampling previews.

Test utilities

  • tests/generate_inpainting_test_data.py: generates synthetic training images for smoke tests.
  • tests/download_training_data.py: streams real images from common-canvas/commoncatalog-cc-by via HuggingFace datasets with metadata and post-download size filtering.
  • tests/visualize_masks.py: renders a gallery PNG for each mask type for visual inspection of mask quality.
  • tests/run_inpainting_test.sh, tests/run_sdxl_inpainting_test.sh: SD1.5 and SDXL smoke test scripts.
  • tests/sdxl_inpainting_test.toml: memory-efficient SDXL training config (Adafactor, bf16, gradient checkpointing, no cache_latents).

Usage

Train an inpainting model from an existing inpainting checkpoint

accelerate launch train_network.py
--pretrained_model_name_or_path sd-v1-5-inpainting.ckpt
--train_inpainting
...

Add to a prompt file to get inpainting sample images during training:

--img /path/to/reference.jpg

Notes
--train_inpainting is incompatible with --cache_latents / --cache_latents_to_disk because masks are generated randomly per step.
Existing non-inpainting training is unaffected; the inpainting path is only active when batch["masks"] is not None.

- 9-channel UNet input (noisy_latents + mask + masked_image_latents)
  wired through all training scripts (train_db, train_network,
  fine_tune, train_textual_inversion, sdxl_train)
- Auto-detect in_channels from checkpoint conv_in weight shape in
  model_util.py and sdxl_model_util.py; UNet constructors accept
  explicit in_channels parameter
- Inpainting inference added to lpw_stable_diffusion.py and
  sdxl_lpw_stable_diffusion.py: encodes masked image before denoising
  loop, prepends 9-ch input each step; latent init uses
  vae.config.latent_channels (4) not unet.in_channels (9)
- --train_inpainting CLI flag; cache_latents incompatibility assertion;
  --img prompt directive for sampling source image; missing image
  gracefully skips sample; resolution rounded to multiples of 64
- library/mask_generator.py: procedural cloud (fBm), polygon, shape,
  and combined random mask generation using numpy/cv2/PIL
- tests/: synthetic data generator, mask visualizer, HuggingFace
  training data downloader, SD1.5 and SDXL smoke test scripts and TOML
Fix conflicts from dev branch merge
@kohya-ss
Copy link
Copy Markdown
Owner

kohya-ss commented Apr 5, 2026

Thank you for this PR. I think it's very well implemented.

However, with several lightweight image editing models available now, I'm skeptical about how much demand there will be for SDXL inpainting models.

Furthermore, if other training tools such as Diffusers have this functionality, it may not necessarily be necessary to implement it in sd-scripts.

Could you tell me where SDXL inpainting models are being used or where there is demand for them?

@allanoepping
Copy link
Copy Markdown
Contributor Author

Thank you for your hard work on this project. Inpainting is still used and embedded into may toolsets. Inpainting is better when you need visual continuity, such as in medical imaging, but also in many other cases.

I'm not aware of an easy to use open-source tool, such as yours - for end-users - that can train for this. If you search for it Gemini even recommends Kohya_ss, which hasn't supported it (until now). Since SDXL is still widely used, and I've seen others ask about derivations of models being adapted for inpainting, I decided to implement this.

I don't think this will be hard to maintain, and most of the changes are isolated. If you have some suggestions on changes I can make to improve maintainability I'd be happy to make them.

@kohya-ss
Copy link
Copy Markdown
Owner

kohya-ss commented Apr 5, 2026

Thank you for the detailed explanation; I understand it well.

The code is indeed well-isolated and seems to have low maintenance costs. I will review it and consider merging it.

I will handle the consistency of the code with other parts of the repository after the merge. Could you please add documentation for this feature, even if it's only in English?

@allanoepping
Copy link
Copy Markdown
Contributor Author

allanoepping commented Apr 5, 2026 via email

@allanoepping
Copy link
Copy Markdown
Contributor Author

I've added documentation. I don't have a good way to verify the Japanese translation. I did a check but I don't have easy access to a native speaker to verify.

@kohya-ss
Copy link
Copy Markdown
Owner

kohya-ss commented Apr 6, 2026

Thank you, I think the documentation, including the Japanese version, is well done.

I will continue the review, but it seems that a minimal inference script for testing is not currently included. If possible, a script like flux_minimal_inference.py would be helpful. Interactive mode is not necessarily required; command-line arguments alone is sufficient.

Also, I would appreciate it if you could provide a URL where I can download the checkpoints (both SD and SDXL) to use for testing.

@allanoepping
Copy link
Copy Markdown
Contributor Author

I've added a inpainting_minimal_inference.py script and a wobbly ellipsoid mask generator for better sampling.

Here are the base inpainting models, although it should be able to train an inpainting model from a base model:

https://huggingface.co/wangqyqq/sd_xl_base_1.0_inpainting_0.1.safetensors/blob/main/sd_xl_base_1.0_inpainting_0.1.safetensors
https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt

A section was added to the inpainting_training.md for the inference script. I can remove that if not desired as not really an end-user script.

@kohya-ss
Copy link
Copy Markdown
Owner

kohya-ss commented Apr 7, 2026

Thank you again for update! I will review and merge this sooner.

@allanoepping
Copy link
Copy Markdown
Contributor Author

Thank you, I appreciate it!

@kohya-ss
Copy link
Copy Markdown
Owner

kohya-ss commented Apr 9, 2026

I'm testing from SD1.5 finetuning.

The documentation says "a standard model checkpoint if you want to train inpainting from scratch," but when I specify the weights of a standard model (not inpainting), I get the following error at line 1591 of original_unet.py:

return F.conv2d(
RuntimeError: Given groups=1, weight of size [320, 4, 3, 3], expected input[2, 9, 80, 56] to have 4 channels, but got 9 channels instead.

Could you please check this?

It might be necessary to create a model instance with 9 input channels and transform the shape of the weights of conv_in, before loading them using load_state_dict.

@allanoepping
Copy link
Copy Markdown
Contributor Author

Will do!

Thanks,
Allan

…5 smoke test

Add expand_unet_to_inpainting() to model_util.py, which expands conv_in from
4 to 9 channels when --train_inpainting is set on a standard (non-inpainting)
checkpoint. Original weights are preserved in channels 0-3; channels 4-8 are
zero-initialised. Called automatically in both train_util.load_target_model
(SD1.5) and sdxl_train_util.load_target_model (SDXL) when in_channels==4.

Also fix a FutureWarning from diffusers by setting steps_offset=1 in
get_my_scheduler(), matching the expected SD1.5 scheduler configuration.

Add tests/sd15_inpainting_test.toml and tests/run_sd15_inpainting_test.sh,
a smoke test equivalent to the SDXL one using train_db.py at 512x512/fp16.
Accepts both standard and inpainting SD1.5 checkpoints.

Update docs/inpainting_training.md to reflect that standard checkpoints now
work automatically with --train_inpainting.
@allanoepping
Copy link
Copy Markdown
Contributor Author

Added commit:

Support standard (4-ch) checkpoints for inpainting training; add SD1.5 smoke test

Add expand_unet_to_inpainting() to model_util.py, which expands conv_in from
4 to 9 channels when --train_inpainting is set on a standard (non-inpainting)
checkpoint. Original weights are preserved in channels 0-3; channels 4-8 are
zero-initialised. Called automatically in both train_util.load_target_model
(SD1.5) and sdxl_train_util.load_target_model (SDXL) when in_channels==4.

Also fix a FutureWarning from diffusers by setting steps_offset=1 in
get_my_scheduler(), matching the expected SD1.5 scheduler configuration.

Add tests/sd15_inpainting_test.toml and tests/run_sd15_inpainting_test.sh,
a smoke test equivalent to the SDXL one using train_db.py at 512x512/fp16.
Accepts both standard and inpainting SD1.5 checkpoints.

Update docs/inpainting_training.md to reflect that standard checkpoints now
work automatically with --train_inpainting.

Thank you!

Copy link
Copy Markdown
Owner

@kohya-ss kohya-ss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If possible, I would appreciate it if you could update them. If that's difficult, I will fix them after the merge.

Comment thread library/train_util.py Outdated
@staticmethod
def prepare_mask_and_masked_image(image, mask):
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, the shape of image is 1,C,H,W. This should be C,H,W, because this will be stacked.

Comment thread library/train_util.py
example["images"] = images

example['masks'] = torch.stack(masks) if masks else None
example['masked_images'] = torch.stack(masked_images) if masked_images else None
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

masked_images appears to be B,1,C,H,W. B,C,H,W would be appropriate. The same applies to masks.

Comment thread fine_tune.py Outdated

if batch["masks"] is not None:
masked_latents = vae.encode(
batch["masked_images"].reshape(batch["images"].shape).to(dtype=weight_dtype)
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reshape should be removable by updating train_util.py (same for other training scripts.)

@allanoepping
Copy link
Copy Markdown
Contributor Author

If possible, I would appreciate it if you could update them. If that's difficult, I will fix them after the merge.

I will fix those.

Thank you!

In prepare_mask_and_masked_image (train_util.py), image was shaped
[1,C,H,W] and mask [1,1,H,W] due to image[None] and mask[None,None].
After torch.stack these became [B,1,C,H,W] and [B,1,1,H,W]. Fixed to
image.transpose(2,0,1) → [C,H,W] and mask[None] → [1,H,W], so stacked
batches are the correct [B,C,H,W] and [B,1,H,W].

Removed the .reshape(batch["images"].shape) workaround from fine_tune.py,
train_db.py, and train_network.py that was compensating for the extra dim.

Replaced the per-item interpolate loop + stack + reshape in fine_tune.py,
train_db.py, train_network.py, and sdxl_train.py with a single
F.interpolate call on the full [B,1,H,W] batch tensor.
@allanoepping
Copy link
Copy Markdown
Contributor Author

That's fixed now.

Thank you!

@kohya-ss
Copy link
Copy Markdown
Owner

kohya-ss commented May 5, 2026

Thanks so much for this PR, and apologies for the long review delay on my side.

I've gone through the changes and the design looks solid — the strategy of expanding conv_in from 4 to 9 channels for standard checkpoints, the per-step procedural mask generation, the wiring through all five training scripts, and the addition of inpainting_minimal_inference.py and the smoke tests are all very well done. The shape/scaling fixes in the recent commits (ff7945d, 77a8bd4) also addressed several issues cleanly.

I'd like to merge this into the dev branch as-is and follow up with two small fixes on our side, rather than asking you to make further changes:

  1. ControlNetDataset delegate — the internal DreamBoothDataset(...) call in ControlNetDataset.init passes positional arguments, and the new train_inpainting parameter wasn't threaded through, which shifts every subsequent argument by one slot and breaks ControlNet training. One-line fix.
  2. train_textual_inversion.py mask interpolate — this script was missed by ff7945d's batched-interpolate refactor and still uses the old per-item loop, which now mismatches the new [B,1,H,W] mask shape. Replacing it with the same single F.interpolate call used in the other scripts resolves it.

A couple of smaller refinements (dtype consistency for the masked-image VAE encode in fine_tune.py / train_textual_inversion.py, leftover .reshape() calls, the --img directive aligning with gen_img.py's --i) will go in a separate follow-up PR.

Thank you again for the substantial contribution — really appreciate the care put into the testing infrastructure and documentation as well.

@kohya-ss kohya-ss merged commit 5a822a4 into kohya-ss:dev May 5, 2026
2 of 3 checks passed
kohya-ss added a commit that referenced this pull request May 5, 2026
* Fix ControlNetDataset delegate: forward train_inpainting to DreamBoothDataset

ControlNetDataset.__init__ constructs an internal DreamBoothDataset via
positional arguments. PR #2309 inserted a new train_inpainting parameter
into DreamBoothDataset.__init__ between prior_loss_weight and
debug_dataset, but this delegate call was not updated, so every
subsequent positional argument was shifted by one slot:

  - train_inpainting       <- received debug_dataset (bool, type-correct
                              but logically wrong: enables inpainting whenever
                              debug mode is on)
  - debug_dataset          <- received validation_split (float)
  - validation_split       <- received validation_seed (int)
  - validation_seed        <- received resize_interpolation (str)
  - resize_interpolation   <- received skip_image_resolution (tuple)
  - skip_image_resolution  <- defaulted to None (missing)

This breaks ControlNet training entirely. Forward the dataset's own
train_inpainting attribute to keep the delegate consistent.

* Fix train_textual_inversion mask interpolate to match new batch shape

Commit ff7945d (PR #2309) fixed the masks/masked_images batch shapes from
[B,1,1,H,W]/[B,1,C,H,W] to [B,1,H,W]/[B,C,H,W] and replaced the per-item
F.interpolate loop with a single batched call in fine_tune.py, train_db.py,
train_network.py, and sdxl_train.py — but train_textual_inversion.py was
missed.

After ff7945d batch["masks"] is [B,1,H,W]; iterating over it yields
[1,H,W] slices. F.interpolate on a 3D tensor with a 2-element size argument
performs 1D interpolation and errors on the dimension mismatch, so
textual-inversion inpainting training crashes on the first step.

Replace the loop+stack with the same single-call form used by the other
training scripts, and drop the now-redundant .reshape() workaround.

* doc: avoid typo warning in documentation for mask visualization output
@allanoepping
Copy link
Copy Markdown
Contributor Author

Thank you and all the work you do maintaining this wonderful software!

kohya-ss added a commit that referenced this pull request May 6, 2026
Inpainting cleanup: misc fixes following PR #2309 review
pull Bot pushed a commit to CrazyForks/sd-scripts that referenced this pull request May 7, 2026
- sdxl_train: drop leftover .reshape(batch["images"].shape) no-op in
  masked-image VAE encode (dead since the ff7945d batched-mask fix).
- model_util.expand_unet_to_inpainting: use unet.register_to_config(in_channels=9)
  so the diffusers FrozenDict stays in sync (the previous isinstance(..., dict)
  guard either silently skipped or raised, depending on diffusers version).
- sample_image_inference: round resolution to mod-32 for SDXL (2 downsamples,
  latent /4) and mod-64 for SD1.5/2.x (3 downsamples, latent /8); fix the
  inaccurate "SDXL requires latents divisible by 8" comment.
- inpainting_minimal_inference: remove unused denoise() — the actual loop is
  inlined in main().
- BaseDataset.random_mask: drop unused ratio / mask_full_image parameters and
  the dead cloud_mask import.
- Rename per-prompt --img directive to --i for consistency with Musubi Tuner
  (which already uses --i alongside the shared w/h/l/d/s directives).
- When --train_inpainting is set but the prompt has no --i, warn and skip the
  sample instead of falling through to the standard pipeline (which would
  crash the 9-channel UNet on 4-channel input). Matches the existing
  missing-file skip behaviour.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants