Skip to content

[NVIDIA] upstream FA4#15182

Merged
Fridge003 merged 26 commits intosgl-project:mainfrom
johnnynunez:flash-attention
Jan 11, 2026
Merged

[NVIDIA] upstream FA4#15182
Fridge003 merged 26 commits intosgl-project:mainfrom
johnnynunez:flash-attention

Conversation

@johnnynunez
Copy link
Copy Markdown
Contributor

  1. Added num_splits_heuristic function (lines 39-46)
    Heuristic to determine the number of splits for split KV computation
  2. Added new parameters to _flash_attn_fwd
    num_splits: int = 1 (line 70)
    mask_mod: Optional[Callable] = None (line 74)
  3. Updated window_size handling logic (lines 225-236)
    Supports mask_mod parameter
    Handles causal/local window logic with mask_mod
  4. Added split KV logic (lines 237-260)
    Computes is_split_kv based on num_splits
    Creates out_partial and lse_partial tensors when split KV is enabled
    Automatic split calculation when num_splits < 1
  5. Updated compile key (lines 350-376)
    Added score_mod_hash and mask_mod_hash
    Added is_split_kv flag
    Added paged_kv_non_tma flag
    Added buffer count instead of aux_tensors count
  6. Updated kernel initialization calls
    FlashAttentionForwardSm90 (lines 382-401): Added mask_mod, intra_wg_overlap=True, mma_pv_is_rs=True
    FlashAttentionForwardSm100 (lines 407-428): Added mask_mod, is_split_kv, paged_kv_non_tma, is_varlen_q, m_block_size, n_block_size
  7. Added _flash_attn_fwd_combine function (lines 486-616)
    Combines partial outputs from split KV computation
    Includes caching logic
  8. Added combine call (lines 473-481)
    Calls _flash_attn_fwd_combine when is_split_kv is True
  9. Added import (line 25)
    from flash_attn_origin.cute.flash_fwd_combine import FlashAttentionForwardCombine

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @johnnynunez, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces significant enhancements to the FlashAttention v4 implementation, primarily by adding robust support for split Key-Value (KV) computation and more flexible mask handling. These changes are designed to optimize performance and adaptability for various attention patterns and sequence lengths. The update also includes a new heuristic for determining optimal splits, refines the kernel compilation process to accommodate these advanced features, and integrates a new function to combine results from split operations.

Highlights

  • Split KV Heuristic: Introduced a new num_splits_heuristic function to dynamically determine the optimal number of splits for Key-Value (KV) computation, improving efficiency.
  • New Parameters for FlashAttention: Added num_splits and mask_mod as new parameters to the _flash_attn_fwd function, enhancing control over attention splitting and masking.
  • Enhanced Window Size Handling: Updated the window_size logic to support the new mask_mod parameter and to correctly handle causal and local attention mechanisms.
  • Split KV Logic Implementation: Implemented the core split KV logic, including the computation of is_split_kv, automatic split calculation when num_splits is less than 1, and the creation of out_partial and lse_partial tensors for intermediate results.
  • Updated Kernel Compile Key: Modified the kernel compile key to include new hashes for score_mod and mask_mod, along with flags for is_split_kv and paged_kv_non_tma, ensuring proper caching for diverse configurations.
  • Kernel Initialization Updates: Adjusted the initialization calls for FlashAttentionForwardSm90 and FlashAttentionForwardSm100 to pass the newly introduced mask_mod, is_split_kv, and other relevant parameters.
  • New Combine Function: Added a new function, _flash_attn_fwd_combine, responsible for merging partial outputs and log-sum-exp values from split KV computations into final results.
  • Integration of Combine Function: Integrated the _flash_attn_fwd_combine call into _flash_attn_fwd, ensuring that partial results are combined when split KV is enabled.
  • New Import: Included a new import statement for FlashAttentionForwardCombine to support the new combine functionality.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request brings in upstream changes from FlashAttention 4, notably adding support for split KV computation. This includes a new heuristic for determining the number of splits, new parameters to _flash_attn_fwd, updated logic for windowing and masking, and a new _flash_attn_fwd_combine function to merge partial results. The compile keys and kernel initializations are also updated accordingly. The changes are well-structured and the new functionality is a significant addition. I have one suggestion to improve input validation in the new combine function.

Comment on lines +540 to +544
for t, name in [
(cu_seqlens, "cu_seqlens"),
(seqused, "seqused"),
(num_splits_dynamic_ptr, "num_splits_dynamic_ptr"),
]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For completeness and safety, the semaphore_to_reset tensor should also be validated in this loop, similar to the other optional tensor arguments. While it's not used in the current call sites, adding this validation will prevent potential issues if it's used in the future.

Suggested change
for t, name in [
(cu_seqlens, "cu_seqlens"),
(seqused, "seqused"),
(num_splits_dynamic_ptr, "num_splits_dynamic_ptr"),
]:
for t, name in [
(cu_seqlens, "cu_seqlens"),
(seqused, "seqused"),
(num_splits_dynamic_ptr, "num_splits_dynamic_ptr"),
(semaphore_to_reset, "semaphore_to_reset"),
]:

device=device,
)
lse_partial = torch.empty(
num_splits, *lse_shape, dtype=torch.float32, device=device
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this cudagraph friendly?

@Qiaolin-Yu
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@johnnynunez
Copy link
Copy Markdown
Contributor Author

johnnynunez commented Dec 29, 2025

@Qiaolin-Yu im on vacation right now... i updated to last commit flash attention, thanks to the other fixes. The basics in sgl-kernel are working

@johnnynunez johnnynunez changed the title [NVIDIA] upstream FA4 12/15/25 [NVIDIA] upstream FA4 12/29/25 Dec 29, 2025
@Qiaolin-Yu
Copy link
Copy Markdown
Collaborator

@Qiaolin-Yu im on vacation right now... i updated to last commit flash attention, thanks to the other fixes. The basics in sgl-kernel are working

Enjoy your vacation! Don’t worry, I’ll take care of wrapping things up.

@Fridge003 Fridge003 mentioned this pull request Jan 8, 2026
6 tasks
@Fridge003 Fridge003 changed the title [NVIDIA] upstream FA4 12/29/25 [NVIDIA] upstream FA4 Jan 10, 2026
@Fridge003
Copy link
Copy Markdown
Collaborator

Will directly move to the latest upstream
sgl-project/sgl-flash-attn#28

@Fridge003
Copy link
Copy Markdown
Collaborator

b200 tests passed with newest cutedsl and tvm-ffi
https://github.com/sgl-project/sglang/actions/runs/20890229706/job/60022405517?pr=15182

@johnnynunez
Copy link
Copy Markdown
Contributor Author

johnnynunez commented Jan 11, 2026

b200 tests passed with newest cutedsl and tvm-ffi https://github.com/sgl-project/sglang/actions/runs/20890229706/job/60022405517?pr=15182

The last commit from FA4, changed a little bit the interface.py. We can update it, i think so

@Qiaolin-Yu
Copy link
Copy Markdown
Collaborator

Test results with #16034

python3 -m sglang.launch_server --model-path Qwen/Qwen3-235B-A22B
-Instruct-2507-FP8  --trust-remote-code --attention-backend fa4 --tp 4

python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions
 1319 --parallel 1319 
100%|█████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:38<00:00, 34.36it/s]
Accuracy: 0.958
Invalid: 0.000
Latency: 38.535 s
Output throughput: 4972.801 token/s

@Qiaolin-Yu
Copy link
Copy Markdown
Collaborator

b200 tests passed with newest cutedsl and tvm-ffi https://github.com/sgl-project/sglang/actions/runs/20890229706/job/60022405517?pr=15182

The last commit from FA4, changed a little bit the interface.py. We can update it, i think so

I have updated it.

@johnnynunez
Copy link
Copy Markdown
Contributor Author

b200 tests passed with newest cutedsl and tvm-ffi https://github.com/sgl-project/sglang/actions/runs/20890229706/job/60022405517?pr=15182

The last commit from FA4, changed a little bit the interface.py. We can update it, i think so

I have updated it.

thanks, that fix some bugs in sm100

Comment thread python/pyproject.toml Outdated
@Fridge003 Fridge003 merged commit b5493f6 into sgl-project:main Jan 11, 2026
18 of 45 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

dependencies Pull requests that update a dependency file high priority run-ci sgl-kernel

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants