Skip to content

[feat] JoyAI-JoyImage-Edit support#13444

Merged
yiyixuxu merged 46 commits intohuggingface:mainfrom
Moran232:joyimage_edit
May 7, 2026
Merged

[feat] JoyAI-JoyImage-Edit support#13444
yiyixuxu merged 46 commits intohuggingface:mainfrom
Moran232:joyimage_edit

Conversation

@Moran232
Copy link
Copy Markdown
Contributor

@Moran232 Moran232 commented Apr 10, 2026

Description

We are the JoyAI Team, and this is the Diffusers implementation for the JoyAI-Image-Edit model.

GitHub Repository: [https://github.com/jd-opensource/JoyAI-Image]
Hugging Face Model: [https://huggingface.co/jdopensource/JoyAI-Image-Edit-Diffusers]
Original opensource weights [https://huggingface.co/jdopensource/JoyAI-Image-Edit]
Fixes #13430

Model Overview

JoyAI-Image is a unified multimodal foundation model for image understanding, text-to-image generation, and instruction-guided image editing. It combines an 8B Multimodal Large Language Model (MLLM) with a 16B Multimodal Diffusion Transformer (MMDiT).

Kye Features

  • Advanced Text Rendering Showcase: JoyAI-Image is optimized for challenging text-heavy scenarios, including multi-panel comics, dense multi-line text, multilingual typography, long-form layouts, real-world scene text, and handwritten styles.
  • Multi-view Generation and Spatial Editing Showcase: JoyAI-Image showcases a spatially grounded generation and editing pipeline that supports multi-view generation, geometry-aware transformations, camera control, object rotation, and precise location-specific object editing. Across these settings, it preserves scene content, structure, and visual consistency while following viewpoint-sensitive instructions more accurately.
  • Spatial Editing for Spatial Reasoning Showcase: JoyAI-Image poses high-fidelity spatial editing, serving as a powerful catalyst for enhancing spatial reasoning. Compared with Qwen-Image-Edit and Nano Banana Pro, JoyAI-Image-Edit synthesizes the most diagnostic viewpoints by faithfully executing camera motions. These high-fidelity novel views effectively disambiguate complex spatial relations, providing clearer visual evidence for downstream reasoning.

Image edit examples

spatial-editing-showcase

@github-actions github-actions Bot added models pipelines size/L PR with diff > 200 LOC labels Apr 10, 2026
Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the PR! I left some initial feedbacks

Comment thread src/diffusers/models/transformers/transformer_joyimage.py Outdated
return x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))


class JoyImageEditTransformer3DModel(JoyImageTransformer3DModel):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh what's going on here? is this some legancy code? can we remove?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We first developed JoyImage, and then trained JoyImage-Edit based on it. This Transformer 3D model belongs to JoyImage, and JoyImage-Edit is inherited from JoyImage. We will also open-source JoyImage in the future.

They essentially share similar Transformer 3D models. I understand that each pipeline requires a specific Transformer model, which is why we implemented inheritance in this way.

Comment on lines +371 to +391
img_qkv = self.img_attn_qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
if vis_freqs_cis is not None:
img_q, img_k = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False)

txt_modulated = modulate(self.txt_norm1(txt), shift=txt_mod1_shift, scale=txt_mod1_scale)
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
if txt_freqs_cis is not None:
txt_q, txt_k = apply_rotary_emb(txt_q, txt_k, txt_freqs_cis, head_first=False)

q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
v = torch.cat((img_v, txt_v), dim=1)

attn = attention(q, k, v, attn_kwargs=attn_kwargs).flatten(2, 3)
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
img_qkv = self.img_attn_qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
if vis_freqs_cis is not None:
img_q, img_k = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False)
txt_modulated = modulate(self.txt_norm1(txt), shift=txt_mod1_shift, scale=txt_mod1_scale)
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
if txt_freqs_cis is not None:
txt_q, txt_k = apply_rotary_emb(txt_q, txt_k, txt_freqs_cis, head_first=False)
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
v = torch.cat((img_v, txt_v), dim=1)
attn = attention(q, k, v, attn_kwargs=attn_kwargs).flatten(2, 3)
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
attn_output, text_attn_output = self.attn(...)

