-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Closed
Labels
bugSomething isn't workingSomething isn't workingstaleIssues that haven't received updatesIssues that haven't received updates
Description
Describe the bug
Adding FSDP into training the given ControlNet example training code leads to an unexpected bug, with the following config:
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: SIZE_BASED_WRAP
fsdp_backward_prefetch_policy: BACKWARD_PRE
fsdp_min_num_params: 100000000
fsdp_offload_params: false
fsdp_sharding_strategy: 1
fsdp_state_dict_type: FULL_STATE_DICT
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
This error is caused by this line cannot unwrap FSDP into the original class
| controlnet = accelerator.unwrap_model(controlnet) |
# type(controlnet): <class 'torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel'
controlnet = accelerator.unwrap_model(controlnet)
# type(controlnet): <class 'torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel'To this end, we cannot use the default diffusers.StableDiffusionControlNetPipeline to run the inference/validation:
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/tiger/diffusers/examples/controlnet/train_controlnet.py:1127 in <module> │
│ │
│ 1124 │
│ 1125 if __name__ == "__main__": │
│ 1126 │ args = parse_args() │
│ ❱ 1127 │ main(args) │
│ 1128 │
│ │
│ /home/tiger/diffusers/examples/controlnet/train_controlnet.py:1083 in main │
│ │
│ 1080 │ │ │ │ │ │ │ │ │ shutil.rmtree(removing_checkpoint) │
│ 1081 │ │ │ │ │ │
│ 1082 │ │ │ │ │ if args.validation_prompt is not None and global_step % args.validat │
│ ❱ 1083 │ │ │ │ │ │ image_logs = log_validation( │
│ 1084 │ │ │ │ │ │ │ vae, │
│ 1085 │ │ │ │ │ │ │ text_encoder, │
│ 1086 │ │ │ │ │ │ │ tokenizer, │
│ │
│ /home/tiger/diffusers/examples/controlnet/train_controlnet.py:126 in log_validation │
│ │
│ 123 │ │ │
│ 124 │ │ for _ in range(args.num_validation_images): │
│ 125 │ │ │ with torch.autocast("cuda"): │
│ ❱ 126 │ │ │ │ image = pipeline( │
│ 127 │ │ │ │ │ validation_prompt, validation_image, num_inference_steps=20, generat │
│ 128 │ │ │ │ ).images[0] │
│ 129 │
│ │
│ /usr/local/lib/python3.9/dist-packages/torch/utils/_contextlib.py:115 in decorate_context │
│ │
│ 112 │ @functools.wraps(func) │
│ 113 │ def decorate_context(*args, **kwargs): │
│ 114 │ │ with ctx_factory(): │
│ ❱ 115 │ │ │ return func(*args, **kwargs) │
│ 116 │ │
│ 117 │ return decorate_context │
│ 118 │
│ │
│ /home/tiger/.local/lib/python3.9/site-packages/diffusers/pipelines/controlnet/pipeline_controlne │
│ t.py:840 in __call__ │
│ │
│ 837 │ │ │ ] │
│ 838 │ │ │
│ 839 │ │ # 1. Check inputs. Raise error if not correct │
│ ❱ 840 │ │ self.check_inputs( │
│ 841 │ │ │ prompt, │
│ 842 │ │ │ image, │
│ 843 │ │ │ callback_steps, │
│ │
│ /home/tiger/.local/lib/python3.9/site-packages/diffusers/pipelines/controlnet/pipeline_controlne │
│ t.py:570 in check_inputs │
│ │
│ 567 │ │ │ for image_ in image: │
│ 568 │ │ │ │ self.check_image(image_, prompt, prompt_embeds) │
│ 569 │ │ else: │
│ ❱ 570 │ │ │ assert False │
│ 571 │ │ │
│ 572 │ │ # Check `controlnet_conditioning_scale` │
│ 573 │ │ if ( │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AssertionError
Reproduction
Please use FSDP, and run the given ControlNet example training code.
Logs
No response
System Info
diffusersversion: 0.18.0.dev0- Platform: Linux-5.4.56.bsk.10-amd64-x86_64-with-glibc2.31
- Python version: 3.9.2
- PyTorch version (GPU?): 2.0.0+cu117 (True)
- Huggingface_hub version: 0.15.1
- Transformers version: 4.27.4
- Accelerate version: 0.20.3
- xFormers version: 0.0.18
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: Yes
Who can help?
No response
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingstaleIssues that haven't received updatesIssues that haven't received updates