Skip to content

NucleusMoE-Image#13317

Merged
dg845 merged 46 commits intohuggingface:mainfrom
sippycoder:main
Apr 3, 2026
Merged

NucleusMoE-Image#13317
dg845 merged 46 commits intohuggingface:mainfrom
sippycoder:main

Conversation

@sippycoder
Copy link
Copy Markdown
Contributor

What does this PR do?

This PR introduces NucleusMoE-Image series into the diffusers library.

NucleusMoE-Image is a 2B active 17B parameter model trained with efficiency at its core. Our novel architecture highlights the scalability of sparse MoE architecture for Image generation. The technical report will be released very soon.

@sippycoder
Copy link
Copy Markdown
Contributor Author

cc: @sayakpaul @IlyasMoutawwakil

@sayakpaul sayakpaul requested review from dg845 and yiyixuxu March 24, 2026 04:08
Comment thread src/diffusers/models/transformers/transformer_nucleusmoe_image.py Outdated
Comment thread src/diffusers/models/transformers/transformer_nucleusmoe_image.py Outdated
Comment thread src/diffusers/models/transformers/transformer_nucleusmoe_image.py Outdated
Comment thread src/diffusers/models/transformers/transformer_nucleusmoe_image.py Outdated
Comment thread src/diffusers/models/transformers/transformer_nucleusmoe_image.py Outdated
Comment thread src/diffusers/models/transformers/transformer_nucleusmoe_image.py Outdated
Comment on lines +545 to +546
gate1 = gate1.clamp(min=-2.0, max=2.0)
gate2 = gate2.clamp(min=-2.0, max=2.0)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems weird to me that we first clamp the gates to [-2.0, 2.0] and then essentially clamp again by squashing with the tanh function below. Is this intended?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it's weird. :) I used it to stabilize the gradients if the tanh gates get saturated while training. I will evaluate the model performance without it and get back to you!

Comment thread src/diffusers/models/transformers/transformer_nucleusmoe_image.py Outdated
Comment thread src/diffusers/models/transformers/transformer_nucleusmoe_image.py Outdated
Comment thread src/diffusers/models/transformers/transformer_nucleusmoe_image.py Outdated
Comment thread src/diffusers/models/transformers/transformer_nucleusmoe_image.py Outdated
Comment thread src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py Outdated
Comment thread src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py Outdated
Comment thread src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py Outdated
sippycoder and others added 4 commits March 31, 2026 00:37
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
@dg845
Copy link
Copy Markdown
Collaborator

dg845 commented Mar 31, 2026

Hi @sippycoder, it doesn't look like my HF account (also dg845) has access to NucleusAI/NucleusMoE-Image, not sure if I am missing something. I get a 404 error if I try to access it.

@sippycoder
Copy link
Copy Markdown
Contributor Author

Hi @sippycoder, it doesn't look like my HF account (also dg845) has access to NucleusAI/NucleusMoE-Image, not sure if I am missing something. I get a 404 error if I try to access it.

Looks like I can't give you private repo access unless you are in my org. I just made the repo public! I didn't update the model page yet.

@dg845
Copy link
Copy Markdown
Collaborator

dg845 commented Mar 31, 2026

@bot /style

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Mar 31, 2026

Style bot fixed some files and pushed the changes.

logger = logging.get_logger(__name__)


# Copied from diffusers.models.transformers.transformer_qwenimage.apply_rotary_emb_qwen with qwen->nucleus
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you run make fix-copies to sync the implementation here with the QwenImage implementation (assuming the implementations are intended to be the same, which I believe is the case)?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

sippycoder and others added 2 commits March 31, 2026 18:30
@sippycoder
Copy link
Copy Markdown
Contributor Author

@yiyixuxu Any comments from you for the text_kv_cache hook?
@dg845 Do you think we are good to merge this PR?

@dg845
Copy link
Copy Markdown
Collaborator

dg845 commented Apr 2, 2026

@bot /style

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 2, 2026

Style bot fixed some files and pushed the changes.

@dg845
Copy link
Copy Markdown
Collaborator

dg845 commented Apr 2, 2026

Hi @sippycoder, I think this PR is close to merge, the remaining items should be:

  1. Confirm that the text KV cache design looks good
  2. Ensure that the MoE weights are correctly supported (e.g. NucleusMoE-Image #13317 (comment)). I think with the new changes they are probably good, CC @IlyasMoutawwakil to confirm.

Additionally, having docs would be nice, but this is not a hard blocker (we can add them in a follow-up PR if necessary).

@sippycoder
Copy link
Copy Markdown
Contributor Author

Hey @dg845 @yiyixuxu , we are releasing the model report tomorrow! Thanks for all the reviews! I think Expert Parallelism can be a separate PR. Any chance we can merge the PR today?

Comment thread src/diffusers/hooks/text_kv_cache.py Outdated
def __init__(self, state_manager: StateManager):
super().__init__()
self.state_manager = state_manager
self.kv_cache: dict[int, tuple[torch.Tensor, torch.Tensor]] = {}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After looking at the existing cache design more closely, I think kv_cache should not be owned by TextKVCacheBlockHook but rather refactored into its own BaseState subclass, which is what MagCache and TaylorSeer do (both store cached tensors in their state classes rather than on the hook).

This would mean we would have two state classes: one which holds the shared encoder_hidden_states tensor and one which holds the KV cache dict for each block:

# Same as before
class TextKVCacheState(BaseState):
    def __init__(self):
        self.key: int | None = None
    ...

# Holds the block-level KV cache
class TextKVCacheBlockState(BaseState):
    def __init__(self):
        self.kv_cache: dict[int, tuple[torch.Tensor, torch.Tensor]] = {}
    ...

# Same as before
class TextKVCacheTransformerHook(ModelHook):
    ...

class TextKVCacheBlockHook(ModelHook):
    # One state manager for shared transformer-level state, one for block-specific state
    def __init__(self, state_manager: StateManager, block_state_manager: StateManager):
        super().__init__()
        self.state_manager = state_manager
        self.block_state_manager = block_state_manager
    ...

This would allow us to manage each block-level KV Cache with a StateManager, which I think more cleanly follows the current design.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea that makes sense! I added a commit for this.

Copy link
Copy Markdown
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again for the PR! I think the items in #13317 (comment) should be resolved. We can add docs and handle any remaining issues in follow-up PRs.

@dg845
Copy link
Copy Markdown
Collaborator

dg845 commented Apr 3, 2026

Merging as the CI is green.

@dg845 dg845 merged commit 447e571 into huggingface:main Apr 3, 2026
11 checks passed
@q5sys
Copy link
Copy Markdown

q5sys commented Apr 13, 2026

Hey @dg845 @yiyixuxu , we are releasing the model report tomorrow! Thanks for all the reviews! I think Expert Parallelism can be a separate PR. Any chance we can merge the PR today?

@sippycoder When is this actually getting released? You said 'tomorrow' two weeks ago to get this PR merged, but your model still isn't released. I'm eager to try this.

Edit: it was posted the day after my comment

terarachang pushed a commit to terarachang/diffusers that referenced this pull request Apr 30, 2026
* adding NucleusMoE-Image model

* update system prompt

* Add text kv caching

* Class/function name changes

* add missing imports

* add RoPE credits

* Update src/diffusers/models/transformers/transformer_nucleusmoe_image.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/models/transformers/transformer_nucleusmoe_image.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/models/transformers/transformer_nucleusmoe_image.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/models/transformers/transformer_nucleusmoe_image.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* update defaults

* Update src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* review updates

* fix the tests

* clean up

* update apply_text_kv_cache

* SwiGLUExperts addition

* fuse SwiGLUExperts up and gate proj

* Update src/diffusers/hooks/text_kv_cache.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/hooks/text_kv_cache.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/hooks/text_kv_cache.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/hooks/text_kv_cache.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/models/transformers/transformer_nucleusmoe_image.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/models/transformers/transformer_nucleusmoe_image.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* _SharedCacheKey -> TextKVCacheState

* Apply style fixes

* Run python utils/check_copies.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite

* Apply style fixes

* run `make fix-copies`

* fix import

* refactor text KV cache to be managed by StateManager

---------

Co-authored-by: Murali Nandan Nagarapu <nmn@withnucleus.ai>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants