Skip to content

batched and grouped experts implementations#42697

Merged
IlyasMoutawwakil merged 74 commits intomainfrom
moe-imp
Jan 5, 2026
Merged

batched and grouped experts implementations#42697
IlyasMoutawwakil merged 74 commits intomainfrom
moe-imp

Conversation

@IlyasMoutawwakil
Copy link
Member

@IlyasMoutawwakil IlyasMoutawwakil commented Dec 8, 2025

What does this PR do?

I have started experimenting with pure pytorch MoE implementations following the HF exporters PR while trying to find a traceable/exportable variant for onnx/openvino.

In this PR I copy the attn_implementation API into a similar experts_implementation API, and added two new implementations:

  • batched_mm (the exportable one) which uses torch.bmm, is fastest on single batch size / small inputs.
  • grouped_mm (the pytorch custom kernel one) inspired from torchtitan's moe imp (using torch._grouped_mm), which is generally fastest.

benchmark

An initial benchmark shows promising results on (A100), I know that the torch._grouped_mm uses bfloat16 or something under the hood, so these might not be apples to apples (i'm still looking for more references on this function and how to use it "equivalently")

MoE Implementations Benchmark

Benchmark script: bench.py

It uses qwen2_moe ("Qwen/Qwen1.5-MoE-A2.7B", bfloat16) where latency and memory are for the forward pass / prefill

TLDR; for very small inputs batched_mm can be extremely fast and even faster with compilation, for bigger inputs grouped_mm is unbeatable but it doesn't seem to get much faster with torch compilation.

Batch Size Seq Length Torch Compile Implementation Mean Latency (ms) Median Latency (ms) P90 Latency (ms) Peak Mem (MB)
1 16 False eager 271.80 272.94 295.34 27324.65
1 16 True eager 351.86 351.64 384.64 27329.29
1 16 max-autotune-no-cudagraphs eager 352.52 352.15 382.79 27329.29
1 16 False batched_mm 52.03 52.07 52.67 28382.50
1 16 True batched_mm 53.04 53.04 53.11 28029.63
1 16 max-autotune-no-cudagraphs batched_mm 23.87 23.86 24.02 27329.29
1 16 False grouped_mm 64.27 64.09 65.49 27329.29
1 16 True grouped_mm 59.45 59.52 60.99 27329.29
1 16 max-autotune-no-cudagraphs grouped_mm 59.61 59.55 60.89 27329.29
1 128 False eager 471.73 472.65 487.97 27396.46
1 128 True eager 637.32 613.70 845.01 27429.82
1 128 max-autotune-no-cudagraphs eager 620.21 619.35 657.74 27429.82
1 128 False batched_mm 316.67 316.94 317.92 35854.56
1 128 True batched_mm 370.29 370.29 370.57 33031.64
1 128 max-autotune-no-cudagraphs batched_mm 151.87 150.38 158.01 27429.82
1 128 False grouped_mm 78.50 78.53 80.00 27429.82
1 128 True grouped_mm 72.95 72.99 74.60 27429.82
1 128 max-autotune-no-cudagraphs grouped_mm 72.71 72.89 73.55 27429.82
4 16 False eager 431.87 433.38 448.01 27391.57
4 16 True eager 566.63 569.74 598.98 27372.12
4 16 max-autotune-no-cudagraphs eager 563.13 567.79 588.25 27372.12
4 16 False batched_mm 163.41 163.38 164.84 31585.54
4 16 True batched_mm 189.18 189.08 189.79 30173.45
4 16 max-autotune-no-cudagraphs batched_mm 79.15 79.10 79.74 27372.11
4 16 False grouped_mm 75.23 75.18 76.74 27372.11
4 16 True grouped_mm 70.35 70.40 71.71 27372.12
4 16 max-autotune-no-cudagraphs grouped_mm 70.26 70.43 71.32 27372.12
4 128 False eager 526.88 522.75 570.01 27632.62
4 128 True eager 678.18 677.54 690.97 27762.46
4 128 max-autotune-no-cudagraphs eager 676.22 677.07 681.91 27762.45
4 128 False batched_mm 1235.25 1235.33 1237.90 61465.85
4 128 True batched_mm 1505.00 1503.31 1536.10 50174.26
4 128 max-autotune-no-cudagraphs batched_mm 572.37 570.81 589.74 27762.45
4 128 False grouped_mm 80.95 81.06 81.70 27762.45
4 128 True grouped_mm 79.67 79.69 80.54 27762.45
4 128 max-autotune-no-cudagraphs grouped_mm 83.29 79.83 111.83 27762.46

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@IlyasMoutawwakil IlyasMoutawwakil changed the title BMM MoE implementation batched and grouped MoE implementations Dec 8, 2025

# --- Down projection per expert (grouped_mm) ---
mat_a_down = hidden_after_activation
mat_b_down = down_proj.transpose(-2, -1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, the way we want v5 is to have "perfect" weights with the weight converter -> this can be don ein the weight converter

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a commit locally with the transposition removed from both eager (using matmul instead of linear) and grouped_mm. I reran the same benchmark above but can't say for sure if it's faster

Experts Implementations Benchmark Results

Batch Size Seq Length Torch Compile Implementation Mean Latency (ms) Median Latency (ms) P90 Latency (ms) Peak Mem (MB)
1 16 False eager 264.84 264.32 288.75 27324.65
1 16 True eager 350.00 349.41 376.13 27329.29
1 16 max-autotune-no-cudagraphs eager 341.95 342.61 374.87 27329.29
1 16 False batched_mm 51.97 51.99 52.64 28382.50
1 16 True batched_mm 53.10 53.07 53.39 28029.63
1 16 max-autotune-no-cudagraphs batched_mm 23.48 23.51 23.54 27329.29
1 16 False grouped_mm 63.18 63.08 64.82 27329.29
1 16 True grouped_mm 59.43 59.41 60.98 27329.29
1 16 max-autotune-no-cudagraphs grouped_mm 60.68 60.76 62.36 27329.29
1 128 False eager 490.27 488.27 504.51 27396.46
1 128 True eager 671.43 637.76 1008.31 27429.82
1 128 max-autotune-no-cudagraphs eager 618.78 620.04 639.69 27429.82
1 128 False batched_mm 316.88 317.25 317.86 35854.56
1 128 True batched_mm 370.47 370.36 371.28 33031.64
1 128 max-autotune-no-cudagraphs batched_mm 152.58 150.86 159.43 27429.82
1 128 False grouped_mm 77.71 77.88 78.76 27429.82
1 128 True grouped_mm 72.99 73.06 74.21 27429.82
1 128 max-autotune-no-cudagraphs grouped_mm 72.67 72.94 73.60 27429.82
4 16 False eager 433.87 431.03 455.82 27391.57
4 16 True eager 569.67 571.82 586.81 27372.12
4 16 max-autotune-no-cudagraphs eager 553.63 557.50 579.73 27372.12
4 16 False batched_mm 164.00 164.10 164.77 31585.54
4 16 True batched_mm 189.24 189.19 189.62 30173.45
4 16 max-autotune-no-cudagraphs batched_mm 79.59 79.45 80.01 27372.11
4 16 False grouped_mm 75.40 75.04 78.37 27372.11
4 16 True grouped_mm 69.93 70.10 71.27 27372.12
4 16 max-autotune-no-cudagraphs grouped_mm 69.79 69.97 71.58 27372.12
4 128 False eager 524.48 520.53 561.31 27632.62
4 128 True eager 702.86 702.79 716.28 27762.46
4 128 max-autotune-no-cudagraphs eager 687.21 682.97 716.89 27762.45
4 128 False batched_mm 1236.74 1236.75 1239.48 61465.86
4 128 True batched_mm 1469.06 1468.82 1470.12 50174.26
4 128 max-autotune-no-cudagraphs batched_mm 570.72 570.03 576.75 27762.45
4 128 False grouped_mm 81.61 81.54 82.90 27762.45
4 128 True grouped_mm 79.41 79.44 79.85 27762.46
4 128 max-autotune-no-cudagraphs grouped_mm 79.44 79.50 80.12 27762.45

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing licence.

If we go down that road, which I like TBH, we should also add kernels supports + try at least for FP8 to see how this would work.

Because the other solution is to use use_hf_hub_kernel decorator as well but looks more cumbersome. So following FA2 we want to support kernels from the hub in this as well + quantization

This does look nice, for the bench can you add compile cases as well please?

Also we need to make sure this works with TP / EP

@IlyasMoutawwakil IlyasMoutawwakil changed the title batched and grouped MoE implementations batched and grouped experts implementations Dec 15, 2025
Comment on lines +81 to +93
All three backends (`"eager"`, `"batched_mm"`, `"grouped_mm"`) are compatible with `torch.compile` to certain extents. The following table summarizes compatibility:

| Implementation | compilation modes | dtypes | `fullgraph=True` |
| -------------- | ------------------------------------ | -------------------------------- | ---------------- |
| `grouped_mm` | `None`, `max-autotune-no-cudagraphs` | `bfloat16` | Yes |
| `batched_mm` | all | `bfloat16`, `float16`, `float32` | Yes |
| `eager` | all | `bfloat16`, `float16`, `float32` | No |

Notes:

- The `grouped_mm` experts backend currently only supports `bfloat16` when compiled with `torch.compile`. Additionally, it is not compatible with CUDA graphs, so you must use `mode=None` or `mode="max-autotune-no-cudagraphs"` when compiling.
- The `eager` experts backend uses a data-dependent operation to find which experts are used in a forward pass. This operation is not compatible with full graph compilation (`fullgraph=True`).
- When using `float16` or `float32` with `grouped_mm`, the model will automatically fall back to `batched_mm` when compiled.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@github-actions
Copy link
Contributor

github-actions bot commented Jan 1, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: afmoe, bamba, dbrx, deepseek_v2, deepseek_v3, dots1, ernie4_5_moe, ernie4_5_vl_moe, falcon_mamba, flex_olmo, glm4_moe, glm4v_moe, gpt_oss, granitemoe, granitemoehybrid

@github-actions
Copy link
Contributor

github-actions bot commented Jan 1, 2026

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=42697&sha=2ebaba

Comment on lines 270 to +287
@@ -281,16 +281,18 @@ def lazy_initialization(self, key_states: torch.Tensor):
i.e. `mode="reduce-overhead"` is known to fail). But it will in general work correctly, and prefill should
not be compiled anyway for performances!
"""
self.max_batch_size, self.num_heads, _, self.head_dim = key_states.shape
self.dtype, self.device = key_states.dtype, key_states.device
self.max_batch_size, self.num_heads = key_states.shape[:2]
self.v_head_dim = value_states.shape[-1]
self.k_head_dim = key_states.shape[-1]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

had to add value_states to the signature since some models like deepseek have different k/v head_dim

)

# Finally: if we can compile, disable tokenizers parallelism and check for FA2 + static cache
if can_compile:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seemed to me like this if was missing from the logic

Comment on lines +2195 to +2202
# If we use grouped_mm and dtype different than bfloat16, we fallback to batched_mm
if self.config._experts_implementation == "grouped_mm":
if self.dtype != torch.bfloat16:
logger.warning_once(
"torch._grouped_mm currently only supports bfloat16 when being compiled with torch.compile. "
"Falling back to batched_mm implementation for compilation."
)
self.set_experts_implementation("batched_mm")
Copy link
Member Author

@IlyasMoutawwakil IlyasMoutawwakil Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

preferred falling back to batched_mm here only when we are optimizing the model for compiled generation.
falling back during the forward pass seemed like very implicit behavior that might go against user intention.

@IlyasMoutawwakil IlyasMoutawwakil merged commit 0642963 into main Jan 5, 2026
26 checks passed
@IlyasMoutawwakil IlyasMoutawwakil deleted the moe-imp branch January 5, 2026 09:53
@ArthurZucker
Copy link
Collaborator

Kudos!

sniper35 pushed a commit to sniper35/transformers that referenced this pull request Jan 5, 2026
* meo implementation

* support more MoEs

* tests

* add comments

* add grouped_mm support

* typing act_fn and adding stride 16 note

* style

* fix dbrx config

* fix config test

* add licence and better stride conditions

* comment

* no need to pad tesnors to 16 byte strides if we made sure our tiny testing models have 16 byte aligned weights

* use a class decorator with a registration interface

* remove line

* remove unnecessary

* register config with the decorator

* fix redundant

* reduce changes some more

* fix

* fix

* import from integrations

* remove empty lines

* use histc instead of bincount

* fix cpu histc not supporting long

* docs

* added benchmark to docs

* add to from_pretrained's docstring

* make grouped_mm the deafault when possible

* Update docs/source/en/experts_interface.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/experts_interface.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/experts_interface.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/experts_interface.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/experts_interface.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/experts_interface.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/experts_interface.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/experts_interface.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/experts_interface.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/experts_interface.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/experts_interface.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/experts_interface.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/experts_interface.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* make qwen3 vl moe inherit its experts and sparse moe blocks from qwen3 moe, making it use experts implementation

* create _supports_grouped_mm flag and use it for testing

* fix copies

* better grouped mm checks

* fix model size failure

* better docs

* get rid of class property _supports_grouped_mm

* add method calling checks and fix models that didn't have experts

* fix copies

* fix

* fix

* more cleanup

* clean

* document compilation behaviour

* docs

* fix new moe after merge

* fix the new ernie 4.5 vl moe testing

* support fullgraph automatic compilation for MoEs

* fix lazy initialization

* disable fullgraph for granitemoe and jetmoe because of topk gating

* avoid implicit fallback in experts implementation and only do it when auto-compiling

* style

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
@vasqu vasqu mentioned this pull request Jan 14, 2026
5 tasks
SangbumChoi pushed a commit to SangbumChoi/transformers that referenced this pull request Jan 23, 2026
* meo implementation

* support more MoEs

* tests

* add comments

* add grouped_mm support

* typing act_fn and adding stride 16 note

* style

* fix dbrx config

* fix config test

* add licence and better stride conditions

* comment

* no need to pad tesnors to 16 byte strides if we made sure our tiny testing models have 16 byte aligned weights

* use a class decorator with a registration interface

* remove line

* remove unnecessary

* register config with the decorator

* fix redundant

* reduce changes some more

* fix

* fix

* import from integrations

* remove empty lines

* use histc instead of bincount

* fix cpu histc not supporting long

* docs

* added benchmark to docs

* add to from_pretrained's docstring

* make grouped_mm the deafault when possible

* Update docs/source/en/experts_interface.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/experts_interface.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/experts_interface.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/experts_interface.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/experts_interface.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/experts_interface.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/experts_interface.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/experts_interface.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/experts_interface.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/experts_interface.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/experts_interface.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/experts_interface.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/experts_interface.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* make qwen3 vl moe inherit its experts and sparse moe blocks from qwen3 moe, making it use experts implementation

* create _supports_grouped_mm flag and use it for testing

* fix copies

* better grouped mm checks

* fix model size failure

* better docs

* get rid of class property _supports_grouped_mm

* add method calling checks and fix models that didn't have experts

* fix copies

* fix

* fix

* more cleanup

* clean

* document compilation behaviour

* docs

* fix new moe after merge

* fix the new ernie 4.5 vl moe testing

* support fullgraph automatic compilation for MoEs

* fix lazy initialization

* disable fullgraph for granitemoe and jetmoe because of topk gating

* avoid implicit fallback in experts implementation and only do it when auto-compiling

* style

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
@IlyasMoutawwakil IlyasMoutawwakil mentioned this pull request Mar 5, 2026
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants