[2/2] support dp mla#5001
Conversation
|
The performance results of MLA in early March. The latest performance may have changed, but the pattern should be similar. |
|
Could you add DP attention in the benchmarks? |
|
Attempted to load this with AWQ, got this error: command: 16x3090. I'm possibly doing something wrong with commands etc. Or maybe it's just not available for awq yet. Thanks for any help. |
you can remove this code in deepseek_v2.py. However, the custom alltoall we currently implement is for a maximum of 8 devices, and the relevant kernel implementation needs to be changed to 16. Other codes related to the number of device may need to be adapted. if (
self_attn.enable_dp_mla
): # Use kv_b_proj's space to avoid double memory
# 512, 128/tp*(128+128)
weigth_data = self_attn.kv_b_proj.weight.data
kv_stack = [
w_kc.to(torch.bfloat16),
w_vc.view(*w_kc.shape).to(torch.bfloat16),
]
new_weight = torch.cat(kv_stack).to(w_kc.dtype).contiguous()
weigth_data.copy_(new_weight.view(weigth_data.shape))
weight_data = weigth_data.view(2, *w_kc.shape)
w_kc = weight_data[0].view(*w_kc.shape)
w_vc = weight_data[1].view(*w_vc.shape) |
in 8*H20(96GB), weight mem usage=87.19 GB when |
dfa4a39 to
d6544f5
Compare
d0d455b to
1d569b5
Compare
|
Could you provide the specific error in detail?
|
|
For example: |
May depend on an older version of repo-flash-attention, not the latest one. such as GIT_TAG bf1e4ce51fc85370b9489976e8ccbe72eb71fefe |
This pr has rebased with newest main |
Motivation
Base on dp_mla_kernel PR #5000
Description:
On an 8*H20(96GB), weight mem usage=87.19 GB when
--dp-size 4 --enable-dp-attention, not enough memory left.This optimization is similar to data parallelism attention, but it applies to MLA core instead of the entire attention. Compared with data parallelism attention, it does not additionally increase the memory occupied by weights. It allows for a significant reduction in the KV cache size and enables larger batch sizes with only 1-3ms additional decode latency.
On an 8×H20 (96GB) node, with data parallelism MLA enabled, we have achieved up to 1.85x(dp=4) ~ 2.34x(dp=8) decoding throughput improvement compared to the previous version. And, the number of kvcaches has been increased by 3.3x(dp=4) and 6.6x(dp=8).
For 8*H20 (96GB), when the input is 4000, the output is 1500, and there are 128 concurrent requests, the output throughput per card for decoding can reach 266 tokens/s.
gsm8k:
# dp=8 Accuracy: 0.955 Invalid: 0.000 Latency: 165.833 s Output throughput: 781.248 token/s# tp=8 Accuracy: 0.954 Invalid: 0.000 Latency: 177.669 s Output throughput: 719.521 token/sUsage:
Restrictions:
On an 8×H20 (96GB) node, on top of
--attention-backend flashinferor--attention-backend fa3, remove--enable-dp-lm-head, additional parameters:# dp=4: --max-running-requests 128 --mem-fraction-static 0.94 --chunked-prefill-size 8192 --enable-dp-mla --dp-size 4# dp=8: --max-running-requests 128 --mem-fraction-static 0.94 --chunked-prefill-size 8192 --enable-dp-mla --dp-size 8We set the switch
flashinfer_mla_disable_raggedtoTruewhen--enable-dp-mlato avoid the following issues:sglang.bench_one_batch_server. Through testing the modeldeepseek-ai/DeepSeek-Coder-V2-Lite-Instruct, we found that after the warmup process, there are NaN values in the kvcache, which leads to NaN values in the results of subsequent requests. After troubleshooting, it was determined that the result of the flashinfer-mla ragged process is NaN. The reproduction parameters are:Background:
For DeepSeek R1 on a single machine with 8*H20 (96GB), we aimed to deploy enable_dp_attention. We found:
TODO:
Accuracy and performance testing and optimization for various scenarios.Performance optimization of custom all to all.Memory:
Performance Comparison:
python -m sglang.bench_one_batch_server --model None --base-url http://localhost:8188 --batch-size {bs} --input-len {input_seq} --output-len 1024Modifications
Use dp for mla core calculations and tp for others, keeping memory and calculation time of linear in attn.
Checklist