Skip to content

[perf]optimize w4afp8 kernel on deepseek-v3-0324#12921

Merged
BBuf merged 7 commits intosgl-project:mainfrom
GMISWE:bruce_optimize_w4a8_kernel_opt
Dec 18, 2025
Merged

[perf]optimize w4afp8 kernel on deepseek-v3-0324#12921
BBuf merged 7 commits intosgl-project:mainfrom
GMISWE:bruce_optimize_w4a8_kernel_opt

Conversation

@Bruce-x-1997
Copy link
Copy Markdown
Contributor

@Bruce-x-1997 Bruce-x-1997 commented Nov 9, 2025

Motivation

we use w4afp8 deepseekv3-0324 online, and we find its performance is not good enough when decode batch size < 32

Modifications

fine-grained tiling config
and based on https://github.com/sgl-project/sglang/pull/10027/files
I use cuda-int4 memory access to decrease memory-access pressure

Accuracy Tests

deepseek-v3-0324 w4afp8

aime25          0.4/0.4
aime24          0.5/0.65/0.55
mmlu              0.8947

Benchmarking and Profiling

###Prefill

[2025-11-06 11:41:40 TP0] Prefill batch. #new-seq: 17, #new-token: 8192, #cached-token: 51, token usage: 0.02, #running-req: 0, #queue-req: 885, 
[2025-11-06 11:41:41 TP0] run dt: 566.5090084075928 ms
[2025-11-06 11:41:41 TP0] Prefill batch. #new-seq: 15, #new-token: 8192, #cached-token: 46, token usage: 0.02, #running-req: 0, #queue-req: 871, 
[2025-11-06 11:41:41 TP0] run dt: 566.9610500335693 ms
[2025-11-06 11:41:41 TP0] Prefill batch. #new-seq: 20, #new-token: 8192, #cached-token: 61, token usage: 0.02, #running-req: 0, #queue-req: 852, 
[2025-11-06 11:41:42 TP0] run dt: 563.4098052978516 ms
[2025-11-06 11:41:42 TP0] Prefill batch. #new-seq: 16, #new-token: 8192, #cached-token: 42, token usage: 0.03, #running-req: 0, #queue-req: 837, 
[2025-11-06 11:41:43 TP0] run dt: 567.133903503418 ms

to

[2025-11-06 13:23:07 TP0] run dt: 501.59597396850586 ms
[2025-11-06 13:23:07 TP0] Prefill batch. #new-seq: 17, #new-token: 8192, #cached-token: 43, token usage: 0.00, #running-req: 0, #queue-req: 938, 
[2025-11-06 13:23:08 TP0] run dt: 500.20575523376465 ms
[2025-11-06 13:23:08 TP0] Prefill batch. #new-seq: 17, #new-token: 8192, #cached-token: 45, token usage: 0.02, #running-req: 0, #queue-req: 922, 
[2025-11-06 13:23:08 TP0] run dt: 503.3257007598877 ms
[2025-11-06 13:23:08 TP0] Prefill batch. #new-seq: 19, #new-token: 8192, #cached-token: 59, token usage: 0.03, #running-req: 0, #queue-req: 904, 
[2025-11-06 13:23:09 TP0] run dt: 507.9646110534668 ms
[2025-11-06 13:23:09 TP0] Prefill batch. #new-seq: 19, #new-token: 8192, #cached-token: 58, token usage: 0.03, #running-req: 0, #queue-req: 886, 
[2025-11-06 13:23:09 TP0] run dt: 505.6488513946533 ms
[2025-11-06 13:23:09 TP0] Prefill batch. #new-seq: 16, #new-token: 8192, #cached-token: 50, token usage: 0.03, #running-req: 0, #queue-req: 871, 
[2025-11-06 13:23:10 TP0] run dt: 508.23521614074707 ms
[2025-11-06 13:23:10 TP0] Prefill batch. #new-seq: 17, #new-token: 8192, #cached-token: 49, token usage: 0.03, #running-req: 0, #queue-req: 855, 
[2025-11-06 13:23:10 TP0] run dt: 507.51590728759766 ms
[2025-11-06 13:23:10 TP0] Prefill batch. #new-seq: 18, #new-token: 8192, #cached-token: 51, token usage: 0.03, #running-req: 0, #queue-req: 838, 
[2025-11-06 13:23:11 TP0] run dt: 506.1039924621582 ms
[2025-11-06 13:23:11 TP0] Prefill batch. #new-seq: 17, #new-token: 8192, #cached-token: 53, token usage: 0.03, #running-req: 0, #queue-req: 822, 
[2025-11-06 13:23:11 TP0] run dt: 506.6485404968262 ms

