Skip to content

Support piecewise cuda graph for MLA#11812

Merged
ispobock merged 19 commits intomainfrom
ke/mla-compile
Nov 10, 2025
Merged

Support piecewise cuda graph for MLA#11812
ispobock merged 19 commits intomainfrom
ke/mla-compile

Conversation

@ispobock
Copy link
Copy Markdown
Collaborator

@ispobock ispobock commented Oct 18, 2025

Motivation

Support piecewise cuda graph for MLA.

Triton

python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V2-Lite --enable-piecewise-cuda-graph --piecewise-cuda-graph-max-tokens 8192 --attention-backend triton
python3 benchmark/gsm8k/bench_sglang.py --parallel 1319 --num-questions 1319

Accuracy: 0.387
Invalid: 0.005
Latency: 16.465 s
Output throughput: 9663.436 token/s

Flashinfer

python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V2-Lite --enable-piecewise-cuda-graph --piecewise-cuda-graph-max-tokens 8192 --attention-backend flashinfer
python3 benchmark/gsm8k/bench_sglang.py --parallel 1319 --num-questions 1319

Accuracy: 0.384
Invalid: 0.005
Latency: 15.206 s
Output throughput: 10550.984 token/s

FA3

python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V2-Lite --enable-piecewise-cuda-graph --piecewise-cuda-graph-max-tokens 8192 --attention-backend fa3
python3 benchmark/gsm8k/bench_sglang.py --parallel 1319 --num-questions 1319

Accuracy: 0.387
Invalid: 0.004
Latency: 15.164 s
Output throughput: 10469.733 token/s

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @ispobock, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly advances the support for piecewise CUDA graphs within the system, particularly for Multi-Layer Attention (MLA) and DeepSeek models. It refines the integration of torch.compile by addressing stability issues and ensuring that model compilation and execution are properly managed within the CUDA graph context. The changes also enhance the flexibility of attention mechanisms by allowing additional parameters and improve compatibility with various model architectures, leading to more efficient and robust model execution.

Highlights

  • Piecewise CUDA Graph Integration: Refactored and enhanced the piecewise CUDA graph context management, ensuring that model compilation and execution correctly utilize this feature for improved performance.
  • DeepSeek Model Compatibility: Added specific support for DeepSeek models by recognizing their unique attention layer attribute (attn_mqa) and adjusting attention backend behavior within the CUDA graph mode to maintain compatibility.
  • torch.compile Stability Improvements: Introduced configuration changes to torch.compile to mitigate FailOnRecompileLimitHit errors, significantly improving the robustness and reliability of graph compilation.
  • Flexible Attention Parameters: Extended the unified_attention_with_output function to accept additional attention-related parameters like RoPE (Rotary Positional Embedding) and sinks, allowing for more complex and customizable attention mechanisms.
  • Dynamic Argument Dimension Inference: Improved the inference of dynamic argument dimensions for torch.Tensor and Optional[torch.Tensor] types, especially when using future import annotations, which is crucial for correct torch.compile behavior.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@ispobock ispobock marked this pull request as draft October 18, 2025 18:14
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for piecewise CUDA graphs for models with Multi-Latent Attention (MLA), which is a valuable performance optimization. The changes correctly adapt the codebase for torch.compile, handling string-based type annotations, working around torch.compile limitations with certain tensor operations, and ensuring static code paths for graph capture. The PR also includes correctness fixes for attention layers and broadens model support. I've identified one minor inconsistency in an optimization that could be addressed for completeness.

query_pass = query[..., self.rotary_dim :]
key_pass = key[..., self.rotary_dim :]

self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The removal of self.cos_sin_cache.to(positions.device) in forward_native is a good optimization to avoid a redundant device transfer. However, a similar redundant call still exists in the forward_npu method of the same DeepseekScalingRotaryEmbedding class (at line 825). For consistency and to apply the same optimization for NPU devices, this line should also be removed from forward_npu.

@ispobock ispobock marked this pull request as ready for review November 1, 2025 09:50
@ispobock ispobock added run-ci and removed run-ci labels Nov 1, 2025
@Oasis-Git
Copy link
Copy Markdown
Collaborator

LGTM. However I suggest postponing the merge until:

  1. Merge of [PieceWise CUDA Graph] Support awq/gptq model in piecewise cudagraph #12518 since the overall modification on context control is heavy in this branch
  2. Understanding of capture problem of MLA model

not get_global_server_args().flashinfer_mla_disable_ragged
and extend_no_prefix
# Piecewise cuda graph should use paged prefill to be compatible with prefix cache
and not is_in_piecewise_cuda_graph()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonder why piecewise cuda graph can impact attention execution?

Copy link
Copy Markdown
Collaborator Author

@ispobock ispobock Nov 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Edenzzzz, good question! The reason is in the forward of deepseek model (ref), there are mostly two types of extend (MHA for no prefix and MLA for w/ prefix), but we can only capture one type in the prefill cuda graph. Currently we choose MLA since it can be used in both w/ or w/o prefix cases. So in the flashinfer_mla attention backend, we only use paged prefill kernel for MLA forward.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, ragged is MHA

@Oasis-Git
Copy link
Copy Markdown
Collaborator

LGTM. I think it could be merged.

@ispobock ispobock merged commit db24d34 into main Nov 10, 2025
81 of 83 checks passed
@ispobock ispobock deleted the ke/mla-compile branch November 10, 2025 01:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants