Conversation
Summary of ChangesHello @ispobock, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly advances the support for piecewise CUDA graphs within the system, particularly for Multi-Layer Attention (MLA) and DeepSeek models. It refines the integration of Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds support for piecewise CUDA graphs for models with Multi-Latent Attention (MLA), which is a valuable performance optimization. The changes correctly adapt the codebase for torch.compile, handling string-based type annotations, working around torch.compile limitations with certain tensor operations, and ensuring static code paths for graph capture. The PR also includes correctness fixes for attention layers and broadens model support. I've identified one minor inconsistency in an optimization that could be addressed for completeness.
| query_pass = query[..., self.rotary_dim :] | ||
| key_pass = key[..., self.rotary_dim :] | ||
|
|
||
| self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device) |
There was a problem hiding this comment.
The removal of self.cos_sin_cache.to(positions.device) in forward_native is a good optimization to avoid a redundant device transfer. However, a similar redundant call still exists in the forward_npu method of the same DeepseekScalingRotaryEmbedding class (at line 825). For consistency and to apply the same optimization for NPU devices, this line should also be removed from forward_npu.
|
LGTM. However I suggest postponing the merge until:
|
| not get_global_server_args().flashinfer_mla_disable_ragged | ||
| and extend_no_prefix | ||
| # Piecewise cuda graph should use paged prefill to be compatible with prefix cache | ||
| and not is_in_piecewise_cuda_graph() |
There was a problem hiding this comment.
Wonder why piecewise cuda graph can impact attention execution?
There was a problem hiding this comment.
Hi @Edenzzzz, good question! The reason is in the forward of deepseek model (ref), there are mostly two types of extend (MHA for no prefix and MLA for w/ prefix), but we can only capture one type in the prefill cuda graph. Currently we choose MLA since it can be used in both w/ or w/o prefix cases. So in the flashinfer_mla attention backend, we only use paged prefill kernel for MLA forward.
|
LGTM. I think it could be merged. |
Motivation
Support piecewise cuda graph for MLA.
Triton
Flashinfer
FA3