Skip to content

[Performance] Fuse DeepSeek shared experts and gate operations#28540

Closed
Red-Caesar wants to merge 4 commits into
vllm-project:mainfrom
axeltec-software:deepseek_optimizations
Closed

[Performance] Fuse DeepSeek shared experts and gate operations#28540
Red-Caesar wants to merge 4 commits into
vllm-project:mainfrom
axeltec-software:deepseek_optimizations

Conversation

@Red-Caesar

@Red-Caesar Red-Caesar commented Nov 12, 2025

Copy link
Copy Markdown

Purpose

This PR contains optimizations from two other PRs that we aligned with current main:

  • Fused shared expert (RFC#26108)
  • Fused operation for gating

This effort is driven by Nebius AI, as part of ongoing optimization work.

Command to run the optimization version:

VLLM_USE_FUSED_MOE_GROUPED_TOPK=0 VLLM_USE_FUSED_MOE_ROUTER=1 VLLM_USE_CUDA_FUSION_SHARED_EXPERTS=1 python3 -m vllm.entrypoints.openai.api_server --model <deepseek-model> --tensor-parallel-size 8 --trust-remote-code

Results

Our measurements show that these optimizations improve performance. Specifically, they give a significant boost to the TTFT metric. Results on the ITL metric are generally positive, with the observed slight performance drawdowns in rare configurations.

Fused shared expert (PR#15502)

This has already been implemented for AMD in by PR#24097. Our PR makes it available to CUDA as well.

Description

Currently, we use an env variable (VLLM_USE_CUDA_FUSION_SHARED_EXPERTS) to enable or disable the feature.
When this feature is enabled, we internally set

expert number = config.n_routed_experts + config.n_shared_experts
topk = config.num_experts_per_tok + config.n_shared_experts

for FusedMoE module. During the weight loading stage, we clone the shared expert weights into the experts (using the loading code from PR#24097). Shared expert id and expert weights are manually assigned to ensure that experts are balanced and accuracy is guaranteed.

Fused operation for gating (PR#21107)

This fuses the DeepSeek-style grouped top-k experts selection process (3 top-k ops and bells and whistles) into one custom kernel.

The kernel is imported from sglang sgl-project/sglang#4530

Currently we use an env variable (VLLM_USE_FUSED_MOE_ROUTER) to enable or disable the feature.

Notes

When these optimizations are enabled we can’t use the fused grouped_topk kernel for MoE (#23274) which is now the default.

Tests

All tests have been performed on DeepSeek V3 on the h200.

gsm8k results

Original version:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9181|±  |0.0076|
|     |       |strict-match    |     5|exact_match|↑  |0.8067|±  |0.0109|

Optimization version:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9105|±  |0.0079|
|     |       |strict-match    |     5|exact_match|↑  |0.7976|±  |0.0111|

vLLM benchmark results

We tests releases/v0.11.1 original vs releases/v0.11.1 + our optimization
image

python benchmarks/kernels/benchmark_moe_fused_gate.py 
moe-fused-gate-performance:
   seq_length    Original  SGL Kernel
0      5000.0   55.936001   27.936000
1     10000.0   61.535999   38.304001
2     15000.0   84.927998   45.600001
3     20000.0  104.704000   54.880001
4     25000.0  126.255997   61.503999
5     30000.0  148.095995   71.392000
6     35000.0  167.840004   81.568003
7     40000.0  187.215999   87.007999

Our benchmark results

Our benchmark results you can find in the doc.

@mergify

mergify Bot commented Nov 12, 2025

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Red-Caesar.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Nov 12, 2025

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant performance optimizations for DeepSeek models by fusing shared experts and gate operations. The changes are well-structured, introducing new environment variables to control the features and adding corresponding benchmarks and tests. However, I've identified a critical logic error in the dispatch logic of the new CUDA kernel csrc/moe/moe_fused_gate.cu that makes some of the templated kernel specializations unreachable. This needs to be fixed to ensure correctness for all supported configurations.

Comment thread csrc/moe/moe_fused_gate.cu

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 1248 to 1265
assert routed_scaling_factor is not None, \
"With num_fused_shared_experts>0"
", routed_scaling_factor need to be provided"
topk_ids[:, -1] = torch.randint(low=num_experts,
high=num_experts +
num_fused_shared_experts,
size=(topk_ids.size(0), ),
dtype=topk_ids.dtype,
device=topk_ids.device)
topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor

if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
if num_fused_shared_experts == 0:
topk_weights_sum = topk_weights.sum(dim=-1, keepdim=True)
else:
topk_weights_sum = topk_weights[:, :-1].sum(dim=-1, keepdim=True)
topk_weights = topk_weights / topk_weights_sum

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Renormalization ignores fused shared expert weight

When shared experts are fused the code overwrites the last top‑k slot with topk_weights[:, :-1].sum()/routed_scaling_factor and then divides the entire vector by the sum of only the non‑shared weights. This makes the non‑shared weights sum to 1 but leaves the shared weight at 1/routed_scaling_factor, so the probabilities add up to 1 + num_fused_shared_experts/routed_scaling_factor and are no longer scaled by routed_scaling_factor as the non‑fused path does. Because DeepseekV2MoE disables external scaling whenever this feature is enabled, tokens routed through fused shared experts will be over‑weighted, producing larger activations than the unfused path. Consider including the shared weight in the normalization or renormalizing after inserting the shared expert so the weights still form a properly scaled distribution.

Useful? React with 👍 / 👎.

Comment thread csrc/moe/moe_fused_gate.cu
@Red-Caesar Red-Caesar force-pushed the deepseek_optimizations branch from 5aa1381 to b149d9c Compare November 12, 2025 10:21
@hmellor

hmellor commented Nov 12, 2025

Copy link
Copy Markdown
Member

cc @bnellnm for FusedMoE

cc @alexm-redhat for RFC

@Red-Caesar Red-Caesar force-pushed the deepseek_optimizations branch from b149d9c to 0c89e7f Compare November 12, 2025 13:39
@mergify mergify Bot removed the needs-rebase label Nov 12, 2025
@Red-Caesar Red-Caesar force-pushed the deepseek_optimizations branch from cf57507 to 02fef8e Compare November 12, 2025 14:34
@@ -1769,6 +1824,7 @@ def forward_impl(
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
enable_fused_moe_router=self.enable_fused_moe_router,

@bnellnm bnellnm Nov 13, 2025

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since enable_fused_moe_router is a FusedMoE layer attribute and apply gets the layer as a parameter, can we avoid passing it around to all the apply methods?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated the way I pass the enable_fused_moe_router parameter - it's now an env variable only in the select_experts function. This has made the code more readable.

@mergify

mergify Bot commented Nov 13, 2025

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Red-Caesar.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Nov 13, 2025
@heheda12345

Copy link
Copy Markdown
Collaborator

Also CC @tlrmchlsmth

@Red-Caesar Red-Caesar force-pushed the deepseek_optimizations branch from 02fef8e to f9d6dc5 Compare November 14, 2025 09:27
@mergify

mergify Bot commented Nov 19, 2025

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Red-Caesar.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@alexm-redhat alexm-redhat left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! I have verified its correctness and did some initial performance verifications - all looks good. Left some comments. The PR needs rebase and it should be ready to be merged.

@@ -0,0 +1,484 @@
#include <ATen/cuda/CUDAContext.h>

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to add a comment that it is taken from sglang

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added

@@ -477,6 +491,15 @@ def __init__(
self.global_num_experts,
get_compressed_expert_map(self.expert_map),
)
if self.num_fused_shared_experts > 0:
logger.warning(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it work with EP? If not, then maybe we should assert here some conditions.

@@ -1360,27 +1383,51 @@ def select_experts(
elif use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
if rocm_aiter_ops.is_fused_moe_enabled():
if hidden_states.shape[0] == 0:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this shape[0] == 0 case is necessary? Would be good to have some quick doc explaining this piece of code.

and e_score_correction_bias is not None
and is_power_of_two(e_score_correction_bias.shape[0])
):
# The fused kernel can only work with 128/256 experts

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a check for E=128 or E=256 in the if statement as well?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this comment was a bit misleading, so I've deleted it. The kernel works with routed_experts, which are powers of 2. So, if the shape of e_score_correction_bias is a power of two, I suppose it will also be true for the experts.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see makes sense

apply_routed_scaling_factor_on_output=False,
)
else:
topk_weights, topk_ids = grouped_topk_impl(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In which case the moe_fused_gate is faster than the usual grouped_topk_impl?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment, it's not faster on its own anymore, but it provides a boost when used together with fused_experts as the default grouped_topk isn't adapted to it

@@ -1809,7 +1856,7 @@ def combine_output(states: torch.Tensor) -> torch.Tensor:
states = get_ep_group().combine(states, self.is_sequence_parallel)
return states

if self.shared_experts is not None:
if self.shared_experts is not None and self.num_fused_shared_experts == 0:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is it possible to have self.shared_experts != None and self.num_fused_shared_experts > 0? It seems like if self.num_fused_shared_experts > 0 then elf.shared_experts must be None

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you're right, fix it

@@ -325,14 +337,15 @@ def __init__(
# we do scaling outside, set factor to 1.0 to avoid double mul
# aiter applies routed_scaling_factor internally
routed_scaling_factor=1.0
if not self.is_rocm_aiter_moe_enabled
if not used_inside_scaling

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some documentation about how "used_inside_scaling" changes the functionality of the scaling application

Signed-off-by: Barbara Suslova <barbara.suslova@axel-t.com>

lint

Signed-off-by: Barbara Suslova <barbara.suslova@axel-t.com>

change the logic of passing variables

Signed-off-by: Barbara Suslova <barbara.suslova@axel-t.com>

aligning

Signed-off-by: Barbara Suslova <barbara.suslova@axel-t.com>
Signed-off-by: Barbara Suslova <barbara.suslova@axel-t.com>
@Red-Caesar Red-Caesar force-pushed the deepseek_optimizations branch from de72969 to b7a869d Compare November 28, 2025 17:24
@mergify mergify Bot removed the needs-rebase label Nov 28, 2025
Red-Caesar and others added 2 commits November 28, 2025 17:38
Signed-off-by: Barbara Suslova <barbara.suslova@axel-t.com>
@mergify

mergify Bot commented Dec 9, 2025

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Red-Caesar.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Dec 9, 2025
@github-actions

Copy link
Copy Markdown

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

@github-actions github-actions Bot added the stale Over 90 days of inactivity label Mar 10, 2026
@github-actions

Copy link
Copy Markdown

This pull request has been automatically closed due to inactivity. Please feel free to reopen if you intend to continue working on it. Thank you!

@github-actions github-actions Bot closed this Apr 10, 2026
omirosh added a commit to omirosh/vllm that referenced this pull request Jun 5, 2026
## Purpose

Extend the AITER Fused Shared Expert (FSE) path - originally added for
DeepSeek-V2/V3 (vllm-project#28540) and Qwen3-Next (vllm-project#39280) - to the GLM-4 MoE family
(GLM-4.5, GLM-4.6, GLM-4.7). When `VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=1`
the shared expert is folded into the AITER FusedMoE kernel as
`n_shared_experts` extra expert slots, eliminating the separate shared-expert
MLP forward pass at low/medium concurrency.

## Changes

Single-file model wiring in `vllm/model_executor/models/glm4_moe.py`, mirroring
the canonical `deepseek_v2.py` FSE pattern:

* `Glm4MoE.__init__`
  - Cache `is_rocm_aiter_moe_enabled` and `is_fusion_moe_shared_experts_enabled`
    from `rocm_aiter_ops`.
  - When FSE is enabled, skip building the separate `shared_experts` MLP and
    pass `n_shared_experts=config.n_shared_experts` to `FusedMoE` so the
    AITER kernel routes the shared expert(s) as extra slots in the routed
    tensor.
  - Switch `apply_routed_scale_to_output` to
    `not self.is_rocm_aiter_moe_enabled`. AITER applies `routed_scaling_factor`
    internally, per routed slot; applying it again post-fusion would also
    scale the FSE shared-expert slot (which the kernel inserts with unit
    weight), producing a structural magnitude error in every MoE layer.
    This matches `deepseek_v2.py`. (`routed_scaling_factor=2.5` for GLM-4.7,
    so the unfixed path showed a ~48 pp gsm8k regression.)

* `Glm4MoeModel.get_expert_mapping`
  - Widen `num_experts` by `config.n_shared_experts` when FSE is on so the
    weight loader enumerates the appended slots.

* `Glm4MoeModel.load_weights`
  - Treat `mlp.shared_experts.{gate,up,down}_proj.*` as expert-style tensors
    when FSE is on (skip the stacked QKV/gate_up linear path).
  - Split each widened shared-expert tensor into `n_shared_experts` chunks
    along the intermediate-size axis (dim 0 for ColumnParallel
    gate/up_proj, dim 1 for RowParallel down_proj) and route each chunk to
    `mlp.experts.{n_routed_experts + j}.*` via the FusedMoE expert-aware
    weight loader.

No changes to FusedMoE / AITER plumbing - all of that landed earlier with
vllm-project#39280 (Qwen3-Next FSE).

## Test Plan

* Model: `zai-org/GLM-4.7-FP8`
* Hardware: 1x MI355X node, TP=4
* Container: ROCm vLLM image (AITER >= v0.1.13.post1, PR vllm-project#44265)
* Accuracy: `lm_eval --tasks gsm8k --num_fewshot 5`
* Throughput: `vllm bench serve --dataset-name random` sweep over
  (ISL, OSL, MC) in {1000/100, 5000/500, 10000/1000} x {4, 16, 64}

Server launch:

```
VLLM_ROCM_USE_AITER=1 \
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=<0|1> \
vllm serve zai-org/GLM-4.7-FP8 \
  --tensor-parallel-size 4 \
  --gpu-memory-utilization 0.92 \
  --max-model-len 32768 \
  --max-num-seqs 256
```

## Test Result

### Accuracy (gsm8k, 5-shot, exact_match)

| Config              | flexible-extract | strict-match     |
|---------------------|-----------------:|-----------------:|
| FSE=0 (baseline)    | 0.9469 ± 0.0062  | 0.9439 ± 0.0063  |
| FSE=1               | 0.9439 ± 0.0063  | 0.9416 ± 0.0065  |

All deltas within standard error. No accuracy regression.

### Throughput (`vllm bench serve`, random)

| ISL  | OSL  | MC | TPOT mean (ms) FSE=0 -> FSE=1 (Δ) | TPOT p99 (ms) FSE=0 -> FSE=1 (Δ) | Output tok/s FSE=0 -> FSE=1 (Δ) | Total tok/s FSE=0 -> FSE=1 (Δ) |
|-----:|-----:|---:|----------------------------------:|---------------------------------:|--------------------------------:|-------------------------------:|
|  1000|   100|   4| 17.76 -> 14.36  (**-19.2%**)      | 19.43 -> 15.93 (**-18.0%**)      | 199.4 -> 243.6  (**+22.1%**)    | 2193.7 -> 2679.1 (**+22.1%**)  |
|  1000|   100|  16| 20.96 -> 18.48  (**-11.9%**)      | 24.29 -> 22.77 (-6.3%)           | 631.0 -> 673.4  (**+6.7%**)     | 6940.6 -> 7407.9 (**+6.7%**)   |
|  1000|   100|  64| 30.74 -> 30.23  (-1.7%)           | 42.85 -> 43.44 (+1.4%)           | 1452.7 -> 1424.3 (-2.0%)        | 15980.1 -> 15667.6 (-2.0%)     |
|  5000|   500|   4| 17.82 -> 14.50  (**-18.7%**)      | 18.63 -> 15.50 (**-16.8%**)      | 211.5 -> 253.5  (**+19.9%**)    | 2326.1 -> 2788.7 (**+19.9%**)  |
|  5000|   500|  16| 22.73 -> 20.76  (**-8.7%**)       | 25.38 -> 23.07 (**-9.1%**)       | 619.1 -> 657.7  (**+6.2%**)     | 6810.4 -> 7234.6 (**+6.2%**)   |
|  5000|   500|  64| 39.79 -> 40.15  (+0.9%)           | 46.15 -> 46.78 (+1.4%)           | 1363.8 -> 1339.1 (-1.8%)        | 15001.9 -> 14730.4 (-1.8%)     |
| 10000|  1000|   4| 18.00 -> 14.70  (**-18.3%**)      | 18.68 -> 15.50 (**-17.0%**)      | 210.3 -> 251.8  (**+19.7%**)    | 2313.5 -> 2769.4 (**+19.7%**)  |
| 10000|  1000|  16| 24.47 -> 22.87  (-6.5%)           | 26.66 -> 25.56 (-4.1%)           | 589.6 -> 615.1  (**+4.3%**)     | 6485.6 -> 6766.2 (**+4.3%**)   |
| 10000|  1000|  64| 46.37 -> 46.33  (-0.1%)           | 51.14 -> 51.78 (+1.3%)           | 1233.6 -> 1211.9 (-1.8%)        | 13570.0 -> 13330.7 (-1.8%)     |

Verdict: FSE delivers +20-22% output throughput and -18-19% TPOT at low
concurrency (MC=4), modest gains at MC=16, and is roughly break-even
(<2% regression) at MC=64. No accuracy regression.

Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Olga Miroshnichenko <olga.miroshnichenko@amd.com>
omirosh added a commit to omirosh/vllm that referenced this pull request Jun 5, 2026
## Purpose

Extend the AITER Fused Shared Expert (FSE) path - originally added for
DeepSeek-V2/V3 (vllm-project#28540) and Qwen3-Next (vllm-project#39280) - to the GLM-4 MoE family
(GLM-4.5, GLM-4.6, GLM-4.7). When `VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=1`
the shared expert is folded into the AITER FusedMoE kernel as
`n_shared_experts` extra expert slots, eliminating the separate shared-expert
MLP forward pass at low/medium concurrency.

## Changes

Single-file model wiring in `vllm/model_executor/models/glm4_moe.py`, mirroring
the canonical `deepseek_v2.py` FSE pattern:

* `Glm4MoE.__init__`
  - Cache `is_rocm_aiter_moe_enabled` and `is_fusion_moe_shared_experts_enabled`
    from `rocm_aiter_ops`.
  - When FSE is enabled, skip building the separate `shared_experts` MLP and
    pass `n_shared_experts=config.n_shared_experts` to `FusedMoE` so the
    AITER kernel routes the shared expert(s) as extra slots in the routed
    tensor.
  - Switch `apply_routed_scale_to_output` to
    `not self.is_rocm_aiter_moe_enabled`. AITER applies `routed_scaling_factor`
    internally, per routed slot; applying it again post-fusion would also
    scale the FSE shared-expert slot (which the kernel inserts with unit
    weight), producing a structural magnitude error in every MoE layer.
    This matches `deepseek_v2.py`. (`routed_scaling_factor=2.5` for GLM-4.7,
    so the unfixed path showed a ~48 pp gsm8k regression.)

* `Glm4MoeModel.get_expert_mapping`
  - Widen `num_experts` by `config.n_shared_experts` when FSE is on so the
    weight loader enumerates the appended slots.

* `Glm4MoeModel.load_weights`
  - Treat `mlp.shared_experts.{gate,up,down}_proj.*` as expert-style tensors
    when FSE is on (skip the stacked QKV/gate_up linear path).
  - Split each widened shared-expert tensor into `n_shared_experts` chunks
    along the intermediate-size axis (dim 0 for ColumnParallel
    gate/up_proj, dim 1 for RowParallel down_proj) and route each chunk to
    `mlp.experts.{n_routed_experts + j}.*` via the FusedMoE expert-aware
    weight loader.

No changes to FusedMoE / AITER plumbing - all of that landed earlier with
vllm-project#39280 (Qwen3-Next FSE).

## Test Plan

* Model: `zai-org/GLM-4.7-FP8`
* Hardware: 1x MI355X node, TP=4
* Container: ROCm vLLM image (AITER >= v0.1.13.post1, PR vllm-project#44265)
* Accuracy: `lm_eval --tasks gsm8k --num_fewshot 5`
* Throughput: `vllm bench serve --dataset-name random` sweep over
  (ISL, OSL, MC) in {1000/100, 5000/500, 10000/1000} x {4, 16, 64}

Server launch:

```
VLLM_ROCM_USE_AITER=1 \
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=<0|1> \
vllm serve zai-org/GLM-4.7-FP8 \
  --tensor-parallel-size 4 \
  --gpu-memory-utilization 0.92 \
  --max-model-len 32768 \
  --max-num-seqs 256
```

## Test Result

### Accuracy (gsm8k, 5-shot, exact_match)

| Config              | flexible-extract | strict-match     |
|---------------------|-----------------:|-----------------:|
| FSE=0 (baseline)    | 0.9469 ± 0.0062  | 0.9439 ± 0.0063  |
| FSE=1               | 0.9439 ± 0.0063  | 0.9416 ± 0.0065  |

All deltas within standard error. No accuracy regression.

### Throughput (`vllm bench serve`, random)

| ISL  | OSL  | MC | TPOT mean (ms) FSE=0 -> FSE=1 (Δ) | TPOT p99 (ms) FSE=0 -> FSE=1 (Δ) | Output tok/s FSE=0 -> FSE=1 (Δ) | Total tok/s FSE=0 -> FSE=1 (Δ) |
|-----:|-----:|---:|----------------------------------:|---------------------------------:|--------------------------------:|-------------------------------:|
|  1000|   100|   4| 17.76 -> 14.36  (**-19.2%**)      | 19.43 -> 15.93 (**-18.0%**)      | 199.4 -> 243.6  (**+22.1%**)    | 2193.7 -> 2679.1 (**+22.1%**)  |
|  1000|   100|  16| 20.96 -> 18.48  (**-11.9%**)      | 24.29 -> 22.77 (-6.3%)           | 631.0 -> 673.4  (**+6.7%**)     | 6940.6 -> 7407.9 (**+6.7%**)   |
|  1000|   100|  64| 30.74 -> 30.23  (-1.7%)           | 42.85 -> 43.44 (+1.4%)           | 1452.7 -> 1424.3 (-2.0%)        | 15980.1 -> 15667.6 (-2.0%)     |
|  5000|   500|   4| 17.82 -> 14.50  (**-18.7%**)      | 18.63 -> 15.50 (**-16.8%**)      | 211.5 -> 253.5  (**+19.9%**)    | 2326.1 -> 2788.7 (**+19.9%**)  |
|  5000|   500|  16| 22.73 -> 20.76  (**-8.7%**)       | 25.38 -> 23.07 (**-9.1%**)       | 619.1 -> 657.7  (**+6.2%**)     | 6810.4 -> 7234.6 (**+6.2%**)   |
|  5000|   500|  64| 39.79 -> 40.15  (+0.9%)           | 46.15 -> 46.78 (+1.4%)           | 1363.8 -> 1339.1 (-1.8%)        | 15001.9 -> 14730.4 (-1.8%)     |
| 10000|  1000|   4| 18.00 -> 14.70  (**-18.3%**)      | 18.68 -> 15.50 (**-17.0%**)      | 210.3 -> 251.8  (**+19.7%**)    | 2313.5 -> 2769.4 (**+19.7%**)  |
| 10000|  1000|  16| 24.47 -> 22.87  (-6.5%)           | 26.66 -> 25.56 (-4.1%)           | 589.6 -> 615.1  (**+4.3%**)     | 6485.6 -> 6766.2 (**+4.3%**)   |
| 10000|  1000|  64| 46.37 -> 46.33  (-0.1%)           | 51.14 -> 51.78 (+1.3%)           | 1233.6 -> 1211.9 (-1.8%)        | 13570.0 -> 13330.7 (-1.8%)     |

Verdict: FSE delivers +20-22% output throughput and -18-19% TPOT at low
concurrency (MC=4), modest gains at MC=16, and is roughly break-even
(<2% regression) at MC=64. No accuracy regression.

Signed-off-by: Olga Miroshnichenko <olga.miroshnichenko@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build deepseek Related to DeepSeek models needs-rebase performance Performance-related issues stale Over 90 days of inactivity

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants