Skip to content

perf: precompute FA3 scheduler_metadata to eliminate per-layer prepare_varlen_num_blocks#21104

Merged
Qiaolin-Yu merged 2 commits intosgl-project:mainfrom
zminglei:zminglei/fa3-scheduler-metadata-python
Apr 10, 2026
Merged

perf: precompute FA3 scheduler_metadata to eliminate per-layer prepare_varlen_num_blocks#21104
Qiaolin-Yu merged 2 commits intosgl-project:mainfrom
zminglei:zminglei/fa3-scheduler-metadata-python

Conversation

@zminglei
Copy link
Copy Markdown
Collaborator

@zminglei zminglei commented Mar 21, 2026

Summary

Call get_scheduler_metadata once per batch in decode metadata init (including CUDA graph capture/replay paths) and pass the result to flash_attn_with_kvcache so FA3 skips the internal prepare_varlen_num_blocks kernel on every layer.

This eliminates 63 redundant GPU kernel calls per decode step (64 layers → 1 call via get_scheduler_metadata), matching what vLLM does.

Changes (1 file, Python only)

flashattention_backend.py: Compute and pass scheduler_metadata in all decode paths:

  • init_forward_metadata (non-CUDA-graph decode)
  • init_cuda_graph_state (pre-allocate buffer)
  • init_forward_metadata_capture_cuda_graph (capture path)
  • init_forward_metadata_replay_cuda_graph (replay path)
  • forward_decode (pass to flash_attn_with_kvcache)

This is part 2 of 2: sglang Python changes only.
Requires sgl-kernel with get_scheduler_metadata (#21103) for the optimization to activate (falls back to no-op without it).

Benchmark Results (TP=4, Qwen3-32B, 4×H200, 3 runs averaged, fresh server restart each)

BS=1 Decode-heavy (input=500, output=8K)

Metric Baseline (avg) Optimized (avg) Delta
Output throughput (tok/s) 131.41 134.83 +2.6%
Mean TPOT (ms) 7.61 7.41 -2.6%
Mean ITL (ms) 7.61 7.41 -2.6%
Mean TTFT (ms) 27.99 26.18 ~same

BS=4 Decode-heavy (input=500, output=8K)

Metric Baseline (avg) Optimized (avg) Delta
Output throughput (tok/s) 496.63 508.21 +2.3%
Mean TPOT (ms) 8.04 7.86 -2.2%
Mean ITL (ms) 8.04 7.86 -2.2%
Mean TTFT (ms) 117.68 117.51 -0.1% (no regression)

BS=16 Decode-heavy (input=500, output=8K)

Metric Baseline (avg) Optimized (avg) Delta
Output throughput (tok/s) 1735.09 1768.68 +1.9%
Mean TPOT (ms) 9.19 9.01 -2.0%
Mean TTFT (ms) 295.75 290.28 ~same

BS=8 Balanced (input=2K, output=2K)

Metric Baseline (avg) Optimized (avg) Delta
Output throughput (tok/s) 934.61 955.25 +2.2%
Mean TPOT (ms) 8.34 8.15 -2.3%
Mean ITL (ms) 8.34 8.15 -2.3%
Mean TTFT (ms) 445.16 449.32 +0.9% (within noise)

BS=8 Prefill-heavy (input=8K, output=500)

Metric Baseline (avg) Optimized (avg) Delta
Output throughput (tok/s) 594.83 603.30 +1.4%
Mean TPOT (ms) 10.10 9.92 -1.8%
Mean ITL (ms) 10.10 9.92 -1.8%
Mean TTFT (ms) 1670.04 1669.40 -0.04% (no regression)

Accuracy (GSM8K, 1319 questions, parallel=1319)

Branch Accuracy
Baseline 0.858
Optimized 0.861

No accuracy regression (within normal sampling variance).

How to Reproduce

Server (Qwen3-32B, TP=4):

python -m sglang.launch_server \
  --model-path Qwen/Qwen3-32B \
  --reasoning-parser qwen3 \
  --tp 4 --port 30000

Decode-heavy benchmark:

python -m sglang.bench_serving --backend sglang \
  --num-prompts 4 --dataset-name random-ids \
  --random-input-len 500 --random-output-len 8000 \
  --random-range-ratio 1.0 --port 30000

Balanced benchmark:

python -m sglang.bench_serving --backend sglang \
  --num-prompts 8 --dataset-name random-ids \
  --random-input-len 2000 --random-output-len 2000 \
  --random-range-ratio 1.0 --port 30000

Prefill-heavy benchmark:

python -m sglang.bench_serving --backend sglang \
  --num-prompts 8 --dataset-name random-ids \
  --random-input-len 8000 --random-output-len 500 \
  --random-range-ratio 1.0 --port 30000

GSM8K accuracy:

python benchmark/gsm8k/bench_sglang.py \
  --num-questions 1319 --parallel 1319 --port 30000

Key takeaways

  • Consistent 1.4-2.6% decode throughput improvement across all traffic patterns and batch sizes
  • No TTFT regression — averaged across 3 fresh-server runs for each configuration
  • Gain is purely from decode optimization; prefill path is untouched
  • The fixed ~200µs/step savings means gain scales inversely with TPOT: BS=1 TPOT=7.6ms → 2.6%, BS=8 TPOT=10.1ms → 1.8%
  • Projected ~5% at TP=8 where TPOT is ~4ms

Profile verification

  • prepare_varlen_num_blocks calls reduced from 576 → 72 per profiling window (64 from prefill + 8 from get_scheduler_metadata; all per-layer decode calls eliminated)

main:
image

this PR:
image

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the performance of FlashAttention v3 during the decode phase by optimizing the handling of scheduler metadata. Instead of repeatedly computing this metadata for each layer, it is now precomputed once per batch, drastically reducing the number of GPU kernel invocations. This change results in measurable improvements in decoding speed and throughput, all while maintaining model accuracy and ensuring compatibility with existing kernel versions.

Highlights

  • Performance Optimization: Precomputed FlashAttention v3 (FA3) scheduler metadata once per batch during decode metadata initialization, eliminating 63 redundant GPU kernel calls per decode step.
  • Throughput Improvement: Achieved a ~1.5% increase in output throughput and a ~1.5% reduction in mean time per output token (TPOT).
  • Backward Compatibility: Ensured graceful fallback to the original behavior if the get_scheduler_metadata function is not available in the sgl-kernel, preventing regressions.
  • CUDA Graph Integration: Integrated scheduler metadata precomputation into CUDA graph capture and replay paths, including pre-allocating necessary buffers.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization by precomputing scheduler_metadata for FlashAttention v3, which avoids redundant kernel calls during decoding. The changes are well-structured and include support for CUDA graphs, with a fallback for backward compatibility.

My review identifies a couple of areas for improvement:

  1. A potential bug related to sliding window attention in CUDA graph paths due to code duplication.
  2. A minor simplification for better code readability.

Overall, this is a good performance enhancement. Addressing the identified issues will improve the robustness and maintainability of the code.

Comment thread python/sglang/srt/layers/attention/flashattention_backend.py Outdated
Comment thread python/sglang/srt/layers/attention/flashattention_backend.py Outdated
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@zminglei
Copy link
Copy Markdown
Collaborator Author

zminglei commented Mar 25, 2026

/tag-run-ci-label retry

@zminglei zminglei force-pushed the zminglei/fa3-scheduler-metadata-python branch from a5ad936 to 70c1df3 Compare April 7, 2026 16:33
…e_varlen_num_blocks

Call get_scheduler_metadata once per batch in decode metadata init
(including CUDA graph capture/replay paths) and pass the result to
flash_attn_with_kvcache so FA3 skips the internal prepare_varlen_num_blocks
kernel on every layer.

This eliminates 63 redundant GPU kernel calls per decode step (64 layers
to 1 call), saving ~215us/step (~2% decode throughput on TP=4).
@zminglei zminglei force-pushed the zminglei/fa3-scheduler-metadata-python branch from ccc3247 to 72b934d Compare April 7, 2026 20:30
@zminglei
Copy link
Copy Markdown
Collaborator Author

zminglei commented Apr 8, 2026

/rerun-failed-ci again

Without this, the dispatch wrapper defaults to ver=3 (FA3),
causing crashes on Blackwell (B200/sm_100) which uses FA4.
@Qiaolin-Yu Qiaolin-Yu merged commit 6af34b9 into sgl-project:main Apr 10, 2026
659 of 756 checks passed
Fridge003 pushed a commit that referenced this pull request Apr 11, 2026
…e_varlen_num_blocks (#21104)

Co-authored-by: zminglei <zminglei@linkedin.com>
pyc96 pushed a commit to pyc96/sglang that referenced this pull request Apr 14, 2026
…e_varlen_num_blocks (sgl-project#21104)

Co-authored-by: zminglei <zminglei@linkedin.com>
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
…e_varlen_num_blocks (sgl-project#21104)

Co-authored-by: zminglei <zminglei@linkedin.com>
YAMY1234 added a commit to YAMY1234/sglang that referenced this pull request May 7, 2026
The scheduler_metadata buffer precomputed in `_compute_scheduler_metadata`
(introduced by PR sgl-project#21104 to avoid per-layer `prepare_varlen_num_blocks`)
can become inconsistent with the `num_splits` the C++ `mha_fwd` kernel
derives from the live `cache_seqlens` once decode advances. The mismatch
triggers an out-of-bounds read in the FA3 split-KV combine kernel and
surfaces as a CUDA illegal-memory-access at
`flash_fwd_combine_launch_template.h:52`.

Reproduces with Qwen3-0.6B + `--enable-dp-attention --dp 8 --tp 8
--chunked-prefill-size 131072` on H200 after ~65 decode steps. Single-GPU
and TP-only paths are unaffected.

Skip the precompute when DP attention is on and let the C++ kernel
recompute its own metadata per layer. PR sgl-project#21104's optimization is
preserved on every other path. PR sgl-project#24235 had previously addressed a
narrower variant on NSA + EAGLE.

Co-authored-by: Cursor <cursoragent@cursor.com>
YAMY1234 added a commit to YAMY1234/sglang that referenced this pull request May 7, 2026
The scheduler_metadata buffer precomputed in `_compute_scheduler_metadata`
(introduced by PR sgl-project#21104 to avoid per-layer `prepare_varlen_num_blocks`)
can become inconsistent with the `num_splits` the C++ `mha_fwd` kernel
derives from the live `cache_seqlens` once decode advances. The mismatch
triggers an out-of-bounds read in the FA3 split-KV combine kernel and
surfaces as a CUDA illegal-memory-access at
`flash_fwd_combine_launch_template.h:52`.

Reproduces with Qwen3-0.6B + `--enable-dp-attention --dp 8 --tp 8
--chunked-prefill-size 131072` on H200 after ~65 decode steps. Single-GPU
and TP-only paths are unaffected.

Skip the precompute when DP attention is on and let the C++ kernel
recompute its own metadata per layer. PR sgl-project#21104's optimization is
preserved on every other path. PR sgl-project#24235 had previously addressed a
narrower variant on NSA + EAGLE.
YAMY1234 added a commit to YAMY1234/sglang that referenced this pull request May 8, 2026
The scheduler_metadata buffer precomputed in `_compute_scheduler_metadata` (introduced by PR sgl-project#21104 to avoid per-layer `prepare_varlen_num_blocks`) can become inconsistent with the `num_splits` the C++ `mha_fwd` kernel derives from live `cache_seqlens` once decode advances. The mismatch triggers an out-of-bounds read in the FA3 split-KV combine kernel and surfaces as a CUDA illegal-memory-access at `flash_fwd_combine_launch_template.h:52`.

Reproduces with Qwen3-0.6B + `--enable-dp-attention --dp 8 --tp 8 --chunked-prefill-size 131072` on H200 after ~65 decode steps. Single-GPU and TP-only paths are unaffected.

Skip the precompute when DP attention is on and leave `scheduler_metadata` unset, so FA3 uses its existing per-layer metadata path. This removes the stale precomputed metadata path for DP attention while preserving PR sgl-project#21104's optimization on non-DP paths.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants