[Performance] Reduce DeepGEMM N dim restriction from 128 to 64 multiplier #28687
[Performance] Reduce DeepGEMM N dim restriction from 128 to 64 multiplier #28687
Conversation
|
Documentation preview: https://vllm--28687.org.readthedocs.build/en/28687/ |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if supports_deep_gemm is None: | ||
| supports_deep_gemm = is_deep_gemm_supported() | ||
| return ( | ||
| supports_deep_gemm | ||
| and output_dtype == torch.bfloat16 | ||
| and weight.shape[0] % 128 == 0 | ||
| and weight.shape[1] % 128 == 0 | ||
| supports_deep_gemm and output_dtype == torch.bfloat16 | ||
| # TODO [alexm-redhat]: Verify in more detail why this | ||
| # restriction was here in the first place for fp8_gemm_nt | ||
| # function of deepgemm. | ||
| # and weight.shape[0] % 128 == 0 | ||
| # and weight.shape[1] % 128 == 0 |
There was a problem hiding this comment.
Guard deepgemm call for non‑aligned FP8 weight shapes
Removing the weight.shape[0] % 128 and weight.shape[1] % 128 checks means should_use_deepgemm_for_fp8_linear now returns True for any weight size as long as DeepGEMM is available. The FP8 warm‑up code still treats DeepGEMM as only valid when both dimensions are multiples of get_mk_alignment_for_contiguous_layout()[0] (currently 128), and the DeepGEMM kernels require that block alignment to operate correctly. With this change, layers whose weight shapes are not aligned (e.g. 96×64) will skip the alignment check, call torch.ops.vllm.fp8_gemm_nt_op and can raise a runtime error or access misaligned buffers. Please keep the divisibility guard or replace it with a query to the DeepGEMM alignment API before dispatching.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Code Review
This pull request aims to improve performance by removing a dimension restriction in the should_use_deepgemm_for_fp8_linear function. While the performance gains are significant, this change introduces a critical risk of correctness issues. The removed alignment check is likely a crucial safeguard for the deepgemm kernel. I've left a detailed comment explaining the risk and suggesting that the check be restored until its removal can be proven safe through more thorough verification and testing.
| supports_deep_gemm and output_dtype == torch.bfloat16 | ||
| # TODO [alexm-redhat]: Verify in more detail why this | ||
| # restriction was here in the first place for fp8_gemm_nt | ||
| # function of deepgemm. | ||
| # and weight.shape[0] % 128 == 0 | ||
| # and weight.shape[1] % 128 == 0 |
There was a problem hiding this comment.
This change introduces a high risk of correctness issues. The removed alignment check (% 128 == 0) is a common requirement for high-performance GEMM kernels to ensure correct memory access and computation, especially when using hardware features like Tensor Cores. Without this check, the deepgemm kernel may produce incorrect results or crash for inputs with dimensions that are not multiples of 128.
Evidence from within the repository supports this concern:
- The MoE
deep_gemmintegration invllm/model_executor/layers/fused_moe/deep_gemm_moe.pyexplicitly checks for 128-byte alignment in_valid_deep_gemm_shape. - The
per_block_cast_to_fp8utility function, taken from the DeepGEMM library, pads tensors to be multiples of 128.
Given the potential for silent data corruption, removing this safeguard is a critical issue. I recommend restoring the original alignment checks to prevent potential correctness bugs.
| supports_deep_gemm and output_dtype == torch.bfloat16 | |
| # TODO [alexm-redhat]: Verify in more detail why this | |
| # restriction was here in the first place for fp8_gemm_nt | |
| # function of deepgemm. | |
| # and weight.shape[0] % 128 == 0 | |
| # and weight.shape[1] % 128 == 0 | |
| supports_deep_gemm | |
| and output_dtype == torch.bfloat16 | |
| and weight.shape[0] % 128 == 0 | |
| and weight.shape[1] % 128 == 0 |
|
We should check this on Hopper and Blackwell, and make sure to expand the kernel unit test to check for edge cases |
|
If we successfully remove the size restriction, it would be good if we could also remove the |
3a0125d to
eeff31d
Compare
yewentao256
left a comment
There was a problem hiding this comment.
We should be more careful about removing this.
I tried to remove before, and got a lot performance loss. Seems to be like dp=8 case of R1
|
@yewentao256 I have verified dp==8 DSR1 (on H200) and the performance is actually better with this PR, about 1.5% for TPOT. |
eeff31d to
1db7125
Compare
|
Also, I did a manual inspection of the deepgemm code and did not found any 128 division restriction for the gemm we use (fp8_gemm_nt). The code internally aligns the K dimension to a multiple of 128, but has no special logic for M and N. |
|
@alexm-redhat Sounds good, could you also test on Blackwell? |
|
@yewentao256 no problem, checking blackwell |
|
@yewentao256 @robertgshaw2-redhat Verified performance for DSR1 on 8xB200 system that has DeepGEMM installed. For TP==8 and DP==8 (+EP), performance with the PR or without is the same. |
This is strange, could you show the command you use and the corresponding result? |
|
@yewentao256 commands used: For client: For server: For TP Results, for TP, TPOT is 38.12ms with PR, 38.89 without PR, for DP, TPOT is 37.9 with PR and 37.6 without PR. |
yewentao256
left a comment
There was a problem hiding this comment.
vllm bench serve --model deepseek-ai/DeepSeek-R1 --dataset-name random --host 127.0.0.1 --random-input-len 4 --random-output-len 1024 --request-rate inf --num-prompts 1024 --port 9256
For larger batch size, seems not hurt in performance now
This PR
============ Serving Benchmark Result ============
Successful requests: 1024
Failed requests: 0
Benchmark duration (s): 80.35
Total input tokens: 3072
Total generated tokens: 1048576
Request throughput (req/s): 12.74
Output token throughput (tok/s): 13050.34
Peak output token throughput (tok/s): 14209.00
Peak concurrent requests: 1024.00
Total Token throughput (tok/s): 13088.57
---------------Time to First Token----------------
Mean TTFT (ms): 1401.40
Median TTFT (ms): 1430.52
P99 TTFT (ms): 1462.59
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 77.08
Median TPOT (ms): 77.07
P99 TPOT (ms): 77.21
---------------Inter-token Latency----------------
Mean ITL (ms): 77.08
Median ITL (ms): 76.91
P99 ITL (ms): 116.87
==================================================Without
============ Serving Benchmark Result ============
Successful requests: 1024
Failed requests: 0
Benchmark duration (s): 80.23
Total input tokens: 3072
Total generated tokens: 1048576
Request throughput (req/s): 12.76
Output token throughput (tok/s): 13069.67
Peak output token throughput (tok/s): 14143.00
Peak concurrent requests: 1024.00
Total Token throughput (tok/s): 13107.96
---------------Time to First Token----------------
Mean TTFT (ms): 1141.24
Median TTFT (ms): 1143.51
P99 TTFT (ms): 1225.30
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 77.23
Median TPOT (ms): 77.24
P99 TPOT (ms): 77.35
---------------Inter-token Latency----------------
Mean ITL (ms): 77.23
Median ITL (ms): 77.24
P99 ITL (ms): 118.71
==================================================And lm_eval looks good.
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.9507|± | 0.006|
| | |strict-match | 5|exact_match|↑ |0.9492|± | 0.006||
@yewentao256 thanks for checking |
There was a problem hiding this comment.
@alexm-redhat Can we please add unit test cases before landing? See where we currently skip many shapes
vllm/tests/kernels/quantization/test_block_fp8.py
Lines 159 to 162 in c3e2978
And I think the current N and K in this file don't test all shapes of interest. If this test isn't currently running due to lack of Hopper in CI, we should add it to the Blackwell runner
I also agree with the suggestion from @ElizaWszola to simplify W8A8BlockFp8LinearOp |
|
@mgoin was not aware of this test, will enable the shapes and see how to run it. |
… of 128 Signed-off-by: Alexander Matveev <amatveev@redhat.com>
a858cd5 to
d182444
Compare
|
@mgoin after enabling the test I found out that N/K are restricted to the sizes that the test is using. So we do need to check that N%64==0 and K%128==0. In general, it worked because the fused_qkv N dim is 2112 which is divisible by 64. Thanks for pointing out this test. |
|
Fixed the check to N%64 and K%128 |
|
Thanks Alex, can you update the title and description to the final state? |
|
@mgoin Updated title and description, thanks for the detailed review! |
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
…lier (vllm-project#28687) Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
…lier (vllm-project#28687) Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
W8A8BlockFp8LinearOp class uses should_use_deepgemm_for_fp8_linear() function to determine which gemm to run with deepgemm fp8_gemm_nt kernel. Originally it was checking that N dim is divisible by 128, however, after some inspection of the deepgemm tests, I found that 64 works as well. This allows fused_qkv_a_proj to run with deepgemm kernel (and not cutlass) and is ~1.83X faster (after inspecting profiles).
With cutlass:

With deepgemm:
