Skip to content

flux img2img controlnet channels error #9979

@wen020

Description

@wen020

Describe the bug

When I use flux's img2img controlnet for inference, a channel error occurs.

Reproduction

import numpy as np
import torch
import cv2
from PIL import Image
from diffusers.utils import load_image
from diffusers import FluxControlNetImg2ImgPipeline, FluxControlNetPipeline
from diffusers import FluxControlNetModel
from controlnet_aux import HEDdetector

base_model = "black-forest-labs/FLUX.1-dev"
controlnet_model = "Xlabs-AI/flux-controlnet-hed-diffusers"
controlnet = FluxControlNetModel.from_pretrained(
  controlnet_model,
  torch_dtype=torch.bfloat16,
  use_safetensors=True,
)
pipe = FluxControlNetImg2ImgPipeline.from_pretrained(
    base_model, controlnet=controlnet, torch_dtype=torch.bfloat16
)
pipe.load_lora_weights("./toonystarkKoreanWebtoonFlux_fluxLoraAlpha.safetensors")

pipe.enable_sequential_cpu_offload()

hed = HEDdetector.from_pretrained("lllyasviel/Annotators")

image_source = load_image("./03.jpeg")
control_image = hed(image_source)
control_image = control_image.resize(image_source.size)
if control_image.mode != 'RGB':
    control_image = control_image.convert('RGB')
control_image.save(f"./hed_03.png")

prompt = "bird, cool, futuristic"
image = pipe(
    prompt,
    image=image_source,
    control_image=control_image,
    control_guidance_start=0.2,
    control_guidance_end=0.8,
    controlnet_conditioning_scale=0.5,
    num_inference_steps=50,
    guidance_scale=6,
).images[0]
image.save("flux.png")

Logs

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[13], line 2
      1 prompt = "bird, cool, futuristic"
----> 2 image = pipe(
      3     prompt,
      4     image=image_source,
      5     control_image=control_image,
      6     control_guidance_start=0.2,
      7     control_guidance_end=0.8,
      8     controlnet_conditioning_scale=0.5,
      9     num_inference_steps=50,
     10     guidance_scale=6,
     11 ).images[0]
     12 image.save("flux.png")

File /opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File /opt/conda/lib/python3.11/site-packages/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py:924, in FluxControlNetImg2ImgPipeline.__call__(self, prompt, prompt_2, image, control_image, height, width, strength, num_inference_steps, timesteps, guidance_scale, control_guidance_start, control_guidance_end, control_mode, controlnet_conditioning_scale, num_images_per_prompt, generator, latents, prompt_embeds, pooled_prompt_embeds, output_type, return_dict, joint_attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length)
    921         controlnet_cond_scale = controlnet_cond_scale[0]
    922     cond_scale = controlnet_cond_scale * controlnet_keep[i]
--> 924 controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
    925     hidden_states=latents,
    926     controlnet_cond=control_image,
    927     controlnet_mode=control_mode,
    928     conditioning_scale=cond_scale,
    929     timestep=timestep / 1000,
    930     guidance=guidance,
    931     pooled_projections=pooled_prompt_embeds,
    932     encoder_hidden_states=prompt_embeds,
    933     txt_ids=text_ids,
    934     img_ids=latent_image_ids,
    935     joint_attention_kwargs=self.joint_attention_kwargs,
    936     return_dict=False,
    937 )
    939 guidance = (
    940     torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None
    941 )
    942 guidance = guidance.expand(latents.shape[0]) if guidance is not None else None

File /opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File /opt/conda/lib/python3.11/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File /opt/conda/lib/python3.11/site-packages/diffusers/models/controlnets/controlnet_flux.py:281, in FluxControlNetModel.forward(self, hidden_states, controlnet_cond, controlnet_mode, conditioning_scale, encoder_hidden_states, pooled_projections, timestep, img_ids, txt_ids, guidance, joint_attention_kwargs, return_dict)
    278 hidden_states = self.x_embedder(hidden_states)
    280 if self.input_hint_block is not None:
--> 281     controlnet_cond = self.input_hint_block(controlnet_cond)
    282     batch_size, channels, height_pw, width_pw = controlnet_cond.shape
    283     height = height_pw // self.config.patch_size

File /opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File /opt/conda/lib/python3.11/site-packages/diffusers/models/controlnets/controlnet.py:99, in ControlNetConditioningEmbedding.forward(self, conditioning)
     98 def forward(self, conditioning):
---> 99     embedding = self.conv_in(conditioning)
    100     embedding = F.silu(embedding)
    102     for block in self.blocks:

File /opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File /opt/conda/lib/python3.11/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File /opt/conda/lib/python3.11/site-packages/torch/nn/modules/conv.py:460, in Conv2d.forward(self, input)
    459 def forward(self, input: Tensor) -> Tensor:
--> 460     return self._conv_forward(input, self.weight, self.bias)

File /opt/conda/lib/python3.11/site-packages/torch/nn/modules/conv.py:456, in Conv2d._conv_forward(self, input, weight, bias)
    452 if self.padding_mode != 'zeros':
    453     return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
    454                     weight, bias, self.stride,
    455                     _pair(0), self.dilation, self.groups)
--> 456 return F.conv2d(input, weight, bias, self.stride,
    457                 self.padding, self.dilation, self.groups)

RuntimeError: Given groups=1, weight of size [16, 3, 3, 3], expected input[1, 1, 4096, 64] to have 3 channels, but got 1 channels instead

System Info

latest diffusers

Who can help?

@yiyixuxu @sayakpaul

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions