Skip to content

Conversation

@DefTruth
Copy link
Contributor

@DefTruth DefTruth commented Jan 13, 2026

follows up #12702, make qwen-image cp_plan compatible with txt_seq_lens, attn_mask and modulate_index(zero_cond_t)

    1. Change the CP plan from transformer-level to block-level to avoid sequence length mismatch in rope and compute_text_seq_len_from_mask.
    1. Relax the attn_mask limit in context_parallel. Since we have already done the checks in the attention backend, there is no need to do additional checks in context_parallel. This can bring more opportunities for performance optimization, especially for cases with attn_mask.
    1. In the case of qwen, attn_mask was processed before context_parallel and was not affected by sequence splitting. Therefore, we do not need to perform additional processing on its attn_mask.
    1. This modification is universal for the qwen series, regardless of whether attn_mask is present or not.
    1. Make CP plan compatible with zero_cond_t (introduced by Qwen-Image-Edit-2511)
    1. Tested on Qwen-Image && Qwen-Image-2512 && Qwen-Image-Edit-2509 && Qwen-Image-Edit-2511 && Qwen-Image-Layered, confirm it is working as expected.

@sayakpaul @kashif

Qwen-Image-Edit-2509 && Qwen-Image-Edit-2511

import torch
import argparse
from diffusers.utils import load_image
from diffusers import QwenImageEditPlusPipeline
import torch.distributed as dist
from diffusers import ContextParallelConfig

def parse_args():
    parser = argparse.ArgumentParser(description="Test Qwen-Image-Edit with Context Parallelism")
    parser.add_argument(
        "--use_2511",
        action="store_true",
        help="Use Qwen-Image-Edit-2511 model if set, otherwise use 2509 model.",
    )
    return parser.parse_args()

args = parse_args()

if dist.is_available():
    dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    device = torch.device("cuda", rank % torch.cuda.device_count())
    world_size = dist.get_world_size()
    torch.cuda.set_device(device)
else:
    rank = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    world_size = 1

if args.use_2511:
    model_id = "Qwen/Qwen-Image-Edit-2511"
else:
    model_id = "Qwen/Qwen-Image-Edit-2509"

pipe = QwenImageEditPlusPipeline.from_pretrained(
    model_id, 
    torch_dtype=torch.bfloat16,
)

pipe.enable_model_cpu_offload(device=device)

pipe.transformer.set_attention_backend("native")
if world_size > 1:
    from diffusers import QwenImageTransformer2DModel
    assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
    pipe.transformer.enable_parallelism(
        config=ContextParallelConfig(ulysses_degree=world_size)
    )

pipe.set_progress_bar_config(disable=rank != 0)

image1 = load_image("https://github.com/vipshop/cache-dit/raw/main/examples/data/edit2509_1.jpg")
image2 = load_image("https://github.com/vipshop/cache-dit/raw/main/examples/data/edit2509_2.jpg")
prompt = "The magician bear is on the left, the alchemist bear is on the right, facing each other in the central park square"
inputs = {
    "image": [image1, image2],
    "prompt": prompt,
    "generator": torch.Generator(device="cpu").manual_seed(0),
    "true_cfg_scale": 4.0,
    "negative_prompt": " ",
    "num_inference_steps": 50,
    "num_images_per_prompt": 1,
    "height": 1024,
    "width": 1024,
}

with torch.inference_mode():
    output = pipe(**inputs)
    output_image = output.images[0]
    model_version = "2511" if args.use_2511 else "2509"
    if world_size > 1:
        save_path = f"output_image_edit_{model_version}_ulysses{world_size}.png"
    else:
        save_path = f"output_image_edit_{model_version}.png"
    if rank == 0:
        output_image.save(save_path)
        print(f"image saved at {save_path}")

if dist.is_initialized():
    dist.destroy_process_group()
  • test cmds:
torchrun --nproc_per_node=1 --local-ranks-filter=0 test_qwen_edit.py # baseline edit_2509
torchrun --nproc_per_node=4 --local-ranks-filter=0 test_qwen_edit.py # cp4 edit_2509
torchrun --nproc_per_node=1 --local-ranks-filter=0 test_qwen_edit.py --use_2511 # baseline edit_2511
torchrun --nproc_per_node=4 --local-ranks-filter=0 test_qwen_edit.py --use_2511 # cp4 edit_2511
  • before this pr:
