Skip to content

using zero stage3 finetune sd2, dimension error occurs #1865

@Nipi64310

Description

@Nipi64310

Describe the bug

An error is reported when using deepspeed's zero stage3 finetune diffusers/examples/text_to_image/train_text_to_image.py script. My machine's GPU is 4*2080ti, and because a single GPU cannot accommodate all SD2 parameters, the deepspeed zero stage3 strategy must be used.

Reproduction

accelerate.yaml

compute_environment: LOCAL_MACHINE
deepspeed_config:
 deepspeed_config_file: /home/kas/zero_stage3_offload_config.json
 zero3_init_flag: true
distributed_type: DEEPSPEED
fsdp_config: {}
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 4
use_cpu: false

/home/kas/zero_stage3_offload_config.json

{
  "train_micro_batch_size_per_gpu": 16,
"gradient_accumulation_steps":2,
  "train_batch_size":128,
  "steps_per_print": 2,
  "gradient_clipping": 1,
  "zero_optimization": {
    "stage": 3,
    "allgather_partitions": false,
    "allgather_bucket_size": 2e8,
    "overlap_comm": true,
    "reduce_scatter": true,
    "reduce_bucket_size": 2e8,
    "contiguous_gradients": true,
    "offload_optimizer": {
      "device": "cpu",
      "pin_memory": true
    },
    "offload_param": {
      "device": "cpu",
      "pin_memory": true
    },
    "stage3_max_live_parameters" : 2e8,
    "stage3_max_reuse_distance" : 2e8,
    "stage3_prefetch_bucket_size": 2e8,
    "stage3_param_persistence_threshold": 2e8,
    "sub_group_size" : 2e8,
    "round_robin_gradients": true
  },
  "bf16": {
    "enabled": true
  }
}

launch script

pip install deepspeed
export MODEL_NAME="stabilityai/stable-diffusion-2"
export dataset_name="lambdalabs/pokemon-blip-captions"

accelerate launch --config_file ./accelerate.yaml --mixed_precision="fp16"  train_text_to_image.py  \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --dataset_name=$dataset_name \
  --use_ema \
  --resolution=224 --center_crop --random_flip \
  --train_batch_size=16 \
  --gradient_accumulation_steps=2 \
  --gradient_checkpointing \
  --max_train_steps=500 \
  --learning_rate=6e-5 \
  --max_grad_norm=1 \
  --lr_scheduler="constant_with_warmup" --lr_warmup_steps=0 \
  --output_dir="sd-pokemon-model"

Logs

`0%| | 0/500 [00:00<?, ?it/s] Steps: 0%| | 0/500 [00:00<?, ?it/s]Traceback (most recent call last):
File "train_text_to_image.py ", line 718, in <module>
main()
File "train_text_to_image.py ", line 648, in main
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
File "/opt/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/miniconda3/lib/python3.7/site-packages/transformers/models/clip/modeling_clip.py", line 739, in forward
return_dict=return_dict,
File "/opt/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/miniconda3/lib/python3.7/site-packages/transformers/models/clip/modeling_clip.py", line 636, in forward
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
File "/opt/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/miniconda3/lib/python3.7/site-packages/transformers/models/clip/modeling_clip.py", line 165, in forward
inputs_embeds = self.token_embedding(input_ids)
File "/opt/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/miniconda3/lib/python3.7/site-packages/torch/nn/modules/sparse.py", line 160, in forward
self.norm_type, self.scale_grad_by_freq, self.sparse)
File "/opt/miniconda3/lib/python3.7/site-packages/torch/nn/functional.py", line 2183, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: 'weight' must be 2-D`

System Info

  • diffusers version: 0.11.1
  • Platform: Linux-4.15.0-29-generic-x86_64-with-debian-buster-sid
  • Python version: 3.7.7
  • PyTorch version (GPU?): 1.11.0+cu113 (True)
  • Huggingface_hub version: 0.10.1
  • Transformers version: 4.23.1
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: yes

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleIssues that haven't received updates

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions