Add inpainting training and sampling support for SD1.5 and SDXL#2309
Add inpainting training and sampling support for SD1.5 and SDXL#2309kohya-ss merged 8 commits intokohya-ss:devfrom
Conversation
…n't appear to get merged
- 9-channel UNet input (noisy_latents + mask + masked_image_latents) wired through all training scripts (train_db, train_network, fine_tune, train_textual_inversion, sdxl_train) - Auto-detect in_channels from checkpoint conv_in weight shape in model_util.py and sdxl_model_util.py; UNet constructors accept explicit in_channels parameter - Inpainting inference added to lpw_stable_diffusion.py and sdxl_lpw_stable_diffusion.py: encodes masked image before denoising loop, prepends 9-ch input each step; latent init uses vae.config.latent_channels (4) not unet.in_channels (9) - --train_inpainting CLI flag; cache_latents incompatibility assertion; --img prompt directive for sampling source image; missing image gracefully skips sample; resolution rounded to multiples of 64 - library/mask_generator.py: procedural cloud (fBm), polygon, shape, and combined random mask generation using numpy/cv2/PIL - tests/: synthetic data generator, mask visualizer, HuggingFace training data downloader, SD1.5 and SDXL smoke test scripts and TOML
Fix conflicts from dev branch merge
|
Thank you for this PR. I think it's very well implemented. However, with several lightweight image editing models available now, I'm skeptical about how much demand there will be for SDXL inpainting models. Furthermore, if other training tools such as Diffusers have this functionality, it may not necessarily be necessary to implement it in sd-scripts. Could you tell me where SDXL inpainting models are being used or where there is demand for them? |
|
Thank you for your hard work on this project. Inpainting is still used and embedded into may toolsets. Inpainting is better when you need visual continuity, such as in medical imaging, but also in many other cases. I'm not aware of an easy to use open-source tool, such as yours - for end-users - that can train for this. If you search for it Gemini even recommends Kohya_ss, which hasn't supported it (until now). Since SDXL is still widely used, and I've seen others ask about derivations of models being adapted for inpainting, I decided to implement this. I don't think this will be hard to maintain, and most of the changes are isolated. If you have some suggestions on changes I can make to improve maintainability I'd be happy to make them. |
|
Thank you for the detailed explanation; I understand it well. The code is indeed well-isolated and seems to have low maintenance costs. I will review it and consider merging it. I will handle the consistency of the code with other parts of the repository after the merge. Could you please add documentation for this feature, even if it's only in English? |
|
Certainly. I'll get to work on it.
Thank you!
…On Sun, Apr 5, 2026, 4:47 PM Kohya S. ***@***.***> wrote:
*kohya-ss* left a comment (kohya-ss/sd-scripts#2309)
<#2309 (comment)>
Thank you for the detailed explanation; I understand it well.
The code is indeed well-isolated and seems to have low maintenance costs.
I will review it and consider merging it.
I will handle the consistency of the code with other parts of the
repository after the merge. Could you please add documentation for this
feature, even if it's only in English?
—
Reply to this email directly, view it on GitHub
<#2309 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AE2PBXE6KV75NRCDPZ2CYPT4ULO6NAVCNFSM6AAAAACXNA4LHWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHM2DCOBZGY2TQMZRGI>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
|
I've added documentation. I don't have a good way to verify the Japanese translation. I did a check but I don't have easy access to a native speaker to verify. |
|
Thank you, I think the documentation, including the Japanese version, is well done. I will continue the review, but it seems that a minimal inference script for testing is not currently included. If possible, a script like Also, I would appreciate it if you could provide a URL where I can download the checkpoints (both SD and SDXL) to use for testing. |
Added wobbly elipse mask for better sampling
|
I've added a inpainting_minimal_inference.py script and a wobbly ellipsoid mask generator for better sampling. Here are the base inpainting models, although it should be able to train an inpainting model from a base model: https://huggingface.co/wangqyqq/sd_xl_base_1.0_inpainting_0.1.safetensors/blob/main/sd_xl_base_1.0_inpainting_0.1.safetensors A section was added to the inpainting_training.md for the inference script. I can remove that if not desired as not really an end-user script. |
|
Thank you again for update! I will review and merge this sooner. |
|
Thank you, I appreciate it! |
|
I'm testing from SD1.5 finetuning. The documentation says "a standard model checkpoint if you want to train inpainting from scratch," but when I specify the weights of a standard model (not inpainting), I get the following error at line 1591 of original_unet.py: Could you please check this? It might be necessary to create a model instance with 9 input channels and transform the shape of the weights of conv_in, before loading them using load_state_dict. |
|
Will do! Thanks, |
…5 smoke test Add expand_unet_to_inpainting() to model_util.py, which expands conv_in from 4 to 9 channels when --train_inpainting is set on a standard (non-inpainting) checkpoint. Original weights are preserved in channels 0-3; channels 4-8 are zero-initialised. Called automatically in both train_util.load_target_model (SD1.5) and sdxl_train_util.load_target_model (SDXL) when in_channels==4. Also fix a FutureWarning from diffusers by setting steps_offset=1 in get_my_scheduler(), matching the expected SD1.5 scheduler configuration. Add tests/sd15_inpainting_test.toml and tests/run_sd15_inpainting_test.sh, a smoke test equivalent to the SDXL one using train_db.py at 512x512/fp16. Accepts both standard and inpainting SD1.5 checkpoints. Update docs/inpainting_training.md to reflect that standard checkpoints now work automatically with --train_inpainting.
|
Added commit: Support standard (4-ch) checkpoints for inpainting training; add SD1.5 smoke test Add expand_unet_to_inpainting() to model_util.py, which expands conv_in from Also fix a FutureWarning from diffusers by setting steps_offset=1 in Add tests/sd15_inpainting_test.toml and tests/run_sd15_inpainting_test.sh, Update docs/inpainting_training.md to reflect that standard checkpoints now Thank you! |
kohya-ss
left a comment
There was a problem hiding this comment.
If possible, I would appreciate it if you could update them. If that's difficult, I will fix them after the merge.
| @staticmethod | ||
| def prepare_mask_and_masked_image(image, mask): | ||
| image = np.array(image.convert("RGB")) | ||
| image = image[None].transpose(0, 3, 1, 2) |
There was a problem hiding this comment.
Here, the shape of image is 1,C,H,W. This should be C,H,W, because this will be stacked.
| example["images"] = images | ||
|
|
||
| example['masks'] = torch.stack(masks) if masks else None | ||
| example['masked_images'] = torch.stack(masked_images) if masked_images else None |
There was a problem hiding this comment.
masked_images appears to be B,1,C,H,W. B,C,H,W would be appropriate. The same applies to masks.
|
|
||
| if batch["masks"] is not None: | ||
| masked_latents = vae.encode( | ||
| batch["masked_images"].reshape(batch["images"].shape).to(dtype=weight_dtype) |
There was a problem hiding this comment.
This reshape should be removable by updating train_util.py (same for other training scripts.)
I will fix those. Thank you! |
In prepare_mask_and_masked_image (train_util.py), image was shaped [1,C,H,W] and mask [1,1,H,W] due to image[None] and mask[None,None]. After torch.stack these became [B,1,C,H,W] and [B,1,1,H,W]. Fixed to image.transpose(2,0,1) → [C,H,W] and mask[None] → [1,H,W], so stacked batches are the correct [B,C,H,W] and [B,1,H,W]. Removed the .reshape(batch["images"].shape) workaround from fine_tune.py, train_db.py, and train_network.py that was compensating for the extra dim. Replaced the per-item interpolate loop + stack + reshape in fine_tune.py, train_db.py, train_network.py, and sdxl_train.py with a single F.interpolate call on the full [B,1,H,W] batch tensor.
|
That's fixed now. Thank you! |
|
Thanks so much for this PR, and apologies for the long review delay on my side. I've gone through the changes and the design looks solid — the strategy of expanding conv_in from 4 to 9 channels for standard checkpoints, the per-step procedural mask generation, the wiring through all five training scripts, and the addition of inpainting_minimal_inference.py and the smoke tests are all very well done. The shape/scaling fixes in the recent commits (ff7945d, 77a8bd4) also addressed several issues cleanly. I'd like to merge this into the dev branch as-is and follow up with two small fixes on our side, rather than asking you to make further changes:
A couple of smaller refinements (dtype consistency for the masked-image VAE encode in fine_tune.py / train_textual_inversion.py, leftover .reshape() calls, the --img directive aligning with gen_img.py's --i) will go in a separate follow-up PR. Thank you again for the substantial contribution — really appreciate the care put into the testing infrastructure and documentation as well. |
* Fix ControlNetDataset delegate: forward train_inpainting to DreamBoothDataset ControlNetDataset.__init__ constructs an internal DreamBoothDataset via positional arguments. PR #2309 inserted a new train_inpainting parameter into DreamBoothDataset.__init__ between prior_loss_weight and debug_dataset, but this delegate call was not updated, so every subsequent positional argument was shifted by one slot: - train_inpainting <- received debug_dataset (bool, type-correct but logically wrong: enables inpainting whenever debug mode is on) - debug_dataset <- received validation_split (float) - validation_split <- received validation_seed (int) - validation_seed <- received resize_interpolation (str) - resize_interpolation <- received skip_image_resolution (tuple) - skip_image_resolution <- defaulted to None (missing) This breaks ControlNet training entirely. Forward the dataset's own train_inpainting attribute to keep the delegate consistent. * Fix train_textual_inversion mask interpolate to match new batch shape Commit ff7945d (PR #2309) fixed the masks/masked_images batch shapes from [B,1,1,H,W]/[B,1,C,H,W] to [B,1,H,W]/[B,C,H,W] and replaced the per-item F.interpolate loop with a single batched call in fine_tune.py, train_db.py, train_network.py, and sdxl_train.py — but train_textual_inversion.py was missed. After ff7945d batch["masks"] is [B,1,H,W]; iterating over it yields [1,H,W] slices. F.interpolate on a 3D tensor with a 2-element size argument performs 1D interpolation and errors on the dimension mismatch, so textual-inversion inpainting training crashes on the first step. Replace the loop+stack with the same single-call form used by the other training scripts, and drop the now-redundant .reshape() workaround. * doc: avoid typo warning in documentation for mask visualization output
|
Thank you and all the work you do maintaining this wonderful software! |
Inpainting cleanup: misc fixes following PR #2309 review
- sdxl_train: drop leftover .reshape(batch["images"].shape) no-op in masked-image VAE encode (dead since the ff7945d batched-mask fix). - model_util.expand_unet_to_inpainting: use unet.register_to_config(in_channels=9) so the diffusers FrozenDict stays in sync (the previous isinstance(..., dict) guard either silently skipped or raised, depending on diffusers version). - sample_image_inference: round resolution to mod-32 for SDXL (2 downsamples, latent /4) and mod-64 for SD1.5/2.x (3 downsamples, latent /8); fix the inaccurate "SDXL requires latents divisible by 8" comment. - inpainting_minimal_inference: remove unused denoise() — the actual loop is inlined in main(). - BaseDataset.random_mask: drop unused ratio / mask_full_image parameters and the dead cloud_mask import. - Rename per-prompt --img directive to --i for consistency with Musubi Tuner (which already uses --i alongside the shared w/h/l/d/s directives). - When --train_inpainting is set but the prompt has no --i, warn and skip the sample instead of falling through to the standard pipeline (which would crash the 9-channel UNet on 4-channel input). Matches the existing missing-file skip behaviour. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Summary
This PR implements inpainting model training support for both SD1.5 and SDXL, based on the approach originally proposed in #173 by @Fannovel16. That PR was never merged; this is a ground-up reimplementation that brings it up to date with the current codebase, extends it to SDXL, and adds inpainting-aware sampling during training.
The core technique: the UNet receives a 9-channel input during training — 4 channels of noisy latents, 1 channel of a downsampled binary mask, and 4 channels of the VAE-encoded masked image. At inference, the same concatenation is applied inside the denoising loop so that sample images generated during training checkpoints reflect the inpainting task.
Changes
UNet — 9-channel input support
Model loading — auto-detect in_channels from checkpoint
Training loop
Dataset / config pipeline
Sampling pipelines
Procedural mask generation
Test utilities
Usage
Train an inpainting model from an existing inpainting checkpoint
accelerate launch train_network.py
--pretrained_model_name_or_path sd-v1-5-inpainting.ckpt
--train_inpainting
...
Add to a prompt file to get inpainting sample images during training:
--img /path/to/reference.jpg
Notes
--train_inpainting is incompatible with --cache_latents / --cache_latents_to_disk because masks are generated randomly per step.
Existing non-inpainting training is unaffected; the inpainting path is only active when batch["masks"] is not None.