Skip to content

Validation Errors When Training ControlNet with FSDP #4037

@liming-ai

Description

@liming-ai

Describe the bug

Adding FSDP into training the given ControlNet example training code leads to an unexpected bug, with the following config:

compute_environment: LOCAL_MACHINE
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: SIZE_BASED_WRAP
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_min_num_params: 100000000
  fsdp_offload_params: false
  fsdp_sharding_strategy: 1
  fsdp_state_dict_type: FULL_STATE_DICT
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

This error is caused by this line cannot unwrap FSDP into the original class

controlnet = accelerator.unwrap_model(controlnet)

# type(controlnet): <class 'torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel'
controlnet = accelerator.unwrap_model(controlnet)
# type(controlnet): <class 'torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel'

To this end, we cannot use the default diffusers.StableDiffusionControlNetPipeline to run the inference/validation:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/tiger/diffusers/examples/controlnet/train_controlnet.py:1127 in <module>                   │
│                                                                                                  │
│   1124                                                                                           │
│   1125 if __name__ == "__main__":                                                                │
│   1126 │   args = parse_args()                                                                   │
│ ❱ 1127 │   main(args)                                                                            │
│   1128                                                                                           │
│                                                                                                  │
│ /home/tiger/diffusers/examples/controlnet/train_controlnet.py:1083 in main                       │
│                                                                                                  │
│   1080 │   │   │   │   │   │   │   │   │   shutil.rmtree(removing_checkpoint)                    │
│   1081 │   │   │   │   │                                                                         │
│   1082 │   │   │   │   │   if args.validation_prompt is not None and global_step % args.validat  │
│ ❱ 1083 │   │   │   │   │   │   image_logs = log_validation(                                      │
│   1084 │   │   │   │   │   │   │   vae,                                                          │
│   1085 │   │   │   │   │   │   │   text_encoder,                                                 │
│   1086 │   │   │   │   │   │   │   tokenizer,                                                    │
│                                                                                                  │
│ /home/tiger/diffusers/examples/controlnet/train_controlnet.py:126 in log_validation              │
│                                                                                                  │
│    123 │   │                                                                                     │
│    124 │   │   for _ in range(args.num_validation_images):                                       │
│    125 │   │   │   with torch.autocast("cuda"):                                                  │
│ ❱  126 │   │   │   │   image = pipeline(                                                         │
│    127 │   │   │   │   │   validation_prompt, validation_image, num_inference_steps=20, generat  │
│    128 │   │   │   │   ).images[0]                                                               │
│    129                                                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.9/dist-packages/torch/utils/_contextlib.py:115 in decorate_context        │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ /home/tiger/.local/lib/python3.9/site-packages/diffusers/pipelines/controlnet/pipeline_controlne │
│ t.py:840 in __call__                                                                             │
│                                                                                                  │
│    837 │   │   │   ]                                                                             │
│    838 │   │                                                                                     │
│    839 │   │   # 1. Check inputs. Raise error if not correct                                     │
│ ❱  840 │   │   self.check_inputs(                                                                │
│    841 │   │   │   prompt,                                                                       │
│    842 │   │   │   image,                                                                        │
│    843 │   │   │   callback_steps,                                                               │
│                                                                                                  │
│ /home/tiger/.local/lib/python3.9/site-packages/diffusers/pipelines/controlnet/pipeline_controlne │
│ t.py:570 in check_inputs                                                                         │
│                                                                                                  │
│    567 │   │   │   for image_ in image:                                                          │
│    568 │   │   │   │   self.check_image(image_, prompt, prompt_embeds)                           │
│    569 │   │   else:                                                                             │
│ ❱  570 │   │   │   assert False                                                                  │
│    571 │   │                                                                                     │
│    572 │   │   # Check `controlnet_conditioning_scale`                                           │
│    573 │   │   if (                                                                              │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AssertionError

Reproduction

Please use FSDP, and run the given ControlNet example training code.

Logs

No response

System Info

  • diffusers version: 0.18.0.dev0
  • Platform: Linux-5.4.56.bsk.10-amd64-x86_64-with-glibc2.31
  • Python version: 3.9.2
  • PyTorch version (GPU?): 2.0.0+cu117 (True)
  • Huggingface_hub version: 0.15.1
  • Transformers version: 4.27.4
  • Accelerate version: 0.20.3
  • xFormers version: 0.0.18
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: Yes

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions