Skip to content

Add pre-suffle weight for new aiter MoE support.#12908

Merged
HaiShaw merged 10 commits intosgl-project:mainfrom
sogalin:preshuffle-moe
Nov 10, 2025
Merged

Add pre-suffle weight for new aiter MoE support.#12908
HaiShaw merged 10 commits intosgl-project:mainfrom
sogalin:preshuffle-moe

Conversation

@sogalin
Copy link
Copy Markdown
Contributor

@sogalin sogalin commented Nov 9, 2025

Motivation

This PR introduces a pre-shuffle step for MoE weights to improve runtime performance and memory access efficiency.

Modifications

Added shuffle_weight from aiter.ops.shuffle to pre-shuffle w13_weight and w2_weight with a (16, 16) granularity matching aiter kernel tile size for better locality and performance.

Accuracy Tests

Model: amd/DeepSeek-R1-MXFP4-Preview

python3 benchmark/gsm8k/bench_sglang.py --num-questions 2000 --parallel 2000 --port 8000
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [01:50<00:00, 11.96it/s]
Accuracy: 0.946
Invalid: 0.000
Latency: 110.434 s
Output throughput: 1180.643 token/s

Benchmarking and Profiling

Machine: MI355 * 8 GPU
Docker Image: rocm/sgl-dev:v0.5.4.post3-rocm700-mi35x-20251106

Command:
SGLANG_USE_AITER=1 RCCL_MSCCL_ENABLE=0 SGLANG_INT4_WEIGHT=0 SGLANG_MOE_PADDING=1 SGLANG_USE_ROCM700A=1 SGLANG_SET_CPU_AFFINITY=1 SGLANG_ROCM_FUSED_DECODE_MLA=1 python3 -m sglang.launch_server --model-path /data2/deepseek-ai/DeepSeek-R1-MXFP4-Preview/ --tensor-parallel-size 8 --trust-remote-code --chunked-prefill-size 131072 --host 0.0.0.0 --port 8000 --log-requests --disable-radix-cache --mem-fraction-static 0.5 --speculative-algorithm EAGLE --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4

Prefill 1 8 16
Wo MoE Update 0.0109 0.6524 1.2924
With MoE Update 0.1071 0.6505 1.2965

Decode 1 8 16
Wo MoE Update 0.0109 0.0123 0.0139
With MoE Update 0.0103 0.0118 0.0133

Checklist

@github-actions github-actions Bot added the amd label Nov 9, 2025
@sogalin sogalin marked this pull request as ready for review November 9, 2025 06:34
@HaiShaw HaiShaw added the run-ci label Nov 10, 2025
)

_is_hip = is_hip()
_is_shuffle_moe = get_bool_env_var("AITER_MXFP4_MOE_SF") and _is_hip
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest change to _is_shuffle_moe_mxfp4

logger = logging.getLogger(__name__)

_is_hip = is_hip()
_is_shuffle_moe = get_bool_env_var("AITER_MXFP4_MOE_SF") and _is_hip
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to _is_shuffle_moe_mxfp4

@HaiShaw
Copy link
Copy Markdown
Collaborator

HaiShaw commented Nov 10, 2025

@BowenBao @kkHuang-amd please have a review.

layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)

# Pre-shuffle weight
if _is_shuffle_moe_mxfp4:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sogalin for my understanding, why would this not need any changes on the kernel call code down in the apply method?

@HaiShaw Would it make sense to keep shuffling as default since it has better perf?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BowenBao it is set in docker ENV

@HaiShaw HaiShaw merged commit 661c1c9 into sgl-project:main Nov 10, 2025
10 of 36 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants