Skip to content

[examples/controlnet/train_controlnet_sd3.py] prompt_embeds and pooled_prompt_embeds not cast to weight_dtype in bf16/fp16 training #11050

@andjoer

Description

@andjoer

Describe the bug

When training with --mixed_precision bf16 or fp16, the prompt_embeds and pooled_prompt_embeds tensors in the compute_text_embeddings function are not cast to the appropriate weight_dtype (matching the rest of the model inputs and parameters), causing a mismatch error during training.

Specifically, the tensors are generated as float32 (default) and not moved to bf16/fp16, which leads to issues when performing operations inside the transformer/controlnet forward pass.

Reproduction

accelerate launch train_controlnet_sd3.py \
--pretrained_model_name_or_path=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \
--train_data_dir=$DATASET \
--resolution=512 \
--caption_column=$CAPTION_COLLUMN \
--output_dir=$OUTPUT_DIR \
--learning_rate=1e-5 \
--max_train_steps=15000 \
--validation_steps=100 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--mixed_precision=bf16 

Logs

System Info

  • 🤗 Diffusers version: 0.33.0.dev0
  • Platform: Linux-6.5.0-45-generic-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.11.10
  • PyTorch version (GPU?): 2.4.1+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.29.3
  • Transformers version: 4.49.0
  • Accelerate version: 1.5.1
  • PEFT version: not installed
  • Bitsandbytes version: not installed
  • Safetensors version: 0.5.3
  • xFormers version: not installed
  • Accelerator: NVIDIA A100 80GB PCIe, 81920 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@sayakpaul

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions