[NVIDIA] upstream FA4#15182
Conversation
Summary of ChangesHello @johnnynunez, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces significant enhancements to the FlashAttention v4 implementation, primarily by adding robust support for split Key-Value (KV) computation and more flexible mask handling. These changes are designed to optimize performance and adaptability for various attention patterns and sequence lengths. The update also includes a new heuristic for determining optimal splits, refines the kernel compilation process to accommodate these advanced features, and integrates a new function to combine results from split operations. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request brings in upstream changes from FlashAttention 4, notably adding support for split KV computation. This includes a new heuristic for determining the number of splits, new parameters to _flash_attn_fwd, updated logic for windowing and masking, and a new _flash_attn_fwd_combine function to merge partial results. The compile keys and kernel initializations are also updated accordingly. The changes are well-structured and the new functionality is a significant addition. I have one suggestion to improve input validation in the new combine function.
| for t, name in [ | ||
| (cu_seqlens, "cu_seqlens"), | ||
| (seqused, "seqused"), | ||
| (num_splits_dynamic_ptr, "num_splits_dynamic_ptr"), | ||
| ]: |
There was a problem hiding this comment.
For completeness and safety, the semaphore_to_reset tensor should also be validated in this loop, similar to the other optional tensor arguments. While it's not used in the current call sites, adding this validation will prevent potential issues if it's used in the future.
| for t, name in [ | |
| (cu_seqlens, "cu_seqlens"), | |
| (seqused, "seqused"), | |
| (num_splits_dynamic_ptr, "num_splits_dynamic_ptr"), | |
| ]: | |
| for t, name in [ | |
| (cu_seqlens, "cu_seqlens"), | |
| (seqused, "seqused"), | |
| (num_splits_dynamic_ptr, "num_splits_dynamic_ptr"), | |
| (semaphore_to_reset, "semaphore_to_reset"), | |
| ]: |
| device=device, | ||
| ) | ||
| lse_partial = torch.empty( | ||
| num_splits, *lse_shape, dtype=torch.float32, device=device |
There was a problem hiding this comment.
is this cudagraph friendly?
|
/tag-and-rerun-ci |
|
@Qiaolin-Yu im on vacation right now... i updated to last commit flash attention, thanks to the other fixes. The basics in sgl-kernel are working |
Enjoy your vacation! Don’t worry, I’ll take care of wrapping things up. |
a0f7949 to
c8bbf66
Compare
|
Will directly move to the latest upstream |
9e274de to
4b24aeb
Compare
|
b200 tests passed with newest cutedsl and tvm-ffi |
The last commit from FA4, changed a little bit the interface.py. We can update it, i think so |
|
Test results with #16034 |
I have updated it. |
thanks, that fix some bugs in sm100 |
Heuristic to determine the number of splits for split KV computation
num_splits: int = 1 (line 70)
mask_mod: Optional[Callable] = None (line 74)
Supports mask_mod parameter
Handles causal/local window logic with mask_mod
Computes is_split_kv based on num_splits
Creates out_partial and lse_partial tensors when split KV is enabled
Automatic split calculation when num_splits < 1
Added score_mod_hash and mask_mod_hash
Added is_split_kv flag
Added paged_kv_non_tma flag
Added buffer count instead of aux_tensors count
FlashAttentionForwardSm90 (lines 382-401): Added mask_mod, intra_wg_overlap=True, mma_pv_is_rs=True
FlashAttentionForwardSm100 (lines 407-428): Added mask_mod, is_split_kv, paged_kv_non_tma, is_varlen_q, m_block_size, n_block_size
Combines partial outputs from split KV computation
Includes caching logic
Calls _flash_attn_fwd_combine when is_split_kv is True
from flash_attn_origin.cute.flash_fwd_combine import FlashAttentionForwardCombine