~10%

###Decode

[2025-11-06 11:32:09 TP0] Decode batch. #running-req: 18, #token: 2724, token usage: 0.01, cuda graph: True, gen throughput (token/s): 482.05, #queue-req: 0, 
[2025-11-06 11:32:09 TP0] run dt: 36.68785095214844 ms
[2025-11-06 11:32:09 TP0] Decode batch. #running-req: 18, #token: 2742, token usage: 0.01, cuda graph: True, gen throughput (token/s): 482.72, #queue-req: 0, 
[2025-11-06 11:32:09 TP0] run dt: 36.653757095336914 ms
[2025-11-06 11:32:09 TP0] Decode batch. #running-req: 18, #token: 2760, token usage: 0.01, cuda graph: True, gen throughput (token/s): 484.31, #queue-req: 0, 
[2025-11-06 11:32:09 TP0] run dt: 36.479949951171875 ms

to

[2025-11-06 13:25:44 TP0] Decode batch. #running-req: 20, #token: 821, token usage: 0.00, cuda graph: True, gen throughput (token/s): 553.01, #queue-req: 0, 
[2025-11-06 13:25:44 TP0] run dt: 33.98561477661133 ms
[2025-11-06 13:25:44 TP0] Decode batch. #running-req: 20, #token: 841, token usage: 0.00, cuda graph: True, gen throughput (token/s): 565.43, #queue-req: 0, 
[2025-11-06 13:25:44 TP0] run dt: 33.72383117675781 ms
[2025-11-06 13:25:44 TP0] Decode batch. #running-req: 20, #token: 861, token usage: 0.00, cuda graph: True, gen throughput (token/s): 555.80, #queue-req: 0, 
[2025-11-06 13:25:44 TP0] run dt: 33.110857009887695 ms

~5%

bench_serving end-to-end case

case 1

python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 20 --random-input 100 --random-output 1000 --max-concurrency 20 --port 30001

from

Request throughput (req/s):              0.69      
Input token throughput (tok/s):          26.82     
Output token throughput (tok/s):         354.28    
Total token throughput (tok/s):          381.10    
Concurrency:                             11.84     
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   17193.17  
Median E2E Latency (ms):                 17988.45 

to

Request throughput (req/s):              0.71      
Input token throughput (tok/s):          27.61     
Output token throughput (tok/s):         364.82    
Total token throughput (tok/s):          392.44    
Concurrency:                             11.84     
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   16699.33  
Median E2E Latency (ms):                 17450.05 

case2

python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 1000 --random-input 1000 --random-output 1 --max-concurrency 20 --port 30001

from

Request throughput (req/s):              18.13     
Input token throughput (tok/s):          9112.65   
Output token throughput (tok/s):         9.10      
Total token throughput (tok/s):          9121.75   
Concurrency:                             19.87     
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   1095.47   
Median E2E Latency (ms):                 1091.97   
---------------Time to First Token----------------
Mean TTFT (ms):                          544.65    
Median TTFT (ms):                        100.08    
P99 TTFT (ms):                           1301.67   


to

Request throughput (req/s):              19.62     
Input token throughput (tok/s):          9856.97   
Output token throughput (tok/s):         9.85      
Total token throughput (tok/s):          9866.82   
Concurrency:                             19.86     
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   1012.46   
Median E2E Latency (ms):                 1013.67   
---------------Time to First Token----------------
Mean TTFT (ms):                          503.63    
Median TTFT (ms):                        104.76    
P99 TTFT (ms):                           1214.27   

launch service cmd

python3 -m sglang.launch_server --model-path ${PATH_OF_DeepSeek_V3_0324_W4AFP8} --tool-call-parser deepseekv3 --quantization w4afp8 --tp 4 --port 30001 --trust-remote-code --context-length 10240 --page-size 1 --disable-overlap-schedule --prefill-attention-backend flashinfer --decode-attention-backend flashinfer --enable-hierarchical-cache --hicache-size 120 --enable-mixed-chunk--flashinfer-mla-disable-ragged --mem-fraction-static 0.75 --max-running-requests 20 --cuda-graph-max-bs 32 --cuda-graph-bs 1 4 8 16 20 24 32

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @Bruce-x-1997, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a series of optimizations to the w4afp8 quantization kernel, crucial for the deepseek-v3-0324 model. The changes aim to boost the kernel's efficiency, particularly for scenarios involving smaller decode batch sizes, by refining internal configurations and optimizing data access strategies. The modifications are expected to yield tangible performance improvements in both prefill and decode operations.

Highlights

  • Performance Optimization: This pull request significantly optimizes the w4afp8 kernel for the deepseek-v3-0324 model, specifically targeting improved performance when the decode batch size is less than 32.
  • Kernel Configuration & Memory Access: Key modifications include implementing fine-grained tiling configurations within the kernel and leveraging cuda-int4 memory access patterns to alleviate memory pressure, enhancing overall efficiency.
  • Performance Gains: Benchmarking results demonstrate a notable performance improvement, with approximately a 10% reduction in prefill time and a 5% reduction in decode time.
  • CUDA Kernel Refactoring: The CUDA kernels for computing problem sizes and expert offsets have been refactored to use CUB library primitives for block-wide reductions and scans, and vectorized memory access for topk_ids, leading to more efficient parallel computations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@Bruce-x-1997
Copy link
Copy Markdown
Contributor Author

@yuhyao could you please help review this one?thanks

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant performance optimizations for the w4afp8 kernel, particularly for the deepseek-v3-0324 model. The changes include fine-tuning GEMM tiling configurations and replacing slow data preparation logic with highly efficient, parallelized versions using CUB. The use of vectorized memory access and parallel prefix sums are excellent improvements. The code is also made more maintainable by reducing duplication. Overall, these are solid enhancements that deliver the benchmarked performance gains.

Comment thread sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu Outdated
@FlamingoPg FlamingoPg self-assigned this Nov 10, 2025
@FlamingoPg
Copy link
Copy Markdown
Collaborator

When this PR is ready for review, ping me.

@Bruce-x-1997 Bruce-x-1997 force-pushed the bruce_optimize_w4a8_kernel_opt branch from 6cb5b59 to 5ebd9b6 Compare November 11, 2025 03:23
@Bruce-x-1997
Copy link
Copy Markdown
Contributor Author

@FlamingoPg hello, could you help trigger ci again?I don't see w4afp8 related error in failed cases, thanks

@Bruce-x-1997 Bruce-x-1997 force-pushed the bruce_optimize_w4a8_kernel_opt branch 2 times, most recently from fbf9b89 to c6462f7 Compare November 13, 2025 10:10
@Bruce-x-1997
Copy link
Copy Markdown
Contributor Author

@FlamingoPg hello, could you help trigger it again, thanks, I see the fail case is not related with my changes

@yuhyao
Copy link
Copy Markdown
Contributor

yuhyao commented Nov 14, 2025

@yuhyao could you please help review this one?thanks

Sorry for a late replay. I will help review this PR today.

@Bruce-x-1997
Copy link
Copy Markdown
Contributor Author

@FlamingoPg hello, could you help trigger it again, thanks, I see the fail case is not related with my changes

hello, could you help trigger ci again, I see all failed cases is not related with my changes,,thanks

@yuhyao
Copy link
Copy Markdown
Contributor

yuhyao commented Nov 14, 2025

@Bruce-x-1997 Hi, please check the comments.
Also, should we change "bugfix" in the title to "perf"?

@Bruce-x-1997
Copy link
Copy Markdown
Contributor Author

@Bruce-x-1997 Hi, please check the comments. Also, should we change "bugfix" in the title to "perf"?

ok, thanks for your commnet, I will fix it asap

@Bruce-x-1997 Bruce-x-1997 changed the title [bugfix]optimize w4afp8 kernel on deepseek-v3-0324 [perf]optimize w4afp8 kernel on deepseek-v3-0324 Nov 17, 2025
Signed-off-by: bruce.xu <bruce.x@gmicloud.ai>
@Bruce-x-1997 Bruce-x-1997 force-pushed the bruce_optimize_w4a8_kernel_opt branch from c6462f7 to 7915d93 Compare November 17, 2025 04:44
@Bruce-x-1997 Bruce-x-1997 requested a review from AniZpZ as a code owner November 17, 2025 04:44
@Bruce-x-1997
Copy link
Copy Markdown
Contributor Author

@yuhyao hello, I can not see comments from you)

@yuhyao
Copy link
Copy Markdown
Contributor

yuhyao commented Nov 21, 2025

Changes look good to me now. Here are some suggestions for improving the PR description:

  • For completeness, it would be helpful to include the MMLU accuracy results (sglang already provides a benchmark script for this).
  • For performance, please consider adding the launch command so others can see the parallelism strategy you’re using. I also recommend including the output printed by sglang.bench_serving for easier comparison.

done

Thanks for adding the information to the description.
Regarding the bench_serving results, it would be helpful to include the full output of the script (including both TTFT and ITL).
Also, could you clarify the test environment, such as the GPU model and the number of GPUs used?

@Bruce-x-1997
Copy link
Copy Markdown
Contributor Author

@FlamingoPg hello, could you help trigger ci again?
I see all fails is not related with my test, does it occur in other ci as well?
thanks

@Bruce-x-1997
Copy link
Copy Markdown
Contributor Author

@FlamingoPg could you help trigger ci again?thanks

@Bruce-x-1997
Copy link
Copy Markdown
Contributor Author

@FlamingoPg hello, could you help trigger ci again?thanks
or could I trigger ci in any way?

@AniZpZ
Copy link
Copy Markdown
Collaborator

AniZpZ commented Nov 27, 2025

@FlamingoPg hello, could you help trigger ci again?thanks or could I trigger ci in any way?

u can try re-run failed ci jobs

@AniZpZ AniZpZ self-assigned this Nov 27, 2025
@Bruce-x-1997
Copy link
Copy Markdown
Contributor Author

Bruce-x-1997 commented Nov 28, 2025

@FlamingoPg hello, could you help trigger ci again?thanks or could I trigger ci in any way?

u can try re-run failed ci jobs

hello, how to rerun failed

@FlamingoPg hello, could you help trigger ci again?thanks or could I trigger ci in any way?

u can try re-run failed ci jobs

hello , how can I rerun failed ci jobs @AniZpZ , I don't find any button to rerun my failed jobs, could you tell me?

@slin1237
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

1 similar comment
@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Dec 11, 2025

/tag-and-rerun-ci

@BBuf BBuf added run-ci and removed run-ci labels Dec 11, 2025
@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Dec 15, 2025

/tag-and-rerun-ci

@BBuf BBuf removed the run-ci label Dec 15, 2025

void compute_expert_offsets_w4a8(
cudaStream_t stream, const int32_t* problem_sizes1, int32_t* expert_offsets, int n, int stride = 1, int off = 0) {
#define compute_expert_offsets_w4a8_call(BLOCK_SIZE) \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#define compute_expert_offsets_w4a8_call(BLOCK_SIZE) \
#define compute_expert_offsets_w4a8_call(BLOCK_SIZE) ...
...
#undef compute_expert_offsets_w4a8_call

use undef to clean up macro definitions inside the function.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Dec 15, 2025

/tag-and-rerun-ci

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Dec 17, 2025

/tag-and-rerun-ci

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Dec 17, 2025

/tag-and-rerun-ci

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Dec 18, 2025

@BBuf BBuf merged commit 793c96c into sgl-project:main Dec 18, 2025
132 of 150 checks passed
Prozac614 pushed a commit to Prozac614/sglang that referenced this pull request Dec 23, 2025
jiaming1130 pushed a commit to zhuyijie88/sglang that referenced this pull request Dec 25, 2025
YChange01 pushed a commit to YChange01/sglang that referenced this pull request Jan 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants