bugfix: fix chrono-edit context parallel#12660
Conversation
|
@sayakpaul @yiyixuxu @DN6 Hi~ can you take a look to this PR? |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Removed unnecessary comments regarding parallelization in cross-attention.
done |
|
@bot /style |
|
Style fix is beginning .... View the workflow run here. |
|
@DefTruth could you run |
|
done |
|
Ah. Issue is with Copied from in the attention processor. @DefTruth would you mind also applying the change to the Wan Attn Processor (it should also be valid since it would also experience the same issue with cross attention) |
|
@DN6 I haven't fully tested the WAN model yet. I'll hold off on submitting the PR until the testing is done — this way we can make sure we don't break the existing functionality. |
|
@DefTruth Could you then remove the It's the reason why the QC checks aren't passing |
Done, rewrite 'Copied from' -> 'modified from' |
|
Thank you @DefTruth 🙏🏽 |
I will also test the Wan I2V model. If they have the same problem, I will submit a PR for repair. |
|
diffusers/src/diffusers/pipelines/wan/pipeline_wan_i2v.py Lines 677 to 700 in dde8754 since only wan 2.1 i2v transformer accepts image_embeds (ChronoEdit will always accepts image_embeds), i did not came across the same crash while using wan 2.2 i2v. |
fixed #12661, fix the crash of ChronoEdit with context parallelism.
We need to disable the splitting of encoder_hidden_states because the image_encoder consistently generates 257 tokens for image_embed. This causes the shape of encoder_hidden_states—whose token count is always 769 (512 + 257) after concatenation—to be indivisible by the number of devices in the CP.
Since the key/value in cross-attention depends solely on encoder_hidden_states (text or img), the (q_chunk * k) * v computation can be parallelized independently. Thus, there is no need to pass the parallel_config for cross-attention. This change reduces redundant all-to-all communications—specifically (3+1)×2=8 for the two cross-attention operations (text and img)—thereby improving ChronoEdit’s performance under context parallelism. With this optimization alone, I have achieved a nearly 1.85× speedup on L20x2, without relying on other optimizations such as torch.compile.
@sayakpaul @yiyixuxu @DN6
Reproduce