torchrun --nproc_per_node=4 --local-ranks-filter=0 test_qwen_edit.py
W0114 03:16:28.131000 505681 torch/distributed/run.py:803]
W0114 03:16:28.131000 505681 torch/distributed/run.py:803] *****************************************
W0114 03:16:28.131000 505681 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0114 03:16:28.131000 505681 torch/distributed/run.py:803] *****************************************
Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.
Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  9.85it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 31.09it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:02<00:00,  2.81it/s]
Attention backends are an experimental feature and the API may be subject to change.
`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning.
  0%|                                                                                                                                                                                                                 | 0/50 [00:21<?, ?it/s]
[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/dev/vipshop/cache-dit/examples/tmp/test_qwen_edit.py", line 69, in <module>
[rank0]:     output = pipe(**inputs)
[rank0]:              ^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py", line 803, in __call__
[rank0]:     noise_pred = self.transformer(
[rank0]:                  ^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank0]:     output = function_reference.forward(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/accelerate/hooks.py", line 175, in new_forward
[rank0]:     output = module._old_forward(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/models/transformers/transformer_qwenimage.py", line 933, in forward
[rank0]:     image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device)
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/hooks.py", line 190, in new_forward
[rank0]:     return function_reference.post_forward(module, output)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 199, in post_forward
[rank0]:     current_output = self._prepare_cp_input(current_output, cpm)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 211, in _prepare_cp_input
[rank0]:     return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 261, in shard
[rank0]:     assert tensor.size()[dim] % mesh.size() == 0, (
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

after this pr:

Edit-2509 Edit-2509 Ulysses-4 Edit-2511 Edit-2511 Ulysses-4
output_image_edit_2509 output_image_edit_2509_ulysses4 output_image_edit_2511 output_image_edit_2511_ulysses4

Qwen-Image && Qwen-Image-2512

import torch
import argparse
from diffusers import QwenImagePipeline
import torch.distributed as dist
from diffusers import ContextParallelConfig

def parse_args():
    parser = argparse.ArgumentParser(description="Test Qwen-Image with Context Parallelism")
    parser.add_argument(
        "--use_2512",
        action="store_true",
        help="Use Qwen-Image-2512 model if set, otherwise use old model.",
    )
    return parser.parse_args()

args = parse_args()

if dist.is_available():
    dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    device = torch.device("cuda", rank % torch.cuda.device_count())
    world_size = dist.get_world_size()
    torch.cuda.set_device(device)
else:
    rank = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    world_size = 1

if args.use_2512:
    model_id = "Qwen/Qwen-Image-2512"
else:
    model_id = "Qwen/Qwen-Image"

pipe = QwenImagePipeline.from_pretrained(
    model_id, 
    torch_dtype=torch.bfloat16,
)

pipe.enable_model_cpu_offload(device=device)

pipe.transformer.set_attention_backend("native")
if world_size > 1:
    from diffusers import QwenImageTransformer2DModel
    assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
    pipe.transformer.enable_parallelism(
        config=ContextParallelConfig(ulysses_degree=world_size)
    )

pipe.set_progress_bar_config(disable=rank != 0)

positive_magic = {
        "en": ", Ultra HD, 4K, cinematic composition.",  # for english prompt
        "zh": ", 超清,4K,电影级构图.",  # for chinese prompt
}
prompt = (
        "A coffee shop entrance features a chalkboard sign reading "
        '"Qwen Coffee 😊 $2 per cup," with a neon light beside it '
        'displaying "通义千问". Next to it hangs a poster showing a '
        "beautiful Chinese woman, and beneath the poster is written "
        '"π≈3.1415926-53589793-23846264-33832795-02384197". '
        "Ultra HD, 4K, cinematic composition"
)
inputs = {
    "prompt": prompt + positive_magic["en"],
    "generator": torch.Generator(device="cpu").manual_seed(0),
    "true_cfg_scale": 4.0,
    "negative_prompt": " ",
    "num_inference_steps": 50,
    "num_images_per_prompt": 1,
    "height": 1024,
    "width": 1024,
}

with torch.inference_mode():
    output = pipe(**inputs)
    output_image = output.images[0]
    model_version = "2512" if args.use_2512 else None
    if world_size > 1:
        if model_version is not None:
            save_path = f"output_image_{model_version}_ulysses{world_size}.png"
        else:
            save_path = f"output_image_ulysses{world_size}.png"
    else:
        if model_version is not None:
            save_path = f"output_image_{model_version}.png"
        else:
            save_path = f"output_image.png"
    if rank == 0:
        output_image.save(save_path)
        print(f"image saved at {save_path}")

if dist.is_initialized():
    dist.destroy_process_group()

tests cmds:

torchrun --nproc_per_node=1 --local-ranks-filter=0 test_qwen_image.py # baseline
torchrun --nproc_per_node=2 --local-ranks-filter=0 test_qwen_image.py # cp2
torchrun --nproc_per_node=1 --local-ranks-filter=0 test_qwen_image.py --use_2512 # baseline 2512
torchrun --nproc_per_node=2 --local-ranks-filter=0 test_qwen_image.py --use_2512 # cp2 2512

before this pr:

torchrun --nproc_per_node=2 --local-ranks-filter=0 test_qwen_image.py

W0114 03:14:57.253000 498979 torch/distributed/run.py:803]
W0114 03:14:57.253000 498979 torch/distributed/run.py:803] *****************************************
W0114 03:14:57.253000 498979 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0114 03:14:57.253000 498979 torch/distributed/run.py:803] *****************************************
Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.
Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 32.75it/s]
Loading pipeline components...:  60%|██████████████████████████████████████████████████████████████████████████████████████████████████████                                                                    | 3/5 [00:00<00:00,  5.55it/s]The config attributes {'pooled_projection_dim': 768} were passed to QwenImageTransformer2DModel, but are not expected and will be ignored. Please verify your config.json configuration file.
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 47.18it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  4.18it/s]
Attention backends are an experimental feature and the API may be subject to change.
`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning.
  0%|                                                                                                                                                                                                                 | 0/50 [00:20<?, ?it/s]
[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/dev/vipshop/cache-dit/examples/tmp/test_qwen_image.py", line 76, in <module>
[rank0]:     output = pipe(**inputs)
[rank0]:              ^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py", line 686, in __call__
[rank0]:     noise_pred = self.transformer(
[rank0]:                  ^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank0]:     output = function_reference.forward(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/accelerate/hooks.py", line 175, in new_forward
[rank0]:     output = module._old_forward(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/models/transformers/transformer_qwenimage.py", line 933, in forward
[rank0]:     image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device)
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/hooks.py", line 190, in new_forward
[rank0]:     return function_reference.post_forward(module, output)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 199, in post_forward
[rank0]:     current_output = self._prepare_cp_input(current_output, cpm)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 211, in _prepare_cp_input
[rank0]:     return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 261, in shard
[rank0]:     assert tensor.size()[dim] % mesh.size() == 0, (
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: AssertionError: Tensor size along dimension to be sharded must be divisible by mesh size
W0114 03:15:38.068000 498979 torch/distributed/elastic/multiprocessing/api.py:908] Sending process 499171 closing signal SIGTERM
E0114 03:15:38.183000 498979 torch/distributed/elastic/multiprocessing/api.py:882] failed (exitcode: 1) local_rank: 1 (pid: 499172) of binary: /usr/bin/python3
Traceback (most recent call last):
  File "/usr/local/bin/torchrun", line 7, in <module>
    sys.exit(main())
             ^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 357, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 936, in main
    run(args)
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 927, in run
    elastic_launch(
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 156, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 293, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
test_qwen_image.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2026-01-14_03:15:38
  host      : 10.189.108.254
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 499172)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

after this pr:

Qwen-Image Qwen-Image Ulysses-2 Qwen-Image-2512 Qwen-Image-2512 Ulysses-2
output_image output_image_ulysses2 output_image_2512 output_image_2512_ulysses2

@sayakpaul sayakpaul requested a review from kashif January 13, 2026 13:33
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

self.gradient_checkpointing = False
self.zero_cond_t = zero_cond_t

# Make CP plan compatible with zero_cond_t
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is a good design.

We could either modify

for module_id, cp_model_plan in plan.items():

And define this at the top of the model in _cp_plan directly or just instruct the users to pass an appropriate CP plan in the parallel config.

@DN6 what are your thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved modulate_index plan to top level, it seems working as expected for both Qwen-Edit-2509 and Qwen-Edit-2511.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. If possible, could you test other QwenImage family checkpoints as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. If possible, could you test other QwenImage family checkpoints as well?

Already tested on Qwen-Image, Qwen-Image-2512, Qwen-Image-Edit-2509, Qwen-Image-Edit-2511 (the Qwen-Image-Linghtning series shared the same pipelines and transformers as Qwen-Image and Qwen-Image-Edit, so, it should be work). I can make some additional tests for Qwen-Image-Layered.

@DefTruth DefTruth changed the title make qwen-image cp_plan compatible with txt_seq_lens fix Qwen-Image series context parallel Jan 14, 2026
@DefTruth DefTruth marked this pull request as ready for review January 14, 2026 03:27
@DefTruth
Copy link
Contributor Author

@sayakpaul @kashif @DN6 Hi~ I provided more details and explanations, as well as many test results. PTAL~

@sayakpaul
Copy link
Member

#12974 should fix the import tests.

@DefTruth
Copy link
Contributor Author

DefTruth commented Jan 14, 2026

@sayakpaul Additional tests for Qwen-Image-Layered, PTAL~

Qwen-Image-Layered

import torch
from diffusers import QwenImageLayeredPipeline
import torch.distributed as dist
from diffusers import ContextParallelConfig
from diffusers.utils import load_image


if dist.is_available():
    dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    device = torch.device("cuda", rank % torch.cuda.device_count())
    world_size = dist.get_world_size()
    torch.cuda.set_device(device)
else:
    rank = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    world_size = 1

model_id = "Qwen/Qwen-Image-Layered"

pipe = QwenImageLayeredPipeline.from_pretrained(
    model_id, 
    torch_dtype=torch.bfloat16,
)

pipe.enable_model_cpu_offload(device=device)

pipe.transformer.set_attention_backend("native")
if world_size > 1:
    from diffusers import QwenImageTransformer2DModel
    assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
    pipe.transformer.enable_parallelism(
        config=ContextParallelConfig(ulysses_degree=world_size)
    )

pipe.set_progress_bar_config(disable=rank != 0)
image = load_image("https://github.com/vipshop/cache-dit/raw/main/examples/data/yarn-art-pikachu.png").convert("RGBA")

inputs = {
    "image": image,
    "prompt": "",
    "num_inference_steps": 50,
    "true_cfg_scale": 4.0,
    "layers": 4,
    "resolution": 640,
    "cfg_normalize": False,
    "use_en_prompt": True,
    "generator": torch.Generator(device="cpu").manual_seed(0),
}

with torch.inference_mode():
    output = pipe(**inputs)
    images = output.images[0]
    if world_size > 1:
        save_prefix = f"output_image_layered_ulysses{world_size}"
    else:
        save_prefix = f"output_image_layered"
    if rank == 0:
        for idx, output_image in enumerate(images):
            save_path = f"{save_prefix}_layer{idx}.png"
            output_image.save(save_path)
            print(f"image saved at {save_path}")

if dist.is_initialized():
    dist.destroy_process_group()

test cmds:

torchrun --nproc_per_node=1 --local-ranks-filter=0 test_qwen_layered.py # baseline
torchrun --nproc_per_node=2 --local-ranks-filter=0 test_qwen_layered.py # cp2

before this pr:

torchrun --nproc_per_node=2 test_qwen_layered.py
W0114 05:27:20.787000 1079300 torch/distributed/run.py:803]
W0114 05:27:20.787000 1079300 torch/distributed/run.py:803] *****************************************
W0114 05:27:20.787000 1079300 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0114 05:27:20.787000 1079300 torch/distributed/run.py:803] *****************************************
Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.
Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.
Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.
Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 32.52it/s]
Loading pipeline components...:  33%|████████████████████████████████████████████████████████▋                                                                                                                 | 2/6 [00:00<00:00,  5.27it/s]The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release.
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 33.27it/s]
Loading pipeline components...:  50%|█████████████████████████████████████████████████████████████████████████████████████                                                                                     | 3/6 
  0%|                                                                                                                                                                                                                 | 0/50 [00:00<?, ?it/s][rank1]: Traceback (most recent call last):
[rank1]:   File "/workspace/dev/vipshop/cache-dit/examples/tmp/test_qwen_layered.py", line 52, in <module>
[rank1]:     output = pipe(**inputs)
[rank1]:              ^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank1]:     return func(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py", line 801, in __call__
[rank1]:     noise_pred = self.transformer(
[rank1]:                  ^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank1]:     output = function_reference.forward(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/accelerate/hooks.py", line 175, in new_forward
[rank1]:     output = module._old_forward(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/models/transformers/transformer_qwenimage.py", line 933, in forward
[rank1]:     image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device)
[rank1]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/hooks.py", line 190, in new_forward
[rank1]:     return function_reference.post_forward(module, output)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 199, in post_forward
[rank1]:     current_output = self._prepare_cp_input(current_output, cpm)
[rank1]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 211, in _prepare_cp_input
[rank1]:     return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 261, in shard
[rank1]:     assert tensor.size()[dim] % mesh.size() == 0, (
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

after this pr:

Layer 0 Layer 1 Layer 2 Layer 3
output_image_layered_layer0 output_image_layered_layer1 output_image_layered_layer2 output_image_layered_layer3
CP2 Layer 0 CP2 Layer 1 CP2 Layer 2 CP2 Layer 3
output_image_layered_ulysses2_layer0 output_image_layered_ulysses2_layer1 output_image_layered_ulysses2_layer2 output_image_layered_ulysses2_layer3

@kashif
Copy link
Contributor

kashif commented Jan 14, 2026

thanks will test it out and review today

@DefTruth
Copy link
Contributor Author

DefTruth commented Jan 15, 2026

thanks will test it out and review today

Hi~ Sorry to bother you~ May I ask if you have tested this PR? I think it would be better for us to fix this CP crash issue.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu
Copy link
Collaborator

Is the issue caused by inferring shape like this not working with CP?

batch_size, text_seq_len = encoder_hidden_states.shape[:2]

We have this pattern everywhere in Diffusers though

@DefTruth
Copy link
Contributor Author

Is the issue caused by inferring shape like this not working with CP?

batch_size, text_seq_len = encoder_hidden_states.shape[:2]

We have this pattern everywhere in Diffusers though

Yes. Therefore, many models may need to do some additional work to support CP, or use a block-level CP plan, because len inferring usually occurs outside the block forward loop.

@sayakpaul
Copy link
Member

If that's the case, we could check if parallelism is enabled and derive the shapes accordingly?

@DefTruth
Copy link
Contributor Author

DefTruth commented Jan 15, 2026

However, in addition to this, the standard Ulysses still has the problem that the sequence length needs to be divisible by the number of devices. For example, the CP-2 cases for Qwen-Image will fail if we increase the number of devices to 4.

Our Ulysses Anything Attention can perfectly solve this problem (support any sequence length and any head num). When I have free time, I can submit a PR to support Ulysses Anything Attention in Diffusers.

https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/attention/_templated_ulysses.py#L398

@sayakpaul
Copy link
Member

Our Ulysses Anything Attention can perfectly solve this problem (support any sequence length and any head num). When I have free time, I can submit a PR to support Ulysses Anything Attention in Diffusers.

Oh it would be great to have it as a CP backend. I think you mentioned it before!

@DefTruth
Copy link
Contributor Author

Our Ulysses Anything Attention can perfectly solve this problem (support any sequence length and any head num). When I have free time, I can submit a PR to support Ulysses Anything Attention in Diffusers.

Oh it would be great to have it as a CP backend. I think you mentioned it before!

Yes, and we have already verified its feasibility on many models.

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes look good to me 👍🏽. Thanks @DefTruth

@sayakpaul sayakpaul merged commit 7f43cb1 into huggingface:main Jan 15, 2026
9 of 11 checks passed
@sayakpaul
Copy link
Member

@DefTruth we want to provide you with a Diffusers MVP status because of your thoughtful discussions and PRs. Please generate certificate from here https://huggingface.co/spaces/diffusers/generate-mvp-certificate (mention this PR number).

Could you also let us know the HF Hub username so that we can grant some credits?

@DefTruth
Copy link
Contributor Author

@DefTruth we want to provide you with a Diffusers MVP status because of your thoughtful discussions and PRs. Please generate certificate from here https://huggingface.co/spaces/diffusers/generate-mvp-certificate (mention this PR number).

Could you also let us know the HF Hub username so that we can grant some credits?

cool~

@DefTruth
Copy link
Contributor Author

@sayakpaul my HF Hub username is DefTruth at https://huggingface.co/DefTruth

@DN6 DN6 added the roadmap Add to current release roadmap label Jan 15, 2026
@DefTruth
Copy link
Contributor Author

image

I encountered some errors when I tried to generate the certificate

@sayakpaul
Copy link
Member

Could you try again?

@DefTruth
Copy link
Contributor Author

Could you try again?

done, hhhhh, the certificate looks really cool~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

close-to-merge roadmap Add to current release roadmap

Projects

Development

Successfully merging this pull request may close these issues.

6 participants