[DeepSeek-V3.2][NSA] Enable MHA Pathway for Short Sequence Prefill on B200 (SM100)#12788
[DeepSeek-V3.2][NSA] Enable MHA Pathway for Short Sequence Prefill on B200 (SM100)#12788Fridge003 merged 3 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @YAMY1234, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request extends the adaptive Multi-Head Attention (MHA) pathway to NVIDIA B200 (SM100) GPUs for DeepSeek-V3.2 models. It integrates TRT-LLM ragged attention to ensure accuracy and performance for short sequence prefill on the new architecture, while also streamlining the attention backend logic and aligning KV cache handling for improved maintainability and consistency. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request enables the MHA pathway for short sequence prefill on B200 (SM100) devices by integrating TRT-LLM's ragged attention kernel. This is a solid optimization for newer hardware, improving both performance and accuracy. The code changes are well-structured, simplifying the MHA forward function and correctly handling device-specific logic. My main feedback concerns the thread-safety of the global workspace buffer initialization, which could be improved to prevent potential race conditions in multi-threaded environments.
| global global_workspace_buffer | ||
| if global_workspace_buffer is None: | ||
| global_workspace_buffer = torch.empty( | ||
| envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(), | ||
| dtype=torch.uint8, | ||
| device=model_runner.device, | ||
| ) | ||
| self.workspace_buffer = global_workspace_buffer |
There was a problem hiding this comment.
The initialization of the global_workspace_buffer is not thread-safe. If multiple NativeSparseAttnBackend instances are created concurrently (e.g., in a multi-worker server setup), the check if global_workspace_buffer is None: followed by the allocation can create a race condition. This could lead to the buffer being allocated multiple times or other concurrency-related issues.
To make this thread-safe, I recommend using a double-checked locking pattern.
First, add the following at the top of the file:
import threading
# ...
# Reuse this workspace buffer across all NSA backend instances
global_workspace_buffer = None
_workspace_buffer_lock = threading.Lock()Then, update the initialization logic as follows:
| global global_workspace_buffer | |
| if global_workspace_buffer is None: | |
| global_workspace_buffer = torch.empty( | |
| envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(), | |
| dtype=torch.uint8, | |
| device=model_runner.device, | |
| ) | |
| self.workspace_buffer = global_workspace_buffer | |
| global global_workspace_buffer | |
| if global_workspace_buffer is None: | |
| with _workspace_buffer_lock: | |
| if global_workspace_buffer is None: | |
| global_workspace_buffer = torch.empty( | |
| envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(), | |
| dtype=torch.uint8, | |
| device=model_runner.device, | |
| ) | |
| self.workspace_buffer = global_workspace_buffer |
Motivation
Follow up this PR: DeepSeek-V3.2: Add Adaptive MHA Attention Pathway for Short-Sequence Prefill. Enable and optimize MHA on B200 (SM100) in the NSA backend by TRT-LLM ragged attention.
Modifications
flashinfer.prefill.trtllm_ragged_attention_deepseekfor Blackwell (B200) architecture to provide better accuracy than FA4Key changes in
nsa_backend.py:_forward_standard_mha()to handle both SM90 (FA3) and SM100 (TRT-LLM) in a single function_forward_trtllm_ragged_mha()helper function for cleaner code structureAccuracy Tests
Benchmarking and Profiling
This PR: MHA(trtllm_ragged_attention_deepseek) + skip topk
main: MLA + skip topk
Checklist