can we refactor the attention implementation to follow diffusers style?
basically you need to move all the layers used in attention calculation here into a JoyImageAttention (similar to FluxAttention https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L275)

also create a JoyImageAttnProcessor (see FluxAttnProcessor as example, I think it is same) https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L75 )

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reminder. I'll clean up this messy code.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix in d397b68

Comment thread src/diffusers/models/transformers/transformer_joyimage.py Outdated
Comment on lines +242 to +250
class ModulateX(nn.Module):
def __init__(self, hidden_size: int, factor: int, dtype=None, device=None):
super().__init__()
self.factor = factor

def forward(self, x: torch.Tensor):
if len(x.shape) != 3:
x = x.unsqueeze(1)
return [o.squeeze(1) for o in x.chunk(self.factor, dim=1)]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class ModulateX(nn.Module):
def __init__(self, hidden_size: int, factor: int, dtype=None, device=None):
super().__init__()
self.factor = factor
def forward(self, x: torch.Tensor):
if len(x.shape) != 3:
x = x.unsqueeze(1)
return [o.squeeze(1) for o in x.chunk(self.factor, dim=1)]

Comment on lines +214 to +225
class ModulateDiT(nn.Module):
def __init__(self, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.factor = factor
self.act = act_layer()
self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)

def forward(self, x: torch.Tensor):
return self.linear(self.act(x)).chunk(self.factor, dim=-1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class ModulateDiT(nn.Module):
def __init__(self, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.factor = factor
self.act = act_layer()
self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
def forward(self, x: torch.Tensor):
return self.linear(self.act(x)).chunk(self.factor, dim=-1)

is ModulateWan is one used in the model? if so let's remove the ModulateDit and ModulateX

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix in f557113

head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)

self.img_mod = load_modulation(self.dit_modulation_type, hidden_size, 6, **factory_kwargs)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.img_mod = load_modulation(self.dit_modulation_type, hidden_size, 6, **factory_kwargs)
self.img_mod = JoyImageModulate(...)

let's remove the load_modulation function and use the layer directly, better to rename to JoyImageModulate too

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will refactor modulation and use ModulateWan

tacos8me added a commit to tacos8me/taco-desktop-backend that referenced this pull request Apr 11, 2026
New `model="joyai-edit"` on /v1/image-edit and /v2/image-edit, routed to a
separate FastAPI sidecar on 127.0.0.1:8092 that runs JoyImageEditPipeline
from the Moran232/diffusers fork + transformers 4.57.1. Process isolation
needed because the fork's diffusers core registry patches cannot be
vendored (PR huggingface/diffusers#13444 pending) and transformers 4.57.x
is incompatible with our 5.3.0 stack.

Phase 0 VRAM measurement: 50.3 GB resident, 65.5 GB peak reserved at
1024² / 30 steps (well under the 80 GB gate). Passed.

- `joyai_client.py` (NEW, 167 lines): thin httpx wrapper with per-call
  short-lived AsyncClient, split timeouts (180s edit / 60s mgmt),
  HTTPStatus→JoyAIError mapping. Singleton `joyai` exported.
- `config.py`: `JOYAI_SIDECAR_URL` (default http://127.0.0.1:8092) and
  `LOAD_JOYAI` env flag. Off by default.
- `server.py`: three-tenant swap protocol replaces the two-tenant v1.1.4
  helpers. New `_last_gpu_tenant` tracker + `_evict_other_tenants(new)`
  helper. All three `_ensure_*_ready()` helpers are now `async def` —
  13 call sites updated across _dispatch_job and v1 sync handlers.
  IMAGE_EDIT dispatch arm routes `model=="joyai-edit"` to joyai_client;
  validates len(image_paths)==1 (422 otherwise). Lifespan health-probes
  the sidecar when LOAD_JOYAI=1 (non-blocking — joyai-edit returns 503
  if unreachable).
- `flux_manager.py`: pre-existing bug fix — _edit() hardcoded
  ensure_model("flux2-klein"), silently ignoring the dispatcher's
  `model` kwarg. Now accepts and respects `model`. Guidance_scale
  is now conditional on model != "flux2-klein" (Klein strips CFG,
  Dev uses it).
- `tests/test_joyai_client.py` (NEW, 7 tests) + `tests/test_validation.py`
  (+3 tests): 89 tests passing (was 79).
- Docs: API.md, QUICKSTART.md, README.md, CLAUDE.md, AGENTS.md all
  updated with joyai-edit model entry, three-tenant swap diagram,
  latency table, sidecar location/port, LOAD_JOYAI env var, v1.1.8
  changelog entry.

Out-of-tree (not committed here, installed separately):
  /mnt/nvme-1/servers/joyai-sidecar/     (sidecar venv + sidecar.py + run.sh)
  ~/.config/systemd/user/joyai-sidecar.service

Smoke-tested end-to-end: upload → /v2/image-edit joyai-edit →
SSE stream (phase denoising → encoding → None) → fetch WEBP result
(352 KB, 91 s wall clock for 20 steps at 1024²). Three-tenant swap
evicted LTX and reloaded it cleanly via _evict_other_tenants.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Comment thread src/diffusers/models/transformers/transformer_joyimage.py Outdated
Comment thread src/diffusers/models/transformers/transformer_joyimage.py Outdated
Comment on lines +454 to +459
self.args = SimpleNamespace(
enable_activation_checkpointing=enable_activation_checkpointing,
is_repa=is_repa,
repa_layer=repa_layer,
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.args = SimpleNamespace(
enable_activation_checkpointing=enable_activation_checkpointing,
is_repa=is_repa,
repa_layer=repa_layer,
)

I think we can use self.config here (e.g. self.config.is_repa, self.config.repa_layer, etc.) instead of needing to define a separate namespace.

Copy link
Copy Markdown
Contributor Author

@Moran232 Moran232 Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I delete this repaxxx, see f557113

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was the repa logic removed because it is not used in inference?

Comment thread src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py Outdated
Comment on lines +900 to +901
timesteps: List[int] = None,
sigmas: List[float] = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
timesteps: List[int] = None,
sigmas: List[float] = None,
timesteps: list[int] | None = None,
sigmas: list[float] | None = None,

nit: could we switch to Python 3.9+ style implicit type hints here and elsewhere?

Comment thread src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py Outdated
Comment thread src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py Outdated
Comment thread src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py Outdated
Comment thread src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py Outdated
Copy link
Copy Markdown
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Left an initial design review :).

@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 14, 2026
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 14, 2026
@Moran232
Copy link
Copy Markdown
Contributor Author

@yiyixuxu @dg845
Thank you very much for your valuable feedback. I've made some modifications. See my latest commits.

Specifically, I refactored the attention module. However, since the weight key names in the Diffusers model are already fixed, I didn't change the actual keys in the attention part. Additionally, I will consider refactoring the image pre-processing logic, since the logic is quite complex, I directly copied it over from the training code.

If you have any further suggestions, please feel free to share. Thank you so much!

# ---- joint attention (fused QKV, directly on the block) ----
# image attention layers
self.img_attn_qkv = nn.Linear(dim, inner_dim * 3, bias=True)
self.img_attn_q_norm = nn.RMSNorm(attention_head_dim, eps=eps)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly, the attention sublayer used to use the custom RMSNorm module, which upcasted to FP32 during the RMS computation. Here we're using torch.nn.RMSNorm, which doesn't. Is this intentional?

Comment thread src/diffusers/models/transformers/transformer_joyimage.py Outdated
Comment thread src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py Outdated
Comment thread setup.py Outdated
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 1, 2026
@yiyixuxu
Copy link
Copy Markdown
Collaborator

yiyixuxu commented May 1, 2026

@claude can you do a final review too?

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 1, 2026

Claude finished @yiyixuxu's task in 3m 24s —— View job


Final Review of JoyAI-JoyImage-Edit PR

  • Read repository guidelines (CLAUDE.md, models.md, pipelines.md)
  • Review transformer model (transformer_joyimage.py)
  • Review image processor (image_processor.py)
  • Review pipeline (pipeline_joyimage_edit.py)
  • Review pipeline output and init/export files
  • Review tests
  • Check outstanding reviewer comments
  • Post final review

The PR has improved significantly since the last round. Many of the earlier issues (einops removal, attention refactor, gradient checkpointing, assert→ValueError, etc.) have been addressed. Below are the remaining issues I found, organized by severity. I've also noted which items align with outstanding reviewer comments from @yiyixuxu and @dg845 that still need resolution.


Must Fix

1. _keep_in_fp32_modules references non-existent modulestransformer_joyimage.py:381-387

_keep_in_fp32_modules = [
    "time_embedder",
    "scale_shift_table",  # ← does not exist in this model
    "norm1",
    "norm2",
    "norm3",              # ← does not exist in this model
]

scale_shift_table and norm3 are WAN concepts copied over but not present in JoyImageEditTransformer3DModel. These entries are harmless (they just match nothing) but are misleading. Update to match the actual module names in this model.

2. __init__.py export orderingsrc/diffusers/__init__.py:584-586

"JoyImageEditPipeline",
"Kandinsky3Img2ImgPipeline",
"JoyImageEditPipelineOutput",  # ← out of alphabetical order

JoyImageEditPipelineOutput should come immediately after JoyImageEditPipeline, before Kandinsky3Img2ImgPipeline. This will likely be caught by make quality.

3. enable_tiling still in docstringpipeline_joyimage_edit.py:698-699

            enable_tiling (`bool`, *optional*, defaults to `False`):
                Enable tiled VAE decoding to reduce peak memory usage.

This parameter does not exist in __call__. Remove from docstring.

4. **kwargs in __call__ silently swallows argumentspipeline_joyimage_edit.py:649

yiyixuxu flagged this. **kwargs is accepted but never used — it silently discards any mistyped keyword arguments. Remove it.


Should Fix (Outstanding reviewer comments)

5. encode_prompt_multiple_images should be called from __call__ directly — yiyixuxu's comment

Currently encode_prompt internally routes to encode_prompt_multiple_images when images is not None. yiyixuxu asked to move the routing to __call__:

# In __call__:
if processed_image is None:
    ... = self.encode_prompt(...)
else:
    ... = self.encode_prompt_multiple_images(...)

Additionally, encode_prompt_multiple_images should accept pre-computed prompt_embeds/prompt_embeds_mask and handle num_images_per_prompt expansion, similar to encode_prompt.

6. Timestep dtype inference via weight introspectiontransformer_joyimage.py:356-358, yiyixuxu's comment

time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
    timestep = timestep.to(time_embedder_dtype)

Per models.md gotcha #4: "don't cast activations by reading a weight's dtype — the stored weight dtype isn't the compute dtype under gguf / quantized loading." yiyixuxu suggested simplifying to:

timestep = timestep.to(encoder_hidden_states.dtype)
temb = self.time_embedder(timestep)

7. nn.RMSNorm vs diffusers RMSNormtransformer_joyimage.py:207-208, 212-213

The attention sublayer QK norms use nn.RMSNorm which does not upcast to FP32 during computation. The diffusers RMSNorm (from normalization.py) does upcast. dg845 flagged this earlier. If the original model's training used FP32 upcast for QK norms, this could cause numerical drift. Please verify whether the original training code upcasts, and use the appropriate norm class.


Nice to Fix

8. Two forward passes instead of batched CFGpipeline_joyimage_edit.py:823-836

The denoising loop makes two separate transformer forward passes (conditional + unconditional) instead of batching them together. This is ~2x slower than the standard diffusers pattern:

latent_model_input = torch.cat([latents] * 2)
prompt_concat = torch.cat([negative_prompt_embeds, prompt_embeds])
noise_pred = self.transformer(latent_model_input, ..., prompt_concat, ...)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

I understand this may be intentional due to the multi-item latent structure, but it's worth considering for performance.

9. Always-on guidance rescalingpipeline_joyimage_edit.py:839-842

cond_norm = torch.norm(noise_pred, dim=2, keepdim=True)
noise_norm = torch.norm(comb_pred, dim=2, keepdim=True)
noise_pred = comb_pred * (cond_norm / noise_norm.clamp_min(1e-6))

This norm-based guidance rescaling is always applied during CFG. Most diffusers pipelines make this optional via a guidance_rescale parameter. If this is an intentional design choice for JoyImage, a brief comment explaining why would help readers.

10. Mixed typing stylepipeline_joyimage_edit.py:2, 620-649

The pipeline mixes old-style typing (Dict, List, Optional, Union) and modern style (str | list[str], int | None). Similarly, the transformer uses from typing import Tuple instead of tuple. Should be consistent — prefer Python 3.9+ builtins per dg845's earlier suggestion.

11. Tuple in transformertransformer_joyimage.py:17

from typing import Tuple — replace all Tuple[...] with tuple[...] for Python 3.9+ consistency.

12. image_processor.py bucket validationimage_processor.py:57-63

find_best_bucket will KeyError if basesize != 1024. Add a check:

if basesize not in BUCKETS:
    raise ValueError(f"Unsupported basesize {basesize}. Supported: {list(BUCKETS.keys())}")

13. @register_to_config on VaeImageProcessor subclassimage_processor.py:91

VaeImageProcessor.__init__ already uses register_to_config. Double-decorating in the subclass may cause issues with config serialization. Verify this works correctly with from_config.


Style / CI

14. make style and make quality needed — dg845 asked for this. The style bot doesn't seem to be working. The authors should run make style && make quality && make fix-copies locally and push the result.


Tests

15. test_from_save_pretrained_dtype_inference unconditionally skippedtest_models_transformer_joyimage.py:90-91

pytest.skip("Tolerance requirements too high for meaningful test")

This unconditionally skips a core test. If there are tolerance issues, investigate the root cause or use @pytest.mark.xfail with a reason so it's tracked. pytest.skip hides the problem entirely.

16. Offloading test failures — dg845 identified that test_group_offloading_inference, test_pipeline_level_group_offloading_inference, test_sequential_cpu_offload_forward_pass, and test_sequential_offload_forward_pass_twice fail due to Qwen3VL incompatibility with leaf-level offloading. These need to be either overridden (as dg845 suggested) or skipped with clear documentation.


Summary

The architecture is clean — JoyImageAttention + JoyImageAttnProcessor with dispatch_attention_fn follows the recommended pattern well. Gradient checkpointing, _repeated_blocks, and _no_split_modules are all properly set. The _get_last_decoder_hidden_states hook for transformers 5.x compat is a pragmatic solution.

The main outstanding items are yiyixuxu's latest review comments (#5, #6, #4), the stale _keep_in_fp32_modules entries (#1), and running make style && make quality (#14). Once those are addressed, this should be ready to merge.


@yiyixuxu
Copy link
Copy Markdown
Collaborator

yiyixuxu commented May 1, 2026

ohh calude did a good job here, can you look into fixing the #1 and #3 listed in the claude review?

@yiyixuxu
Copy link
Copy Markdown
Collaborator

yiyixuxu commented May 1, 2026

@claude are you able to do a dead code analysis since it's a new model here?
(also, can you let me know why you did not include this in the review already? we listed in the review-fules.md)

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 1, 2026

Claude finished @yiyixuxu's task in 2m 9s —— View job


Dead Code Analysis for JoyImage-Edit Model

  • Read review rules to understand the dead code analysis requirements
  • Read the pipeline __call__ and trace all calls into the transformer
  • Read the transformer model __init__ and forward
  • Read the published model config (jdopensource/JoyAI-Image-Edit-Diffusers)
  • Cross-reference: identify unused parameters, methods, and unreachable code paths
  • Explain why previous reviews missed this analysis
  • Post final dead code analysis report

Why this wasn't in the previous reviews

The review rules in .ai/review-rules.md explicitly require a dead code analysis section for PRs that add new models ("Trace the call path... Check the default model config... Flag unused parameters and methods"). My two earlier reviews focused on correctness, API design, and convention issues but failed to follow the dead-code-analysis checklist. That was an oversight on my part — the rule was there, I just didn't execute it. Sorry about that.


Call path trace

Pipeline __call__ (lines 619–880) calls into the transformer at line 823:

noise_pred = self.transformer(
    hidden_states=latent_model_input,   # (B, num_items, C, T, H, W) — 6D when editing
    timestep=t_expand,
    encoder_hidden_states=prompt_embeds,
    return_dict=False,
)[0]

Only three arguments are ever passed: hidden_states, timestep, encoder_hidden_states, and return_dict. The transformer's forward signature is:

def forward(self, hidden_states, timestep, encoder_hidden_states=None, return_dict=True):

All four parameters are used. Good — no unused forward parameters.

Published config (jdopensource/JoyAI-Image-Edit-Diffusers/transformer/config.json):

{
  "hidden_size": 4096,
  "num_attention_heads": 32,
  "rope_type": "rope",
  "rope_dim_list": [16, 56, 56],
  "patch_size": [1, 2, 2],
  "theta": 10000
}

Findings (advisory — not blocking)

Under the published config and the pipeline's call path, the following code appears unreachable or unused:

1. Text RoPE branch is dead codetransformer_joyimage.py:561

txt_rope_size=txt_seq_len if self.rope_type == "mrope" else None,

The published config has rope_type = "rope", so txt_rope_size is always None. This means:

  • get_rotary_pos_embed always takes the early return at line 490–491 (if txt_rope_size is None: return vis_freqs, None)
  • The entire text-RoPE computation (lines 494–501) is never executed
  • In JoyImageAttnProcessor.__call__, the text RoPE branch at line 143–144 (if txt_freqs is not None) is never taken — text tokens get no RoPE

This is likely intentional (the model was trained with rope_type="rope"), but the "mrope" code path is untested dead code.

2. rope_dim_list is None fallback is unreachabletransformer_joyimage.py:466-467

if rope_dim_list is None:
    rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]

The published config sets rope_dim_list = [16, 56, 56], and the __init__ default is also [16, 56, 56]. There is no config where rope_dim_list would be None. This fallback is defensive but unreachable.

3. _keep_in_fp32_modules references non-existent modulestransformer_joyimage.py:381-387

_keep_in_fp32_modules = [
    "time_embedder",       # ✓ exists (inside condition_embedder.time_embedder)
    "scale_shift_table",   # ✗ does NOT exist anywhere in the model
    "norm1",               # ✓ exists (img_norm1, txt_norm1 in blocks)
    "norm2",               # ✓ exists (img_norm2, txt_norm2 in blocks)
    "norm3",               # ✗ does NOT exist anywhere in the model
]

scale_shift_table and norm3 are WAN model concepts that were copied but don't exist in JoyImageEditTransformer3DModel. These entries silently match nothing — harmless but misleading.

4. num_items = 0 initialization — semantically misleadingtransformer_joyimage.py:532

num_items = 0
if is_multi_item:
    num_items = hidden_states.shape[1]

When is_multi_item is False, num_items remains 0. The un-multi-item block at line 581 (if is_multi_item) won't execute in that case, so num_items = 0 is never used meaningfully. However, if is_multi_item were accidentally set while num_items = 0, it would cause a silent reshape bug. Consider leaving num_items unset until needed.

5. Pipeline: template_type="image" passed to encode_prompt is overridden when images are presentpipeline_joyimage_edit.py:757

prompt_embeds, prompt_embeds_mask = self.encode_prompt(
    ...
    images=processed_image,
    ...
    template_type="image",   # <-- this value is ignored when images is not None
)

When images is not None, encode_prompt immediately delegates to encode_prompt_multiple_images (line 386-392), which hard-codes template_type="multiple_images". The template_type="image" argument passed from __call__ is silently discarded in the editing path.

6. _skip_layerwise_casting_patterns includes "condition_embedder" indirectly via "norm"transformer_joyimage.py:378

_skip_layerwise_casting_patterns = ["img_in", "condition_embedder", "norm"]

The "norm" pattern matches norm_out, img_norm1, img_norm2, txt_norm1, txt_norm2, and also the RMSNorm QK norms inside attention (img_attn_q_norm, etc.). This is a very broad match. It also matches condition_embedder which contains the time embedder — this is likely intentional but means most of the model is excluded from layerwise casting.

7. vec.shape[-1] > self.hidden_size check in forwardtransformer_joyimage.py:553-554

if vec.shape[-1] > self.hidden_size:
    vec = vec.unflatten(1, (6, -1))

JoyImageTimeTextImageEmbedding.forward returns temb from time_proj which has output dim hidden_size * 6 (set at init: time_proj_dim=hidden_size * 6). So vec.shape[-1] is always hidden_size * 6, which is always > self.hidden_size. This branch is always taken — the else path (where vec stays 1D) is dead code. The check could be removed and replaced with an unconditional unflatten.


Methods and layers that ARE used (confirming no false positives)

Component Called from Status
img_in (Conv3d) forward:549 ✓ Used
condition_embedder forward:552 ✓ Used
double_blocks forward:565-574 ✓ Used
norm_out + proj_out forward:577 ✓ Used
JoyImageModulate (img_mod, txt_mod) JoyImageTransformerBlock.forward:295,302 ✓ Used
JoyImageAttention (attn) JoyImageTransformerBlock.forward:311 ✓ Used
FeedForward (img_mlp, txt_mlp) JoyImageTransformerBlock.forward:325-326 ✓ Used
get_rotary_pos_embed forward:559 ✓ Used (vis path)
unpatchify forward:578 ✓ Used
_apply_rotary_emb JoyImageAttnProcessor:142-144 ✓ Used (vis path)

All initialized layers and helper methods are reachable from the pipeline call path. No dead layers or methods found.


Summary

The model is clean — all initialized layers and methods are used. The dead code is limited to:

None of these are bugs — they're defensive/copied code that happens to be unreachable under the published config. The authors may have configs or future models that exercise these paths.


@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 2, 2026
… support prompt_embeds/prompt_embesd_mask/num_images_per_prompt in edit mode
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 2, 2026
@feice-huang
Copy link
Copy Markdown
Contributor

feice-huang commented May 2, 2026

@dg845 @yiyixuxu
Thanks for the suggestions! Updates:

Refactor

  • Separate encode_prompt_multiple_images from encode_prompt. Dispatching is now handled in __call__ based on whether processed_image is provided, instead of branching inside encode_prompt.
  • Add prompt_embeds, prompt_embeds_mask, and num_images_per_prompt support to encode_prompt_multiple_images so it matches the encode_prompt interface.

Test updates

  • Skipp test_group_offloading_inference, test_pipeline_level_group_offloading_inference, test_sequential_cpu_offload_forward_pass, and test_sequential_offload_forward_pass_twice.

Not changed

  • We keep the current time_embedder dtype inference pattern (next(iter(...)).dtype). We tested simplifying it, but that causes dtype mismatch errors.

Comment thread tests/pipelines/joyimage/test_joyimage_edit.py Outdated
Comment thread tests/pipelines/joyimage/test_joyimage_edit.py
Comment thread tests/pipelines/joyimage/test_joyimage_edit.py Outdated
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 4, 2026
@feice-huang
Copy link
Copy Markdown
Contributor

@dg845 fix these in e8c4db7. hoping this resolves the CI failures.

Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, looking good to me
i left one question and one feedback; we will merge soon

Comment on lines +356 to +359
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
timestep = timestep.to(time_embedder_dtype)
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @feice-huang
can you share your testing script? we will merge now but want to see if we can refactor this pattern out seperately (not just in JoyAi but wan & others too)

Comment thread src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py
@feice-huang
Copy link
Copy Markdown
Contributor

feice-huang commented May 7, 2026

@yiyixuxu
We directly used the testing script from #13444 (comment). If you apply the changes below, you will hit the error mentioned in #13444 (comment).

Let us know if you need more details!

`transformer_joyimage.py` to cause error (directly applying your suggested change, I haven’t find a better way yet to avoid the error)
    def forward(
        self,
        timestep: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
    ):
        timestep = self.timesteps_proj(timestep)

        if timestep.dtype != encoder_hidden_states.dtype:
            timestep = timestep.to(encoder_hidden_states.dtype)
        temb = self.time_embedder(timestep)
        timestep_proj = self.time_proj(self.act_fn(temb))

        encoder_hidden_states = self.text_embedder(encoder_hidden_states)

        return temb, timestep_proj, encoder_hidden_states

@yiyixuxu yiyixuxu merged commit 1030249 into huggingface:main May 7, 2026
13 of 15 checks passed
@yiyixuxu
Copy link
Copy Markdown
Collaborator

yiyixuxu commented May 7, 2026

thanks a lot for the PR!

@feice-huang
Copy link
Copy Markdown
Contributor

thank you so much for your suggestions!

@sayakpaul
Copy link
Copy Markdown
Member

Hey folks,

@Moran232 @feice-huang docs is missing for JoyImage. Could you please add docs in a separate PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support for JoyAI-Image-Edit

6 participants