Skip to content

DeepSeek-V3.2: Add Adaptive MHA Attention Pathway for Short-Sequence Prefill#11892

Merged
Fridge003 merged 22 commits intosgl-project:mainfrom
YAMY1234:dpsk_mla_opt
Nov 6, 2025
Merged

DeepSeek-V3.2: Add Adaptive MHA Attention Pathway for Short-Sequence Prefill#11892
Fridge003 merged 22 commits intosgl-project:mainfrom
YAMY1234:dpsk_mla_opt

Conversation

@YAMY1234
Copy link
Copy Markdown
Collaborator

@YAMY1234 YAMY1234 commented Oct 21, 2025

Motivation

For DeepSeek-V3.2 models, using MLA (Multi-Latent Attention) uniformly across all sequence lengths during prefill is suboptimal. For short sequences, the overhead of MLA's compression/decompression and absorbed attention mechanism outweighs any potential benefits, making standard MHA (Multi-Head Attention) more efficient. This PR implements adaptive attention mechanism selection based on sequence length to optimize inference performance across different workloads.

Modifications

This PR implements sequence-length-based adaptive attention mechanism selection in the NSA (Native Sparse Attention) backend:

Core Changes

  1. Add MHA mode detection and processing in nsa_backend.py

    • Implement _forward_standard_mha() method using FlashAttention varlen interface for standard MHA
    • Support both prefix chunk processing and regular token processing
    • Correctly handle LSE (log-sum-exp) return values for DP attention compatibility
  2. Implement intelligent attention mechanism selection in deepseek_v2.py

    • Add adaptive selection logic in handle_attention_nsa() based on sequence length:
      • Decode phase: Always use MLA (avoids per-token decompression overhead)
      • Prefill with seq_len < 2048: Use MHA_CHUNKED_KV (lower computational overhead)
      • Prefill with seq_len >= 2048: Use MLA (absorbed attention mechanism becomes beneficial at scale)
    • Add configurable nsa_seq_len_threshold parameter (default: 2048) for flexible threshold tuning

Note: This threshold is also chosen because of NSA's topk sparse filtering, which operates within the MLA pathway.

Accuracy Tests

GSM8K:

hostuser@1ad3348c50a4:/sgl-workspace/sglang$ python3 benchmark/gsm8k/bench_sglang.py --num-shots 4 --num-questions 100 --parallel 100
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:07<00:00, 13.10it/s]
Accuracy: 0.970
Invalid: 0.000
Latency: 8.016 s
Output throughput: 1080.888 token/s

GPQA:

(optional:  export SGLANG_MHA_USE_WKC_WVC=1)
python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --dp 8 --enable-dp-attention

python3 -m sglang.test.run_eval --port 30000 --eval-name gpqa --num-examples 198 --max-tokens 36000 --repeat 8 --thinking-mode deepseek-v3
With export SGLANG_MHA_USE_WKC_WVC=1: 
Repeat: 8, mean: 0.798█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍        | 187/198 [27:45<01:04,  5.87s/it]
Scores: ['0.793', '0.788', '0.818', '0.783', '0.813', '0.763', '0.813', '0.813']███████████████████████████████████████████████████████████████████████▍        | 187/198 [29:01<00:52,  4.75s/it]
====================███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍        | 187/198 [28:03<00:37,  3.37s/it]
Writing report to /tmp/gpqa_deepseek-ai_DeepSeek-V3.2-Exp.html
{'chars': np.float64(14716.31313131313), 'chars:std': np.float64(11968.213414131249), 'score:std': np.float64(0.3898060809385348), 'score': np.float64(0.8131313131313131)}
Writing results to /tmp/gpqa_deepseek-ai_DeepSeek-V3.2-Exp.json
Total latency: 1666.190 s
Score: 0.813
Without export SGLANG_MHA_USE_WKC_WVC=1: 
Repeat: 8, mean: 0.790
Scores: ['0.793', '0.768', '0.788', '0.808', '0.773', '0.793', '0.818', '0.783']█████████████████████████████████████████████████████████████████████▊          | 185/198 [27:52<00:49,  3.81s/it]
====================█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊          | 185/198 [27:00<00:46,  3.59s/it]
Writing report to /tmp/gpqa_deepseek-ai_DeepSeek-V3.2-Exp.html
{'chars': np.float64(15173.823232323231), 'chars:std': np.float64(11838.867709818003), 'score:std': np.float64(0.41232046084617835), 'score': np.float64(0.7828282828282829)}
Writing results to /tmp/gpqa_deepseek-ai_DeepSeek-V3.2-Exp.json
Total latency: 1652.339 s
Score: 0.783

Benchmarking and Profiling

