Skip to content

[FlashAttn] Add fused triton kernel for normal_decode_set_metadata#20778

Merged
BBuf merged 3 commits intosgl-project:mainfrom
libowen2121:flashattn/fuse-normal-decode-set-metadata-triton
Mar 22, 2026
Merged

[FlashAttn] Add fused triton kernel for normal_decode_set_metadata#20778
BBuf merged 3 commits intosgl-project:mainfrom
libowen2121:flashattn/fuse-normal-decode-set-metadata-triton

Conversation

@libowen2121
Copy link
Copy Markdown
Contributor

@libowen2121 libowen2121 commented Mar 17, 2026

Motivation

This PR introduces a fused Triton kernel for the normal_decode_set_metadata function, addressing the existing TODO: fuse these kernels in flashattention_backend.py.

# @torch.compile(dynamic=True, backend=get_compiler_backend())
# TODO: fuse these kernels
# NOTE: torch.compile makes it slower in speculative decoding
def normal_decode_set_metadata(
cache_seqlens_int32: torch.Tensor,
cu_seqlens_k: torch.Tensor,
page_table: torch.Tensor,

Modifications

Introduced two fused Triton kernels in flashattention_backend.py to replace the original sequential operations in normal_decode_set_metadata:

  • _fused_metadata_kernel_ps1_no_swa — specialized fast path for the common case (page_size=1, no SWA)
  • _fused_metadata_kernel_general — general path supporting arbitrary power-of-two page sizes and optional Sliding Window Attention (SWA)

Added unit tests in python/sglang/test/attention/test_normal_decode_set_metadata.py.

Accuracy Tests

➜  sglang git:(main) ✗ python -m pytest python/sglang/test/attention/test_flashattn_backend.py -v
================================================================================ test session starts =================================================================================
platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0 -- /usr/bin/python
cachedir: .pytest_cache
rootdir: /bowen/projects/rl/sglang/python
configfile: pyproject.toml
plugins: anyio-4.12.1, typeguard-4.5.1
collected 7 items

python/sglang/test/attention/test_flashattn_backend.py::TestFlashAttentionBackend::test_forward_decode PASSED                                                                  [ 14%]
python/sglang/test/attention/test_flashattn_backend.py::TestFlashAttentionBackend::test_forward_decode_with_page_size_greater_than_1 PASSED                                    [ 28%]
python/sglang/test/attention/test_flashattn_backend.py::TestFlashAttentionBackend::test_forward_extend PASSED                                                                  [ 42%]
python/sglang/test/attention/test_flashattn_backend.py::TestFlashAttentionBackend::test_forward_extend_with_page_size_greater_than_1 PASSED                                    [ 57%]
python/sglang/test/attention/test_flashattn_backend.py::TestFlashAttentionBackend::test_forward_extend_with_prefix PASSED                                                      [ 71%]
python/sglang/test/attention/test_flashattn_backend.py::TestUpdateDraftDecodeSetExpandMetadata::test_draft_decode_set_expand_metadata PASSED                                   [ 85%]
python/sglang/test/attention/test_flashattn_backend.py::TestUpdateDraftDecodeSetExpandMetadata::test_update_draft_decode_set_expand_metadata_multi_batch PASSED                [100%]

================================================================================== warnings summary ==================================================================================
<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

<frozen importlib._bootstrap_external>:1297
  <frozen importlib._bootstrap_external>:1297: FutureWarning: The cuda.cudart module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.runtime module instead.

<frozen importlib._bootstrap_external>:1297
  <frozen importlib._bootstrap_external>:1297: FutureWarning: The cuda.nvrtc module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.nvrtc module instead.

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================================================================== 7 passed, 4 warnings in 14.97s ===========================================================================
sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute
➜  sglang git:(main) ✗ python -m pytest python/sglang/test/attention/test_normal_decode_set_metadata.py -v
================================================================================ test session starts =================================================================================
platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0 -- /usr/bin/python
cachedir: .pytest_cache
rootdir: /bowen/projects/rl/sglang/python
configfile: pyproject.toml
plugins: anyio-4.12.1, typeguard-4.5.1
collected 16 items

python/sglang/test/attention/test_normal_decode_set_metadata.py::TestNormalDecodeSetMetadata::test_batch_size_1 PASSED                                                         [  6%]
python/sglang/test/attention/test_normal_decode_set_metadata.py::TestNormalDecodeSetMetadata::test_max_seq_pages_zero PASSED                                                   [ 12%]
python/sglang/test/attention/test_normal_decode_set_metadata.py::TestNormalDecodeSetMetadata::test_page_size_16_medium_batch PASSED                                            [ 18%]
python/sglang/test/attention/test_normal_decode_set_metadata.py::TestNormalDecodeSetMetadata::test_page_size_16_small_batch PASSED                                             [ 25%]
python/sglang/test/attention/test_normal_decode_set_metadata.py::TestNormalDecodeSetMetadata::test_page_size_16_with_swa PASSED                                                [ 31%]
python/sglang/test/attention/test_normal_decode_set_metadata.py::TestNormalDecodeSetMetadata::test_page_size_1_large_batch PASSED                                              [ 37%]
python/sglang/test/attention/test_normal_decode_set_metadata.py::TestNormalDecodeSetMetadata::test_page_size_1_medium_batch PASSED                                             [ 43%]
python/sglang/test/attention/test_normal_decode_set_metadata.py::TestNormalDecodeSetMetadata::test_page_size_1_small_batch PASSED                                              [ 50%]
python/sglang/test/attention/test_normal_decode_set_metadata.py::TestNormalDecodeSetMetadata::test_page_size_1_with_seq_len_delta PASSED                                       [ 56%]
python/sglang/test/attention/test_normal_decode_set_metadata.py::TestNormalDecodeSetMetadata::test_page_size_64_medium_batch PASSED                                            [ 62%]
python/sglang/test/attention/test_normal_decode_set_metadata.py::TestNormalDecodeSetMetadata::test_page_size_64_small_batch PASSED                                             [ 68%]
python/sglang/test/attention/test_normal_decode_set_metadata.py::TestNormalDecodeSetMetadata::test_page_size_64_with_seq_len_delta PASSED                                      [ 75%]
python/sglang/test/attention/test_normal_decode_set_metadata.py::TestNormalDecodeSetMetadata::test_page_size_64_with_swa PASSED                                                [ 81%]
python/sglang/test/attention/test_normal_decode_set_metadata.py::TestNormalDecodeSetMetadata::test_page_size_64_with_swa_and_delta PASSED                                      [ 87%]
python/sglang/test/attention/test_normal_decode_set_metadata.py::TestNormalDecodeSetMetadata::test_power_of_two_page_sizes
python/sglang/test/attention/test_normal_decode_set_metadata.py::TestNormalDecodeSetMetadata::test_power_of_two_page_sizes PASSED                                              [ 93%]
python/sglang/test/attention/test_normal_decode_set_metadata.py::TestNormalDecodeSetMetadata::test_varied_sequence_lengths PASSED                                              [100%]

================================================================================== warnings summary ==================================================================================
<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================================= 16 passed, 2 warnings, 8 subtests passed in 7.72s ==================================================================

Benchmarking and Profiling

Isolated kernel benchmark

We measured the latency of the isolated kernel over 1,000 iterations on NV-H200 using the following configuration: bs 32, page_size 1, max_ctx 8192, max_pool 1024, seq_delta 0. The fused kernel achieves a 4.78x speedup compared to the baseline implementation.

End-to-end benchmark

The following benchmarks were conducted using sglang.bench_serving with meta-llama/Meta-Llama-3.1-8B-Instruct on NV-H200. Both random-input-len and random-output-len were set to seqlen.

batchsize seqlen baseline_latency_ms pr_latency_ms
16 64 251.20 249.87
16 128 430.46 427.80
16 256 794.13 787.26
16 512 1579.02 1569.09
16 1024 3164.77 3147.41
16 2048 7014.71 6978.49
32 64 301.44 302.48
32 128 513.51 504.52
32 256 916.67 911.05
32 512 1761.68 1751.12
32 1024 3674.88 3655.83
32 2048 8274.03 8239.82
64 64 394.45 392.46
64 128 653.04 647.54
64 256 1181.10 1173.95
64 512 2337.13 2321.75
64 1024 4827.64 4813.91
64 2048 11689.43 11662.33
128 64 485.51 480.58
128 128 916.92 913.92
128 256 1770.41 1761.12
128 512 3353.52 3337.89
128 1024 7480.48 7472.00
128 2048 19039.06 19017.41

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the performance of the normal_decode_set_metadata function, a critical component in the FlashAttention backend, by refactoring its underlying operations. By consolidating multiple sequential CUDA kernel calls into highly optimized, fused Triton kernels, the change drastically reduces overhead and improves execution speed, particularly beneficial for decode workloads. This optimization directly contributes to faster and more efficient model inference.

Highlights

  • Performance Optimization: Introduced two fused Triton kernels, _fused_metadata_kernel_general and _fused_metadata_kernel_ps1_no_swa, to optimize the normal_decode_set_metadata function.
  • Kernel Fusion: Replaced multiple sequential CUDA kernel operations with these new fused Triton kernels, significantly reducing overhead and improving execution efficiency.
  • Speedup Achieved: Demonstrated a substantial performance improvement of approximately 5.2x speedup on H200 hardware for typical decode workloads.
  • Comprehensive Testing: Added extensive unit tests to validate the correctness and robustness of the new fused Triton kernels across various configurations and edge cases.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/layers/attention/flashattention_backend.py
    • Added two new Triton JIT kernels: _fused_metadata_kernel_general and _fused_metadata_kernel_ps1_no_swa.
    • Refactored normal_decode_set_metadata to dispatch to the appropriate fused Triton kernel based on page_size and SWA usage.
    • Removed the previous sequential PyTorch operations within normal_decode_set_metadata.
  • python/sglang/test/attention/test_normal_decode_set_metadata.py
    • Added a new test file to validate the correctness of the fused Triton kernels in normal_decode_set_metadata.
    • Included a reference PyTorch implementation for comparison.
    • Implemented tests covering various batch sizes, sequence lengths, page sizes (1, 16, 64), and SWA configurations.
Activity
  • Accuracy tests were performed using pytest on test_flashattn_backend.py and test_normal_decode_set_metadata.py, with all 23 collected items passing.
  • Benchmarking results for meta-llama/Meta-Llama-3.1-8B-Instruct across various batch sizes and sequence lengths consistently show latency improvements, confirming the performance benefits of the fused kernels.
  • The author has completed several checklist items, including code formatting, adding unit tests, and providing accuracy and speed benchmark results.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces fused Triton kernels (_fused_metadata_kernel_general and _fused_metadata_kernel_ps1_no_swa) to the normal_decode_set_metadata function in flashattention_backend.py. This change replaces several sequential CUDA operations with optimized kernels, achieving a reported ~5.2x speedup. The normal_decode_set_metadata function now dispatches to either a general kernel or a specialized kernel for page_size=1 without SWA. A new test file, test_normal_decode_set_metadata.py, has been added to verify the correctness of these fused kernels across various page sizes, SWA configurations, batch sizes, sequence lengths, and edge cases. A review comment suggests refactoring the duplicated prefix sum logic found in both new Triton kernels into a shared helper function to improve code maintainability.

Comment on lines +2791 to +2800
# 1. Prefix sum (only one block does it)
if pid_b == 0 and pid_c == 0:
acc = 0
for idx in range(B):
seq = tl.load(seq_lens + idx * seq_lens_stride_0)
val = (seq + seq_len_delta).to(tl.int32)
tl.store(cache_seqlens_int32 + idx * cache_seqlens_int32_stride_0, val)
tl.store(cu_seqlens_k + idx * cu_seqlens_k_stride_0, acc)
acc += val
tl.store(cu_seqlens_k + B * cu_seqlens_k_stride_0, acc)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This prefix sum logic is identical to the one in _fused_metadata_kernel_general (lines 2701-2710). To avoid code duplication and improve maintainability, you could extract this block into a separate triton.jit helper function and call it from both kernels.


# Kernel configuration
BLOCK_COLS = 128
shift = (page_size).bit_length() - 1 if page_size > 1 else 0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add check for page_size, it must be power-of-two number.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have added it.

self._run_test(batch_size=1, max_seq_len=128, page_size=1, has_swa=False)
self._run_test(batch_size=1, max_seq_len=256, page_size=64, has_swa=False)

def test_max_seq_pages_zero(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What zero mean?

Copy link
Copy Markdown
Contributor

@kinza99 kinza99 Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "zero" was a misnomer. We've renamed it to test_max_seq_pages_small to better reflect what it tests.

@@ -0,0 +1,414 @@
"""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add this test to ci workflow.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@libowen2121
Copy link
Copy Markdown
Contributor Author

Added isolated kernel benchmarking result.

Copy link
Copy Markdown
Collaborator

@BBuf BBuf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Mar 20, 2026

/tag-and-rerun-ci

@BBuf BBuf merged commit 3bc595a into sgl-project:main Mar 22, 2026
221 of 281 checks passed
OrangeRedeng pushed a commit to OrangeRedeng/sglang that referenced this pull request Mar 22, 2026
0-693 pushed a commit to 0-693/sglang that referenced this pull request Mar 25, 2026
dutsc pushed a commit to dutsc/sglang that referenced this pull request Mar 30, 2026
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants