batched and grouped experts implementations#42697
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
src/transformers/integrations/moe.py
Outdated
|
|
||
| # --- Down projection per expert (grouped_mm) --- | ||
| mat_a_down = hidden_after_activation | ||
| mat_b_down = down_proj.transpose(-2, -1) |
There was a problem hiding this comment.
same here, the way we want v5 is to have "perfect" weights with the weight converter -> this can be don ein the weight converter
There was a problem hiding this comment.
I have a commit locally with the transposition removed from both eager (using matmul instead of linear) and grouped_mm. I reran the same benchmark above but can't say for sure if it's faster
Experts Implementations Benchmark Results
| Batch Size | Seq Length | Torch Compile | Implementation | Mean Latency (ms) | Median Latency (ms) | P90 Latency (ms) | Peak Mem (MB) |
|---|---|---|---|---|---|---|---|
| 1 | 16 | False | eager | 264.84 | 264.32 | 288.75 | 27324.65 |
| 1 | 16 | True | eager | 350.00 | 349.41 | 376.13 | 27329.29 |
| 1 | 16 | max-autotune-no-cudagraphs | eager | 341.95 | 342.61 | 374.87 | 27329.29 |
| 1 | 16 | False | batched_mm | 51.97 | 51.99 | 52.64 | 28382.50 |
| 1 | 16 | True | batched_mm | 53.10 | 53.07 | 53.39 | 28029.63 |
| 1 | 16 | max-autotune-no-cudagraphs | batched_mm | 23.48 | 23.51 | 23.54 | 27329.29 |
| 1 | 16 | False | grouped_mm | 63.18 | 63.08 | 64.82 | 27329.29 |
| 1 | 16 | True | grouped_mm | 59.43 | 59.41 | 60.98 | 27329.29 |
| 1 | 16 | max-autotune-no-cudagraphs | grouped_mm | 60.68 | 60.76 | 62.36 | 27329.29 |
| 1 | 128 | False | eager | 490.27 | 488.27 | 504.51 | 27396.46 |
| 1 | 128 | True | eager | 671.43 | 637.76 | 1008.31 | 27429.82 |
| 1 | 128 | max-autotune-no-cudagraphs | eager | 618.78 | 620.04 | 639.69 | 27429.82 |
| 1 | 128 | False | batched_mm | 316.88 | 317.25 | 317.86 | 35854.56 |
| 1 | 128 | True | batched_mm | 370.47 | 370.36 | 371.28 | 33031.64 |
| 1 | 128 | max-autotune-no-cudagraphs | batched_mm | 152.58 | 150.86 | 159.43 | 27429.82 |
| 1 | 128 | False | grouped_mm | 77.71 | 77.88 | 78.76 | 27429.82 |
| 1 | 128 | True | grouped_mm | 72.99 | 73.06 | 74.21 | 27429.82 |
| 1 | 128 | max-autotune-no-cudagraphs | grouped_mm | 72.67 | 72.94 | 73.60 | 27429.82 |
| 4 | 16 | False | eager | 433.87 | 431.03 | 455.82 | 27391.57 |
| 4 | 16 | True | eager | 569.67 | 571.82 | 586.81 | 27372.12 |
| 4 | 16 | max-autotune-no-cudagraphs | eager | 553.63 | 557.50 | 579.73 | 27372.12 |
| 4 | 16 | False | batched_mm | 164.00 | 164.10 | 164.77 | 31585.54 |
| 4 | 16 | True | batched_mm | 189.24 | 189.19 | 189.62 | 30173.45 |
| 4 | 16 | max-autotune-no-cudagraphs | batched_mm | 79.59 | 79.45 | 80.01 | 27372.11 |
| 4 | 16 | False | grouped_mm | 75.40 | 75.04 | 78.37 | 27372.11 |
| 4 | 16 | True | grouped_mm | 69.93 | 70.10 | 71.27 | 27372.12 |
| 4 | 16 | max-autotune-no-cudagraphs | grouped_mm | 69.79 | 69.97 | 71.58 | 27372.12 |
| 4 | 128 | False | eager | 524.48 | 520.53 | 561.31 | 27632.62 |
| 4 | 128 | True | eager | 702.86 | 702.79 | 716.28 | 27762.46 |
| 4 | 128 | max-autotune-no-cudagraphs | eager | 687.21 | 682.97 | 716.89 | 27762.45 |
| 4 | 128 | False | batched_mm | 1236.74 | 1236.75 | 1239.48 | 61465.86 |
| 4 | 128 | True | batched_mm | 1469.06 | 1468.82 | 1470.12 | 50174.26 |
| 4 | 128 | max-autotune-no-cudagraphs | batched_mm | 570.72 | 570.03 | 576.75 | 27762.45 |
| 4 | 128 | False | grouped_mm | 81.61 | 81.54 | 82.90 | 27762.45 |
| 4 | 128 | True | grouped_mm | 79.41 | 79.44 | 79.85 | 27762.46 |
| 4 | 128 | max-autotune-no-cudagraphs | grouped_mm | 79.44 | 79.50 | 80.12 | 27762.45 |
There was a problem hiding this comment.
missing licence.
If we go down that road, which I like TBH, we should also add kernels supports + try at least for FP8 to see how this would work.
Because the other solution is to use use_hf_hub_kernel decorator as well but looks more cumbersome. So following FA2 we want to support kernels from the hub in this as well + quantization
This does look nice, for the bench can you add compile cases as well please?
Also we need to make sure this works with TP / EP
…sting models have 16 byte aligned weights
docs/source/en/experts_interface.md
Outdated
| All three backends (`"eager"`, `"batched_mm"`, `"grouped_mm"`) are compatible with `torch.compile` to certain extents. The following table summarizes compatibility: | ||
|
|
||
| | Implementation | compilation modes | dtypes | `fullgraph=True` | | ||
| | -------------- | ------------------------------------ | -------------------------------- | ---------------- | | ||
| | `grouped_mm` | `None`, `max-autotune-no-cudagraphs` | `bfloat16` | Yes | | ||
| | `batched_mm` | all | `bfloat16`, `float16`, `float32` | Yes | | ||
| | `eager` | all | `bfloat16`, `float16`, `float32` | No | | ||
|
|
||
| Notes: | ||
|
|
||
| - The `grouped_mm` experts backend currently only supports `bfloat16` when compiled with `torch.compile`. Additionally, it is not compatible with CUDA graphs, so you must use `mode=None` or `mode="max-autotune-no-cudagraphs"` when compiling. | ||
| - The `eager` experts backend uses a data-dependent operation to find which experts are used in a forward pass. This operation is not compatible with full graph compilation (`fullgraph=True`). | ||
| - When using `float16` or `float32` with `grouped_mm`, the model will automatically fall back to `batched_mm` when compiled. |
d204a42 to
707adf1
Compare
|
[For maintainers] Suggested jobs to run (before merge) run-slow: afmoe, bamba, dbrx, deepseek_v2, deepseek_v3, dots1, ernie4_5_moe, ernie4_5_vl_moe, falcon_mamba, flex_olmo, glm4_moe, glm4v_moe, gpt_oss, granitemoe, granitemoehybrid |
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=42697&sha=2ebaba |
| @@ -281,16 +281,18 @@ def lazy_initialization(self, key_states: torch.Tensor): | |||
| i.e. `mode="reduce-overhead"` is known to fail). But it will in general work correctly, and prefill should | |||
| not be compiled anyway for performances! | |||
| """ | |||
| self.max_batch_size, self.num_heads, _, self.head_dim = key_states.shape | |||
| self.dtype, self.device = key_states.dtype, key_states.device | |||
| self.max_batch_size, self.num_heads = key_states.shape[:2] | |||
| self.v_head_dim = value_states.shape[-1] | |||
| self.k_head_dim = key_states.shape[-1] | |||
There was a problem hiding this comment.
had to add value_states to the signature since some models like deepseek have different k/v head_dim
| ) | ||
|
|
||
| # Finally: if we can compile, disable tokenizers parallelism and check for FA2 + static cache | ||
| if can_compile: |
There was a problem hiding this comment.
seemed to me like this if was missing from the logic
| # If we use grouped_mm and dtype different than bfloat16, we fallback to batched_mm | ||
| if self.config._experts_implementation == "grouped_mm": | ||
| if self.dtype != torch.bfloat16: | ||
| logger.warning_once( | ||
| "torch._grouped_mm currently only supports bfloat16 when being compiled with torch.compile. " | ||
| "Falling back to batched_mm implementation for compilation." | ||
| ) | ||
| self.set_experts_implementation("batched_mm") |
There was a problem hiding this comment.
preferred falling back to batched_mm here only when we are optimizing the model for compiled generation.
falling back during the forward pass seemed like very implicit behavior that might go against user intention.
|
Kudos! |
* meo implementation * support more MoEs * tests * add comments * add grouped_mm support * typing act_fn and adding stride 16 note * style * fix dbrx config * fix config test * add licence and better stride conditions * comment * no need to pad tesnors to 16 byte strides if we made sure our tiny testing models have 16 byte aligned weights * use a class decorator with a registration interface * remove line * remove unnecessary * register config with the decorator * fix redundant * reduce changes some more * fix * fix * import from integrations * remove empty lines * use histc instead of bincount * fix cpu histc not supporting long * docs * added benchmark to docs * add to from_pretrained's docstring * make grouped_mm the deafault when possible * Update docs/source/en/experts_interface.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/experts_interface.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/experts_interface.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/experts_interface.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/experts_interface.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/experts_interface.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/experts_interface.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/experts_interface.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/experts_interface.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/experts_interface.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/experts_interface.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/experts_interface.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/experts_interface.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Apply suggestion from @stevhliu Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Apply suggestion from @stevhliu Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * make qwen3 vl moe inherit its experts and sparse moe blocks from qwen3 moe, making it use experts implementation * create _supports_grouped_mm flag and use it for testing * fix copies * better grouped mm checks * fix model size failure * better docs * get rid of class property _supports_grouped_mm * add method calling checks and fix models that didn't have experts * fix copies * fix * fix * more cleanup * clean * document compilation behaviour * docs * fix new moe after merge * fix the new ernie 4.5 vl moe testing * support fullgraph automatic compilation for MoEs * fix lazy initialization * disable fullgraph for granitemoe and jetmoe because of topk gating * avoid implicit fallback in experts implementation and only do it when auto-compiling * style --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
* meo implementation * support more MoEs * tests * add comments * add grouped_mm support * typing act_fn and adding stride 16 note * style * fix dbrx config * fix config test * add licence and better stride conditions * comment * no need to pad tesnors to 16 byte strides if we made sure our tiny testing models have 16 byte aligned weights * use a class decorator with a registration interface * remove line * remove unnecessary * register config with the decorator * fix redundant * reduce changes some more * fix * fix * import from integrations * remove empty lines * use histc instead of bincount * fix cpu histc not supporting long * docs * added benchmark to docs * add to from_pretrained's docstring * make grouped_mm the deafault when possible * Update docs/source/en/experts_interface.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/experts_interface.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/experts_interface.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/experts_interface.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/experts_interface.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/experts_interface.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/experts_interface.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/experts_interface.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/experts_interface.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/experts_interface.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/experts_interface.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/experts_interface.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/experts_interface.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Apply suggestion from @stevhliu Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Apply suggestion from @stevhliu Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * make qwen3 vl moe inherit its experts and sparse moe blocks from qwen3 moe, making it use experts implementation * create _supports_grouped_mm flag and use it for testing * fix copies * better grouped mm checks * fix model size failure * better docs * get rid of class property _supports_grouped_mm * add method calling checks and fix models that didn't have experts * fix copies * fix * fix * more cleanup * clean * document compilation behaviour * docs * fix new moe after merge * fix the new ernie 4.5 vl moe testing * support fullgraph automatic compilation for MoEs * fix lazy initialization * disable fullgraph for granitemoe and jetmoe because of topk gating * avoid implicit fallback in experts implementation and only do it when auto-compiling * style --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
What does this PR do?
I have started experimenting with pure pytorch MoE implementations following the HF exporters PR while trying to find a traceable/exportable variant for onnx/openvino.
In this PR I copy the
attn_implementationAPI into a similarexperts_implementationAPI, and added two new implementations:batched_mm(the exportable one) which usestorch.bmm, is fastest on single batch size / small inputs.grouped_mm(the pytorch custom kernel one) inspired from torchtitan's moe imp (usingtorch._grouped_mm), which is generally fastest.benchmark
An initial benchmark shows promising results on (A100), I know that the
torch._grouped_mmuses bfloat16 or something under the hood, so these might not be apples to apples (i'm still looking for more references on this function and how to use it "equivalently")MoE Implementations Benchmark
Benchmark script: bench.py
It uses qwen2_moe ("Qwen/Qwen1.5-MoE-A2.7B", bfloat16) where latency and memory are for the forward pass / prefill
TLDR; for very small inputs batched_mm can be extremely fast and even faster with compilation, for bigger inputs grouped_mm is unbeatable but it doesn't seem to get much faster with torch compilation.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.