1024 tokens, 100 requests

Metric Original (MLA) This PR (MHA) Improvement
Mean E2E Latency 4334.95 ms 3820.75 ms 11.9%
Mean TTFT 4279.27 ms 3820.73 ms 10.7%
Input Throughput 14213.70 tok/s 15479.94 tok/s 8.9%
Benchmark Duration 7.20 s 6.62 s 8.1%

512 tokens, 100 requests

Metric Original (MLA) This PR (MHA) Improvement
Mean E2E Latency 3060.32 ms 2819.46 ms 7.9%
Mean TTFT 3060.29 ms 2977.96 ms 2.7%
Input Throughput 9909.07 tok/s 10662.89 tok/s 7.6%
Benchmark Duration 5.17 s 5.30 s −2.5%

2000 tokens, 100 requests

Metric Original (MLA) This PR (MHA) Improvement
Mean E2E Latency 5797.70 ms 4819.40 ms 16.9%
Mean TTFT 5601.00 ms 4819.37 ms 14.0%
Input Throughput 19331.60 tok/s 22249.25 tok/s 15.1%
Benchmark Duration 10.35 s 8.99 s 13.1%

Kernel level analysis:

Forward Total Time

  • MHA + Skip TopK: 2-2.5 ms
image
  • Original: 3.5-4 ms
image - **Improvement**: ~30-40% reduction

Prepare Stage

  • MHA + Skip TopK (forward_normal_prepare/forward_normal_chunked_kv_prepare + topk skipped indexer):
  • 1.5 -2 ms
image
  • Original (forward_absorb_prepare):
  • 2.5 - 3.5 ms
image
  • Improvement: ~30% reduction

Core Stage

  • This PR (forward_normal_chunked_kv_core): ~300 μs
image - **Original** (`forward_absorb_core`): 500-600 μs image - **Improvement**: ~40-50% reduction

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @YAMY1234, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly optimizes the inference performance of DeepSeek-V3.2 models by introducing an adaptive attention mechanism. It intelligently selects between Multi-Head Attention (MHA) and Multi-Latent Attention (MLA) based on the sequence length during the prefill stage. This ensures that the more efficient MHA is used for shorter sequences where MLA's overhead is detrimental, while MLA's benefits for longer sequences are preserved. The change results in notable reductions in latency and increases in throughput across various sequence lengths.

Highlights

  • Adaptive Attention Mechanism: Introduces an adaptive attention mechanism for DeepSeek-V3.2 models, dynamically switching between Multi-Latent Attention (MLA) and Multi-Head Attention (MHA) during the prefill phase based on sequence length.
  • Optimized Prefill for Short Sequences: For short sequences (below a configurable threshold, default 2048 tokens) during prefill, the system now utilizes MHA_CHUNKED_KV for improved efficiency by avoiding MLA's compression/decompression overhead.
  • MLA for Long Sequences and Decode: For longer sequences (at or above the threshold) during prefill, MLA is retained to leverage its absorbed attention mechanism and sparse filtering benefits. The decode phase consistently uses MLA to avoid per-token decompression overhead.
  • New MHA Implementation: A new _forward_standard_mha() method has been implemented in nsa_backend.py to support standard MHA using the FlashAttention varlen interface, including handling prefix chunks and Log-Sum-Exp (LSE) return values.
  • Configurable Threshold: A configurable nsa_seq_len_threshold parameter (default 2048) has been added, allowing for flexible tuning of the switching point between MHA and MLA.
  • Performance Improvements: Performance benchmarks demonstrate significant improvements, with end-to-end latency reductions of up to 24.4% for 1024-token requests and substantial gains in input throughput and kernel-level efficiency.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an adaptive attention mechanism for DeepSeek-V3.2 models, switching between Multi-Head Attention (MHA) for short sequences and Multi-Latent Attention (MLA) for longer ones to optimize performance. The changes are well-structured, with clear logic for selecting the attention pathway based on sequence length during the prefill phase. The new _forward_standard_mha method in the NSA backend is a clean implementation for handling the MHA path.

My review focuses on the correctness and efficiency of the new logic. I've found a minor opportunity for improvement in python/sglang/srt/models/deepseek_v2.py to make the code more efficient and idiomatic. Overall, this is a solid contribution with impressive performance gains demonstrated in the benchmarks.

Comment thread python/sglang/srt/models/deepseek_v2.py Outdated
@YAMY1234 YAMY1234 marked this pull request as draft October 21, 2025 05:01
@Fridge003
Copy link
Copy Markdown
Collaborator

@YAMY1234 Can you post some accuracy results on GPQA ?

@hlu1
Copy link
Copy Markdown
Collaborator

hlu1 commented Oct 24, 2025

Please test with MTP and make sure it's not broken.

@YAMY1234 YAMY1234 force-pushed the dpsk_mla_opt branch 2 times, most recently from 1eaabcc to 8a8cb9a Compare October 28, 2025 18:11
@YAMY1234 YAMY1234 marked this pull request as ready for review October 28, 2025 18:16
@YAMY1234
Copy link
Copy Markdown
Collaborator Author

@YAMY1234 Can you post some accuracy results on GPQA ?

Thanks! Just added the GPQA result.

@YAMY1234
Copy link
Copy Markdown
Collaborator Author

Please test with MTP and make sure it's not broken.

Ensured MHA is currently triggered without MTP. Will test with MTP later to confirm it runs correctly.

@hlu1
Copy link
Copy Markdown
Collaborator

hlu1 commented Oct 29, 2025

Please update the benchmark data after adding back self.indexer.

@YAMY1234 YAMY1234 force-pushed the dpsk_mla_opt branch 2 times, most recently from ad79486 to df02529 Compare October 30, 2025 23:30
@YAMY1234
Copy link
Copy Markdown
Collaborator Author

Please update the benchmark data after adding back self.indexer.

Updated the results with adding back indexer but skipping topK!

Comment thread python/sglang/srt/models/deepseek_v2.py Outdated
Comment thread python/sglang/srt/models/deepseek_v2.py Outdated
Comment thread python/sglang/srt/layers/attention/nsa/nsa_indexer.py Outdated
Comment thread python/sglang/srt/models/deepseek_v2.py Outdated
@YAMY1234 YAMY1234 force-pushed the dpsk_mla_opt branch 2 times, most recently from f610f1e to a2b76e6 Compare November 2, 2025 00:19
@Fridge003
Copy link
Copy Markdown
Collaborator

@YAMY1234 Please fix lint

@YAMY1234
Copy link
Copy Markdown
Collaborator Author

YAMY1234 commented Nov 3, 2025

@YAMY1234 Please fix lint

@Fridge003 Done! Thanks

Comment thread python/sglang/srt/models/deepseek_v2.py Outdated
Comment thread python/sglang/srt/models/deepseek_v2.py Outdated
Comment thread python/sglang/srt/models/deepseek_v2.py
Comment thread python/sglang/srt/layers/attention/nsa_backend.py Outdated
Comment thread python/sglang/srt/layers/attention/nsa/nsa_indexer.py Outdated
Comment thread python/sglang/srt/layers/attention/nsa/nsa_indexer.py Outdated
Comment thread python/sglang/srt/layers/attention/nsa/nsa_indexer.py Outdated
Comment thread python/sglang/srt/models/deepseek_v2.py Outdated
Comment thread python/sglang/srt/models/deepseek_v2.py Outdated
Comment thread python/sglang/srt/models/deepseek_v2.py Outdated
Comment thread python/sglang/srt/layers/attention/nsa/nsa_indexer.py Outdated
Comment thread python/sglang/srt/layers/attention/nsa/nsa_indexer.py
Comment thread python/sglang/srt/models/deepseek_v2.py Outdated
@Fridge003
Copy link
Copy Markdown
Collaborator

Great job! @YAMY1234

@Fridge003 Fridge003 merged commit f235498 into sgl-project:main Nov 6, 2025
177 of 188 checks passed
@thqq479
Copy link
Copy Markdown

thqq479 commented Dec 11, 2025

I noticed that decoding in a context smaller than 2048 also spends a lot of time in the indexer() function. Why is that? Can't we skip this step during decoding? @YAMY1234

@hlu1
Copy link
Copy Markdown
Collaborator

hlu1 commented Dec 11, 2025

In the case of decoding, because of the limitations of CudaGraph, we'll need to instantiate two versions of CudaGraph, one with indexer, and one with skipping logits computation. Given that prompts with <2k tokens are not the majority of the cases, I don't think it's worth the effort or complexity.

@thqq479
Copy link
Copy Markdown

thqq479 commented Dec 12, 2025

In the case of decoding, because of the limitations of CudaGraph, we'll need to instantiate two versions of CudaGraph, one with indexer, and one with skipping logits computation. Given that prompts with <2k tokens are not the majority of the cases, I don't think it's worth the effort or complexity.

So, you mean that in a cudagraph, an if branch will only capture one result? Therefore, judging by length is invalid in a cudagraph?

@thqq479
Copy link
Copy Markdown

thqq479 commented Dec 12, 2025

Perhaps piecewise CUDA graphs can solve this issue. I believe skipping the attention graph would be much more cost-effective than performing the default indexing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants