Skip to content

[DeepSeek-V3.2][NSA] Enable MHA Pathway for Short Sequence Prefill on B200 (SM100)#12788

Merged
Fridge003 merged 3 commits intosgl-project:mainfrom
YAMY1234:mha_b200
Nov 7, 2025
Merged

[DeepSeek-V3.2][NSA] Enable MHA Pathway for Short Sequence Prefill on B200 (SM100)#12788
Fridge003 merged 3 commits intosgl-project:mainfrom
YAMY1234:mha_b200

Conversation

@YAMY1234
Copy link
Copy Markdown
Collaborator

@YAMY1234 YAMY1234 commented Nov 6, 2025

Motivation

Follow up this PR: DeepSeek-V3.2: Add Adaptive MHA Attention Pathway for Short-Sequence Prefill. Enable and optimize MHA on B200 (SM100) in the NSA backend by TRT-LLM ragged attention.

Modifications

  • Add TRT-LLM ragged attention for SM100: Integrate flashinfer.prefill.trtllm_ragged_attention_deepseek for Blackwell (B200) architecture to provide better accuracy than FA4
  • Remove chunked prefix cache support: Align NSA backend with PR DeepSeek-V3.2: Add Adaptive MHA Attention Pathway for Short-Sequence Prefill #11892 by removing chunked KV cache logic, keeping only MHA_ONE_SHOT mode for standard multi-head attention

Key changes in nsa_backend.py:

  • Simplified _forward_standard_mha() to handle both SM90 (FA3) and SM100 (TRT-LLM) in a single function
  • Removed _forward_trtllm_ragged_mha() helper function for cleaner code structure
  • Added device capability check to conditionally allocate workspace buffer

Accuracy Tests

Repeat: 8, mean: 0.792                                           | 24/198 [17:23<1:38:01, 33.80s/it]
Scores: ['0.783', '0.788', '0.783', '0.803', '0.823', '0.788', '0.773', '0.793']
====================
Writing report to /tmp/gpqa_deepseek-ai_DeepSeek-V3.2-Exp.html
{'chars': np.float64(15007.823232323231), 'chars:std': np.float64(11964.504345200807), 'score:std': np.float64(0.40520665017240837), 'score': np.float64(0.7929292929292929)}
Writing results to /tmp/gpqa_deepseek-ai_DeepSeek-V3.2-Exp.json
Total latency: 1104.677 s
Score: 0.793

Benchmarking and Profiling

python -m sglang.launch_server \
  --model deepseek-ai/DeepSeek-V3.2-Exp \
  --tp 8 --dp 8 --enable-dp-attention \
  --kv-cache-dtype bf16
python3 -m sglang.bench_serving     --backend sglang     --dataset-name random     --num-prompts 100     --random-input-len 2000     --random-output-len 1     --disable-ignore-eos --random-range-ratio 1.0

This PR: MHA(trtllm_ragged_attention_deepseek) + skip topk

============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    inf       
Max request concurrency:                 not set   
Successful requests:                     100       
Benchmark duration (s):                  6.56      
Total input tokens:                      200000    
Total input text tokens:                 200000    
Total input vision tokens:               0         
Total generated tokens:                  100       
Total generated tokens (retokenized):    95        
Request throughput (req/s):              15.25     
Input token throughput (tok/s):          30507.62  
Output token throughput (tok/s):         15.25     
Total token throughput (tok/s):          30522.87  
Concurrency:                             62.85     
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   4120.22   
Median E2E Latency (ms):                 4129.15   
---------------Time to First Token----------------
Mean TTFT (ms):                          3950.18   
Median TTFT (ms):                        3915.54   
P99 TTFT (ms):                           6528.78   
---------------Inter-Token Latency----------------
Mean ITL (ms):                           0.00      
Median ITL (ms):                         0.00      
P95 ITL (ms):                            0.00      
P99 ITL (ms):                            0.00      
Max ITL (ms):                            0.00      
==================================================

main: MLA + skip topk

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 not set
Successful requests:                     100
Benchmark duration (s):                  7.04
Total input tokens:                      200000
Total input text tokens:                 200000
Total input vision tokens:               0
Total generated tokens:                  100
Total generated tokens (retokenized):    95
Request throughput (req/s):              14.20
Input token throughput (tok/s):          28399.30
Output token throughput (tok/s):         14.20
Total token throughput (tok/s):          28413.50
Concurrency:                             60.48
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   4259.29
Median E2E Latency (ms):                 4230.05
---------------Time to First Token----------------
Mean TTFT (ms):                          4090.11
Median TTFT (ms):                        4229.76
P99 TTFT (ms):                           7012.75
---------------Inter-Token Latency----------------
Mean ITL (ms):                           0.00
Median ITL (ms):                         0.00
P95 ITL (ms):                            0.00
P99 ITL (ms):                            0.00
Max ITL (ms):                            0.00
==================================================

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @YAMY1234, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request extends the adaptive Multi-Head Attention (MHA) pathway to NVIDIA B200 (SM100) GPUs for DeepSeek-V3.2 models. It integrates TRT-LLM ragged attention to ensure accuracy and performance for short sequence prefill on the new architecture, while also streamlining the attention backend logic and aligning KV cache handling for improved maintainability and consistency.

Highlights

  • MHA Pathway for B200 (SM100): Enabled the Multi-Head Attention (MHA) pathway for short sequence prefill on NVIDIA B200 (SM100) GPUs, extending support beyond Hopper (SM90).
  • TRT-LLM Ragged Attention Integration: Integrated flashinfer.prefill.trtllm_ragged_attention_deepseek for the Blackwell (B200) architecture to provide improved accuracy compared to Flash Attention v4 (FA4).
  • Code Simplification and Unification: Refactored _forward_standard_mha() to uniformly handle attention for both SM90 (using FA3) and SM100 (using TRT-LLM ragged attention) within a single function, leading to cleaner code by removing a dedicated helper function.
  • KV Cache Alignment: Removed chunked prefix cache support in the NSA backend to align with previous PR DeepSeek-V3.2: Add Adaptive MHA Attention Pathway for Short-Sequence Prefill #11892, ensuring consistency and focusing on the MHA_ONE_SHOT mode.
  • Dynamic Workspace Buffer Allocation: Implemented a device capability check to conditionally allocate a global workspace buffer for the TRT-LLM ragged attention kernel specifically for SM100 devices, optimizing resource usage.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@Fridge003 Fridge003 self-assigned this Nov 6, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables the MHA pathway for short sequence prefill on B200 (SM100) devices by integrating TRT-LLM's ragged attention kernel. This is a solid optimization for newer hardware, improving both performance and accuracy. The code changes are well-structured, simplifying the MHA forward function and correctly handling device-specific logic. My main feedback concerns the thread-safety of the global workspace buffer initialization, which could be improved to prevent potential race conditions in multi-threaded environments.

Comment on lines +242 to +249
global global_workspace_buffer
if global_workspace_buffer is None:
global_workspace_buffer = torch.empty(
envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(),
dtype=torch.uint8,
device=model_runner.device,
)
self.workspace_buffer = global_workspace_buffer
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The initialization of the global_workspace_buffer is not thread-safe. If multiple NativeSparseAttnBackend instances are created concurrently (e.g., in a multi-worker server setup), the check if global_workspace_buffer is None: followed by the allocation can create a race condition. This could lead to the buffer being allocated multiple times or other concurrency-related issues.

To make this thread-safe, I recommend using a double-checked locking pattern.

First, add the following at the top of the file:

import threading

# ...

# Reuse this workspace buffer across all NSA backend instances
global_workspace_buffer = None
_workspace_buffer_lock = threading.Lock()

Then, update the initialization logic as follows:

Suggested change
global global_workspace_buffer
if global_workspace_buffer is None:
global_workspace_buffer = torch.empty(
envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(),
dtype=torch.uint8,
device=model_runner.device,
)
self.workspace_buffer = global_workspace_buffer
global global_workspace_buffer
if global_workspace_buffer is None:
with _workspace_buffer_lock:
if global_workspace_buffer is None:
global_workspace_buffer = torch.empty(
envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(),
dtype=torch.uint8,
device=model_runner.device,
)
self.workspace_buffer = global_workspace_buffer

@Fridge003
Copy link
Copy Markdown
Collaborator

@Fridge003 Fridge003 merged commit 7257525 into sgl-project:main Nov 7, 2025
81 of 107 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants