Skip to content

feat: SD1.x/2.x と SDXL 向けの LECO 学習スクリプトを追加#2285

Merged
kohya-ss merged 9 commits intokohya-ss:fix/leco-cleanupfrom
umisetokikaze:feat-leco
Mar 28, 2026
Merged

feat: SD1.x/2.x と SDXL 向けの LECO 学習スクリプトを追加#2285
kohya-ss merged 9 commits intokohya-ss:fix/leco-cleanupfrom
umisetokikaze:feat-leco

Conversation

@umisetokikaze
Copy link
Copy Markdown
Contributor

概要

Stable Diffusion 1.x/2.x および SDXL 向けの LECO 学習機能を追加しました。
モデル自身のノイズ予測を使って LoRA を学習できるようにし、画像データセットを使わない LECO 学習フローを実装しています。

変更内容

  • SD 1.x / 2.x 向けの LECO 学習スクリプト train_leco.py を追加
  • SDXL 向けの LECO 学習スクリプト sdxl_train_leco.py を追加
  • LECO 学習用の共通ユーティリティ library/leco_train_util.py を追加
  • 従来の LECO 形式に加えて、ai-toolkit style の slider target YAML 読み込みに対応
  • SDXL で add_time_ids と size embedding を用いたノイズ予測に対応
  • LECO のノイズ予測経路で gradient checkpointing を使えるようにし、メモリ使用量を抑制
  • 共有の学習引数検証および deepspeed 関連処理を getattr ベースにして、新規スクリプトでも安全に再利用できるよう修正
  • docs/train_leco.md に LECO 学習手順と YAML 例を追加
  • パーサ検証、プロンプト設定の読み込み、SDXL の vector embedding 挙動を確認するテストを追加

- Implemented `sdxl_train_leco.py` for training with LECO prompts, including argument parsing, model setup, training loop, and weight saving functionality.
- Created unit tests for `load_prompt_settings` in `test_leco_train_util.py` to validate loading of prompt configurations in both original and slider formats.
- Added basic syntax tests for `train_leco.py` and `sdxl_train_leco.py` to ensure modules are importable.
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 53e3263253

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

loss = torch.nn.functional.mse_loss(target_latents.float(), target.float(), reduction="none")
loss = loss.mean(dim=(1, 2, 3))
if args.min_snr_gamma is not None and args.min_snr_gamma > 0:
timesteps = torch.full((loss.shape[0],), current_timestep_index, device=loss.device, dtype=torch.long)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Use actual timestep for min-SNR weighting

When --min_snr_gamma is set, this code builds timesteps from current_timestep_index (the position in noise_scheduler.timesteps) instead of the diffusion timestep value used for prediction. apply_snr_weight indexes noise_scheduler.all_snr by the provided tensor, so this applies SNR weights for the wrong noise level (e.g., high-noise states get weighted like low-noise ones), which can materially distort LECO training dynamics.

Useful? React with 👍 / 👎.

loss = torch.nn.functional.mse_loss(target_latents.float(), target.float(), reduction="none")
loss = loss.mean(dim=(1, 2, 3))
if args.min_snr_gamma is not None and args.min_snr_gamma > 0:
timesteps = torch.full((loss.shape[0],), current_timestep_index, device=loss.device, dtype=torch.long)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Pass diffusion timestep value into min-SNR loss

The SDXL path has the same mismatch: timesteps is filled with current_timestep_index rather than the timestep value passed to predict_noise_xl. Because apply_snr_weight expects real timestep IDs for all_snr lookup, enabling --min_snr_gamma will compute incorrect per-step loss weights and can bias or destabilize slider training results.

Useful? React with 👍 / 👎.

@kohya-ss
Copy link
Copy Markdown
Owner

PR、ありがとうございます。

申し訳ありません、Codexのレビュー設定をオフにしてあるのにレビューが投稿されてしまいました。Codexのレビューは無視していただいて構いません。

内容について早急にレビューいたします。

@kohya-ss kohya-ss changed the base branch from main to sd3 March 28, 2026 10:19
@kohya-ss
Copy link
Copy Markdown
Owner

確認が遅くなり申し訳ありません。基本的な機能は問題ないようです。他のスクリプトへの影響を抑えるため、マージ後にいくつか修正をさせていただきますがご了解ください。
改めて素晴らしいPRをありがとうございました。

@kohya-ss kohya-ss changed the base branch from sd3 to fix/leco-cleanup March 28, 2026 10:22
@kohya-ss kohya-ss merged commit 4ea6032 into kohya-ss:fix/leco-cleanup Mar 28, 2026
3 checks passed
kohya-ss added a commit that referenced this pull request Mar 28, 2026
- train_util.py/deepspeed_utils.py の getattr 化を元に戻し、LECO パーサーにダミー引数を追加
- sdxl_train_util のモジュールレベルインポートをローカルインポートに変更
- PromptEmbedsCache.__getitem__ でキャッシュミス時に KeyError を送出するよう修正
- 設定ファイル形式を YAML から TOML に変更(リポジトリの規約に統一)
- 重複コード (build_network_kwargs, get_save_extension, save_weights) を leco_train_util.py に統合
- _expand_slider_target の冗長な PromptSettings 構築を簡素化
- add_time_ids 用に専用の batch_add_time_ids 関数を追加

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
kohya-ss added a commit that referenced this pull request Mar 29, 2026
* feat: SD1.x/2.x と SDXL 向けの LECO 学習スクリプトを追加 (#2285)

* Add LECO training script and associated tests

- Implemented `sdxl_train_leco.py` for training with LECO prompts, including argument parsing, model setup, training loop, and weight saving functionality.
- Created unit tests for `load_prompt_settings` in `test_leco_train_util.py` to validate loading of prompt configurations in both original and slider formats.
- Added basic syntax tests for `train_leco.py` and `sdxl_train_leco.py` to ensure modules are importable.

* fix: use getattr for safe attribute access in argument verification

* feat: add CUDA device compatibility validation and corresponding tests

* Revert "feat: add CUDA device compatibility validation and corresponding tests"

This reverts commit 6d3e514.

* feat: update predict_noise_xl to use vector embedding from add_time_ids

* feat: implement checkpointing in predict_noise and predict_noise_xl functions

* feat: remove unused submodules and update .gitignore to exclude .codex-tmp

---------

Co-authored-by: Kohya S. <52813779+kohya-ss@users.noreply.github.com>

* fix: format

* fix: LECO PR #2285 のレビュー指摘事項を修正

- train_util.py/deepspeed_utils.py の getattr 化を元に戻し、LECO パーサーにダミー引数を追加
- sdxl_train_util のモジュールレベルインポートをローカルインポートに変更
- PromptEmbedsCache.__getitem__ でキャッシュミス時に KeyError を送出するよう修正
- 設定ファイル形式を YAML から TOML に変更(リポジトリの規約に統一)
- 重複コード (build_network_kwargs, get_save_extension, save_weights) を leco_train_util.py に統合
- _expand_slider_target の冗長な PromptSettings 構築を簡素化
- add_time_ids 用に専用の batch_add_time_ids 関数を追加

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* docs: LECO 学習ガイドを大幅に拡充

コマンドライン引数の全カテゴリ別解説、プロンプト TOML の全フィールド説明、
2つの guidance_scale の違い、推奨設定表、YAML からの変換ガイド等を追加。
英語本文と日本語折り畳みの二言語構成。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: apply_noise_offset の dtype 不一致を修正

torch.randn のデフォルト float32 により latents が暗黙的にアップキャストされる問題を修正。
float32/CPU で生成後に latents の dtype/device へ変換する安全なパターンを採用。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Umisetokikaze <52318966+umisetokikaze@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants