-
Notifications
You must be signed in to change notification settings - Fork 6.7k
fix Qwen-Image series context parallel #12970
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
| self.gradient_checkpointing = False | ||
| self.zero_cond_t = zero_cond_t | ||
|
|
||
| # Make CP plan compatible with zero_cond_t |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is a good design.
We could either modify
| for module_id, cp_model_plan in plan.items(): |
And define this at the top of the model in _cp_plan directly or just instruct the users to pass an appropriate CP plan in the parallel config.
@DN6 what are your thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved modulate_index plan to top level, it seems working as expected for both Qwen-Edit-2509 and Qwen-Edit-2511.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. If possible, could you test other QwenImage family checkpoints as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. If possible, could you test other QwenImage family checkpoints as well?
Already tested on Qwen-Image, Qwen-Image-2512, Qwen-Image-Edit-2509, Qwen-Image-Edit-2511 (the Qwen-Image-Linghtning series shared the same pipelines and transformers as Qwen-Image and Qwen-Image-Edit, so, it should be work). I can make some additional tests for Qwen-Image-Layered.
|
@sayakpaul @kashif @DN6 Hi~ I provided more details and explanations, as well as many test results. PTAL~ |
|
#12974 should fix the import tests. |
|
@sayakpaul Additional tests for Qwen-Image-Layered, PTAL~ Qwen-Image-Layeredimport torch
from diffusers import QwenImageLayeredPipeline
import torch.distributed as dist
from diffusers import ContextParallelConfig
from diffusers.utils import load_image
if dist.is_available():
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
device = torch.device("cuda", rank % torch.cuda.device_count())
world_size = dist.get_world_size()
torch.cuda.set_device(device)
else:
rank = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
world_size = 1
model_id = "Qwen/Qwen-Image-Layered"
pipe = QwenImageLayeredPipeline.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload(device=device)
pipe.transformer.set_attention_backend("native")
if world_size > 1:
from diffusers import QwenImageTransformer2DModel
assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
pipe.transformer.enable_parallelism(
config=ContextParallelConfig(ulysses_degree=world_size)
)
pipe.set_progress_bar_config(disable=rank != 0)
image = load_image("https://github.com/vipshop/cache-dit/raw/main/examples/data/yarn-art-pikachu.png").convert("RGBA")
inputs = {
"image": image,
"prompt": "",
"num_inference_steps": 50,
"true_cfg_scale": 4.0,
"layers": 4,
"resolution": 640,
"cfg_normalize": False,
"use_en_prompt": True,
"generator": torch.Generator(device="cpu").manual_seed(0),
}
with torch.inference_mode():
output = pipe(**inputs)
images = output.images[0]
if world_size > 1:
save_prefix = f"output_image_layered_ulysses{world_size}"
else:
save_prefix = f"output_image_layered"
if rank == 0:
for idx, output_image in enumerate(images):
save_path = f"{save_prefix}_layer{idx}.png"
output_image.save(save_path)
print(f"image saved at {save_path}")
if dist.is_initialized():
dist.destroy_process_group()test cmds: torchrun --nproc_per_node=1 --local-ranks-filter=0 test_qwen_layered.py # baseline
torchrun --nproc_per_node=2 --local-ranks-filter=0 test_qwen_layered.py # cp2before this pr: torchrun --nproc_per_node=2 test_qwen_layered.py
W0114 05:27:20.787000 1079300 torch/distributed/run.py:803]
W0114 05:27:20.787000 1079300 torch/distributed/run.py:803] *****************************************
W0114 05:27:20.787000 1079300 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0114 05:27:20.787000 1079300 torch/distributed/run.py:803] *****************************************
Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.
Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.
Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.
Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 32.52it/s]
Loading pipeline components...: 33%|████████████████████████████████████████████████████████▋ | 2/6 [00:00<00:00, 5.27it/s]The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release.
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 33.27it/s]
Loading pipeline components...: 50%|█████████████████████████████████████████████████████████████████████████████████████ | 3/6
0%| | 0/50 [00:00<?, ?it/s][rank1]: Traceback (most recent call last):
[rank1]: File "/workspace/dev/vipshop/cache-dit/examples/tmp/test_qwen_layered.py", line 52, in <module>
[rank1]: output = pipe(**inputs)
[rank1]: ^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank1]: return func(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/workspace/dev/vipshop/diffusers/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py", line 801, in __call__
[rank1]: noise_pred = self.transformer(
[rank1]: ^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank1]: output = function_reference.forward(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.12/dist-packages/accelerate/hooks.py", line 175, in new_forward
[rank1]: output = module._old_forward(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/workspace/dev/vipshop/diffusers/src/diffusers/models/transformers/transformer_qwenimage.py", line 933, in forward
[rank1]: image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/hooks.py", line 190, in new_forward
[rank1]: return function_reference.post_forward(module, output)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 199, in post_forward
[rank1]: current_output = self._prepare_cp_input(current_output, cpm)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 211, in _prepare_cp_input
[rank1]: return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 261, in shard
[rank1]: assert tensor.size()[dim] % mesh.size() == 0, (
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^after this pr:
|
|
thanks will test it out and review today |
Hi~ Sorry to bother you~ May I ask if you have tested this PR? I think it would be better for us to fix this CP crash issue. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Is the issue caused by inferring shape like this not working with CP? batch_size, text_seq_len = encoder_hidden_states.shape[:2]We have this pattern everywhere in Diffusers though |
Yes. Therefore, many models may need to do some additional work to support CP, or use a block-level CP plan, because len inferring usually occurs outside the block forward loop. |
|
If that's the case, we could check if parallelism is enabled and derive the shapes accordingly? |
|
However, in addition to this, the standard Ulysses still has the problem that the sequence length needs to be divisible by the number of devices. For example, the CP-2 cases for Qwen-Image will fail if we increase the number of devices to 4. Our Ulysses Anything Attention can perfectly solve this problem (support any sequence length and any head num). When I have free time, I can submit a PR to support Ulysses Anything Attention in Diffusers. |
Oh it would be great to have it as a CP backend. I think you mentioned it before! |
Yes, and we have already verified its feasibility on many models. |
DN6
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes look good to me 👍🏽. Thanks @DefTruth
|
@DefTruth we want to provide you with a Diffusers MVP status because of your thoughtful discussions and PRs. Please generate certificate from here https://huggingface.co/spaces/diffusers/generate-mvp-certificate (mention this PR number). Could you also let us know the HF Hub username so that we can grant some credits? |
cool~ |
|
@sayakpaul my HF Hub username is DefTruth at https://huggingface.co/DefTruth |
|
Could you try again? |
done, hhhhh, the certificate looks really cool~ |









follows up #12702, make qwen-image cp_plan compatible with txt_seq_lens, attn_mask and modulate_index(zero_cond_t)
@sayakpaul @kashif
Qwen-Image-Edit-2509 && Qwen-Image-Edit-2511
after this pr:
Qwen-Image && Qwen-Image-2512
tests cmds:
before this pr:
after this pr: