Support fa4 decoding#16034
Conversation
Summary of ChangesHello @Qiaolin-Yu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces support for FlashAttention v4 (FA4) in decoding operations, enhancing the system's ability to leverage advanced attention mechanisms for improved performance during the decoding phase. This was previously restricted to prefill operations. The changes involve updating a core dependency and removing explicit code limitations that prevented FA4 from being used for decoding. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request enables FA4 decoding by updating the nvidia-cutlass-dsl dependency and removing checks that previously restricted FA4 to prefill only. The changes look good, but I've identified a couple of potential issues in flashattention_backend.py related to FP8 and MLA compatibility with FA4 that should be addressed for robustness. Please see my detailed comment.
| k_rope: Optional[torch.Tensor] = None, | ||
| sinks: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| assert self.fa_impl_ver in [3], "Only FA3 support decoding" |
There was a problem hiding this comment.
While removing this assertion enables FA4 for decoding, there are a couple of related inconsistencies that should be addressed to ensure correctness and prevent unexpected behavior:
-
FP8 Incompatibility with FA4: The
forward_extendmethod correctly checks for and prevents the use of FA4 with FP8 KV cache (line 747), as it's unsupported. This check is missing inforward_decodeat line 1096. This could lead to errors. It's recommended to addand self.fa_impl_ver != 4to the condition at line 1096. -
MLA Path Inconsistency: The
forward_extendmethod asserts that only FA3 is supported for the MLA path (line 949). This check is missing inforward_decode. If FA4 is not supported for MLA, this could lead to silent fallbacks or errors. An assertion should be added around line 1228 for consistency.
|
/tag-and-rerun-ci |
|
@Qiaolin-Yu great feature! May I ask any performance report avaiable in typical workflows? and does it also boost Hopper? |
Motivation
Wait for #15182
Modifications
Accuracy Tests
Benchmarking and Profiling
Checklist