Skip to content

[Performance] Reduce DeepGEMM N dim restriction from 128 to 64 multiplier #28687

Merged
vllm-bot merged 6 commits intomainfrom
relax_deepgemm
Nov 19, 2025
Merged

[Performance] Reduce DeepGEMM N dim restriction from 128 to 64 multiplier #28687
vllm-bot merged 6 commits intomainfrom
relax_deepgemm

Conversation

@alexm-redhat
Copy link
Copy Markdown
Collaborator

@alexm-redhat alexm-redhat commented Nov 13, 2025

W8A8BlockFp8LinearOp class uses should_use_deepgemm_for_fp8_linear() function to determine which gemm to run with deepgemm fp8_gemm_nt kernel. Originally it was checking that N dim is divisible by 128, however, after some inspection of the deepgemm tests, I found that 64 works as well. This allows fused_qkv_a_proj to run with deepgemm kernel (and not cutlass) and is ~1.83X faster (after inspecting profiles).

With cutlass:
image

With deepgemm:
image

  • lm_eval correctness verification shows that no accuracy loss.
  • e2e TPOT improvement is from 28.32ms to 27.72 so about 2.2%

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Nov 13, 2025

Documentation preview: https://vllm--28687.org.readthedocs.build/en/28687/

@mergify mergify bot added the documentation Improvements or additions to documentation label Nov 13, 2025
@alexm-redhat alexm-redhat self-assigned this Nov 13, 2025
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread vllm/utils/deep_gemm.py Outdated
Comment on lines +343 to +351
if supports_deep_gemm is None:
supports_deep_gemm = is_deep_gemm_supported()
return (
supports_deep_gemm
and output_dtype == torch.bfloat16
and weight.shape[0] % 128 == 0
and weight.shape[1] % 128 == 0
supports_deep_gemm and output_dtype == torch.bfloat16
# TODO [alexm-redhat]: Verify in more detail why this
# restriction was here in the first place for fp8_gemm_nt
# function of deepgemm.
# and weight.shape[0] % 128 == 0
# and weight.shape[1] % 128 == 0
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Guard deepgemm call for non‑aligned FP8 weight shapes

Removing the weight.shape[0] % 128 and weight.shape[1] % 128 checks means should_use_deepgemm_for_fp8_linear now returns True for any weight size as long as DeepGEMM is available. The FP8 warm‑up code still treats DeepGEMM as only valid when both dimensions are multiples of get_mk_alignment_for_contiguous_layout()[0] (currently 128), and the DeepGEMM kernels require that block alignment to operate correctly. With this change, layers whose weight shapes are not aligned (e.g. 96×64) will skip the alignment check, call torch.ops.vllm.fp8_gemm_nt_op and can raise a runtime error or access misaligned buffers. Please keep the divisibility guard or replace it with a query to the DeepGEMM alignment API before dispatching.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to improve performance by removing a dimension restriction in the should_use_deepgemm_for_fp8_linear function. While the performance gains are significant, this change introduces a critical risk of correctness issues. The removed alignment check is likely a crucial safeguard for the deepgemm kernel. I've left a detailed comment explaining the risk and suggesting that the check be restored until its removal can be proven safe through more thorough verification and testing.

Comment thread vllm/utils/deep_gemm.py Outdated
Comment on lines +346 to +351
supports_deep_gemm and output_dtype == torch.bfloat16
# TODO [alexm-redhat]: Verify in more detail why this
# restriction was here in the first place for fp8_gemm_nt
# function of deepgemm.
# and weight.shape[0] % 128 == 0
# and weight.shape[1] % 128 == 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change introduces a high risk of correctness issues. The removed alignment check (% 128 == 0) is a common requirement for high-performance GEMM kernels to ensure correct memory access and computation, especially when using hardware features like Tensor Cores. Without this check, the deepgemm kernel may produce incorrect results or crash for inputs with dimensions that are not multiples of 128.

Evidence from within the repository supports this concern:

  • The MoE deep_gemm integration in vllm/model_executor/layers/fused_moe/deep_gemm_moe.py explicitly checks for 128-byte alignment in _valid_deep_gemm_shape.
  • The per_block_cast_to_fp8 utility function, taken from the DeepGEMM library, pads tensors to be multiples of 128.

Given the potential for silent data corruption, removing this safeguard is a critical issue. I recommend restoring the original alignment checks to prevent potential correctness bugs.

Suggested change
supports_deep_gemm and output_dtype == torch.bfloat16
# TODO [alexm-redhat]: Verify in more detail why this
# restriction was here in the first place for fp8_gemm_nt
# function of deepgemm.
# and weight.shape[0] % 128 == 0
# and weight.shape[1] % 128 == 0
supports_deep_gemm
and output_dtype == torch.bfloat16
and weight.shape[0] % 128 == 0
and weight.shape[1] % 128 == 0

@mgoin
Copy link
Copy Markdown
Member

mgoin commented Nov 13, 2025

We should check this on Hopper and Blackwell, and make sure to expand the kernel unit test to check for edge cases

@ElizaWszola
Copy link
Copy Markdown
Contributor

If we successfully remove the size restriction, it would be good if we could also remove the should_use_deepgemm_for_fp8_linear condition from W8A8BlockFp8LinearOp.apply() and decide whether we want to run deepgemm with _dispatch_w8a8_blockscale_op()

Comment thread examples/offline_inference/basic/basic.py Outdated
Copy link
Copy Markdown
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be more careful about removing this.
I tried to remove before, and got a lot performance loss. Seems to be like dp=8 case of R1

@alexm-redhat
Copy link
Copy Markdown
Collaborator Author

alexm-redhat commented Nov 17, 2025

@yewentao256 I have verified dp==8 DSR1 (on H200) and the performance is actually better with this PR, about 1.5% for TPOT.

@alexm-redhat
Copy link
Copy Markdown
Collaborator Author

Also, I did a manual inspection of the deepgemm code and did not found any 128 division restriction for the gemm we use (fp8_gemm_nt). The code internally aligns the K dimension to a multiple of 128, but has no special logic for M and N.

@yewentao256
Copy link
Copy Markdown
Member

@alexm-redhat Sounds good, could you also test on Blackwell?

@alexm-redhat
Copy link
Copy Markdown
Collaborator Author

@yewentao256 no problem, checking blackwell

@alexm-redhat
Copy link
Copy Markdown
Collaborator Author

@yewentao256 @robertgshaw2-redhat Verified performance for DSR1 on 8xB200 system that has DeepGEMM installed. For TP==8 and DP==8 (+EP), performance with the PR or without is the same.

@yewentao256
Copy link
Copy Markdown
Member

@yewentao256 @robertgshaw2-redhat Verified performance for DSR1 on 8xB200 system that has DeepGEMM installed. For TP==8 and DP==8 (+EP), performance with the PR or without is the same.

This is strange, could you show the command you use and the corresponding result?

@alexm-redhat
Copy link
Copy Markdown
Collaborator Author

@yewentao256 commands used:

For client:
vllm bench serve --port 8123 --model deepseek-ai/DeepSeek-R1-0528 --dataset-name random --max-concurrency 32 --random-input-len 4 --random-output-len 1024 --num-prompts 80 --seed 1763417423 --percentile-metrics ttft,tpot,itl,e2el --metric-percentiles 90,95,99 --ignore-eos --trust-remote-code

For server:

For TP
vllm serve deepseek-ai/DeepSeek-R1-0528 --port 8123 --no-enable-prefix-caching --max-num-seqs 128 --max-model-len 16384 --tensor-parallel-size 8
For DP
vllm serve deepseek-ai/DeepSeek-R1-0528 --port 8123 --no-enable-prefix-caching --max-num-seqs 128 --max-model-len 16384 --enable-expert-parallel --data-parallel-size 8

Results, for TP, TPOT is 38.12ms with PR, 38.89 without PR, for DP, TPOT is 37.9 with PR and 37.6 without PR.

Copy link
Copy Markdown
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vllm bench serve --model deepseek-ai/DeepSeek-R1 --dataset-name random --host 127.0.0.1 --random-input-len 4 --random-output-len 1024 --request-rate inf --num-prompts 1024 --port 9256

For larger batch size, seems not hurt in performance now

This PR

============ Serving Benchmark Result ============
Successful requests:                     1024      
Failed requests:                         0         
Benchmark duration (s):                  80.35     
Total input tokens:                      3072      
Total generated tokens:                  1048576   
Request throughput (req/s):              12.74     
Output token throughput (tok/s):         13050.34  
Peak output token throughput (tok/s):    14209.00  
Peak concurrent requests:                1024.00   
Total Token throughput (tok/s):          13088.57  
---------------Time to First Token----------------
Mean TTFT (ms):                          1401.40   
Median TTFT (ms):                        1430.52   
P99 TTFT (ms):                           1462.59   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          77.08     
Median TPOT (ms):                        77.07     
P99 TPOT (ms):                           77.21     
---------------Inter-token Latency----------------
Mean ITL (ms):                           77.08     
Median ITL (ms):                         76.91     
P99 ITL (ms):                            116.87    
==================================================

Without

============ Serving Benchmark Result ============
Successful requests:                     1024      
Failed requests:                         0         
Benchmark duration (s):                  80.23     
Total input tokens:                      3072      
Total generated tokens:                  1048576   
Request throughput (req/s):              12.76     
Output token throughput (tok/s):         13069.67  
Peak output token throughput (tok/s):    14143.00  
Peak concurrent requests:                1024.00   
Total Token throughput (tok/s):          13107.96  
---------------Time to First Token----------------
Mean TTFT (ms):                          1141.24   
Median TTFT (ms):                        1143.51   
P99 TTFT (ms):                           1225.30   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          77.23     
Median TPOT (ms):                        77.24     
P99 TPOT (ms):                           77.35     
---------------Inter-token Latency----------------
Mean ITL (ms):                           77.23     
Median ITL (ms):                         77.24     
P99 ITL (ms):                            118.71    
==================================================

And lm_eval looks good.

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.9507|±  | 0.006|
|     |       |strict-match    |     5|exact_match||0.9492|±  | 0.006|

@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 18, 2025
@alexm-redhat
Copy link
Copy Markdown
Collaborator Author

@yewentao256 thanks for checking

@alexm-redhat alexm-redhat enabled auto-merge (squash) November 18, 2025 18:15
Copy link
Copy Markdown
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexm-redhat Can we please add unit test cases before landing? See where we currently skip many shapes

def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
# only aligned sizes
if M % 4 != 0 or K % 128 != 0 or N % 64 != 0:
pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}")

And I think the current N and K in this file don't test all shapes of interest. If this test isn't currently running due to lack of Hopper in CI, we should add it to the Blackwell runner

@mgoin
Copy link
Copy Markdown
Member

mgoin commented Nov 18, 2025

If we successfully remove the size restriction, it would be good if we could also remove the should_use_deepgemm_for_fp8_linear condition from W8A8BlockFp8LinearOp.apply() and decide whether we want to run deepgemm with _dispatch_w8a8_blockscale_op()

I also agree with the suggestion from @ElizaWszola to simplify W8A8BlockFp8LinearOp

@alexm-redhat
Copy link
Copy Markdown
Collaborator Author

@mgoin was not aware of this test, will enable the shapes and see how to run it.

… of 128

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
@alexm-redhat
Copy link
Copy Markdown
Collaborator Author

@mgoin after enabling the test I found out that N/K are restricted to the sizes that the test is using. So we do need to check that N%64==0 and K%128==0. In general, it worked because the fused_qkv N dim is 2112 which is divisible by 64. Thanks for pointing out this test.

@alexm-redhat
Copy link
Copy Markdown
Collaborator Author

Fixed the check to N%64 and K%128

@mgoin
Copy link
Copy Markdown
Member

mgoin commented Nov 19, 2025

Thanks Alex, can you update the title and description to the final state?

@alexm-redhat alexm-redhat changed the title [Performance] Remove deepgemm N/K dim restriction of being a multiple of 128 [Performance] Reduce DeepGEMM N dim restriction from 128 to 64 multiplier Nov 19, 2025
@alexm-redhat
Copy link
Copy Markdown
Collaborator Author

@mgoin Updated title and description, thanks for the detailed review!

Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
@mgoin mgoin added performance Performance-related issues deepseek Related to DeepSeek models labels Nov 19, 2025
@vllm-bot vllm-bot merged commit 3aaa94a into main Nov 19, 2025
5 of 8 checks passed
@vllm-bot vllm-bot deleted the relax_deepgemm branch November 19, 2025 23:47
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
…lier (vllm-project#28687)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
kitaekatt pushed a commit to kitaekatt/vllm that referenced this pull request Dec 1, 2025
…lier (vllm-project#28687)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build deepseek Related to DeepSeek models documentation Improvements or additions to documentation performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants