Skip to content

[Perf] precompute FA3 scheduler_metadata to eliminate per-layer prepare_varlen_num_blocks#20943

Open
zminglei wants to merge 7 commits intosgl-project:mainfrom
zminglei:zminglei/fa3-scheduler-metadata
Open

[Perf] precompute FA3 scheduler_metadata to eliminate per-layer prepare_varlen_num_blocks#20943
zminglei wants to merge 7 commits intosgl-project:mainfrom
zminglei:zminglei/fa3-scheduler-metadata

Conversation

@zminglei
Copy link
Copy Markdown
Collaborator

@zminglei zminglei commented Mar 19, 2026

Summary

  • Expose get_scheduler_metadata from sgl-attn C++ code through sgl_kernel torch op registration
  • Call it once per batch in decode metadata init (including CUDA graph capture/replay paths)
  • Pass the precomputed scheduler_metadata to flash_attn_with_kvcache so FA3 skips the internal prepare_varlen_num_blocks kernel on every layer

This eliminates 63 redundant GPU kernel calls per decode step (64 layers → 1 call via get_scheduler_metadata), matching what vLLM does.

Changes (4 files)

  • sgl-kernel/include/sgl_flash_kernel_ops.h: Declare mha_fwd_get_scheduler_metadata (already compiled in flash_ops.so, just not exposed)
  • sgl-kernel/csrc/flash_extension.cc: Register sgl_kernel.get_scheduler_metadata torch op
  • sgl-kernel/python/sgl_kernel/flash_attn.py: Add Python get_scheduler_metadata() wrapper
  • flashattention_backend.py: Compute and pass scheduler_metadata in all decode paths:
    • init_forward_metadata (non-CUDA-graph decode)
    • init_cuda_graph_state (pre-allocate buffer)
    • init_forward_metadata_capture_cuda_graph (capture path)
    • init_forward_metadata_replay_cuda_graph (replay path)
    • forward_decode (pass to flash_attn_with_kvcache)

Benchmark Results (TP=4, Qwen3-32B, 4×H200, 3 runs averaged, fresh server restart each)

BS=1 Decode-heavy (input=500, output=8K)

Metric Baseline (avg) Optimized (avg) Delta
Output throughput (tok/s) 131.41 134.83 +2.6%
Mean TPOT (ms) 7.61 7.41 -2.6%
Mean ITL (ms) 7.61 7.41 -2.6%
Mean TTFT (ms) 27.99 26.18 ~same

BS=4 Decode-heavy (input=500, output=8K)

Metric Baseline (avg) Optimized (avg) Delta
Output throughput (tok/s) 496.63 508.21 +2.3%
Mean TPOT (ms) 8.04 7.86 -2.2%
Mean ITL (ms) 8.04 7.86 -2.2%
Mean TTFT (ms) 117.68 117.51 -0.1% (no regression)

BS=16 Decode-heavy (input=500, output=8K)

Metric Baseline (avg) Optimized (avg) Delta
Output throughput (tok/s) 1735.09 1768.68 +1.9%
Mean TPOT (ms) 9.19 9.01 -2.0%
Mean TTFT (ms) 295.75 290.28 ~same

BS=8 Balanced (input=2K, output=2K)

Metric Baseline (avg) Optimized (avg) Delta
Output throughput (tok/s) 934.61 955.25 +2.2%
Mean TPOT (ms) 8.34 8.15 -2.3%
Mean ITL (ms) 8.34 8.15 -2.3%
Mean TTFT (ms) 445.16 449.32 +0.9% (within noise)

BS=8 Prefill-heavy (input=8K, output=500)

Metric Baseline (avg) Optimized (avg) Delta
Output throughput (tok/s) 594.83 603.30 +1.4%
Mean TPOT (ms) 10.10 9.92 -1.8%
Mean ITL (ms) 10.10 9.92 -1.8%
Mean TTFT (ms) 1670.04 1669.40 -0.04% (no regression)

Accuracy (GSM8K, 1319 questions, parallel=1319)

Branch Accuracy
Baseline 0.858
Optimized 0.853

No regression with speculative decoding.

How to Reproduce

Server (Qwen3-32B, TP=4):

python -m sglang.launch_server   --model-path Qwen/Qwen3-32B   --reasoning-parser qwen3   --enable-piecewise-cuda-graph   --tp 4 --port 30000

Decode-heavy benchmark:

python -m sglang.bench_serving --backend sglang   --num-prompts 4 --dataset-name random-ids   --random-input-len 500 --random-output-len 8000   --random-range-ratio 1.0 --port 30000

Balanced benchmark:

python -m sglang.bench_serving --backend sglang   --num-prompts 8 --dataset-name random-ids   --random-input-len 2000 --random-output-len 2000   --random-range-ratio 1.0 --port 30000

Prefill-heavy benchmark:

python -m sglang.bench_serving --backend sglang   --num-prompts 8 --dataset-name random-ids   --random-input-len 8000 --random-output-len 500   --random-range-ratio 1.0 --port 30000

GSM8K accuracy:

python benchmark/gsm8k/bench_sglang.py   --num-questions 1319 --parallel 1319 --port 30000

Key takeaways

  • Consistent 1.4-2.6% decode throughput improvement across all traffic patterns and batch sizes
  • No TTFT regression — averaged across 3 fresh-server runs for each configuration
  • Gain is purely from decode optimization; prefill path is untouched
  • The fixed ~200µs/step savings means gain scales inversely with TPOT: BS=1 TPOT=7.6ms → 2.6%, BS=8 TPOT=10.1ms → 1.8%
  • Projected ~5% at TP=8 where TPOT is ~4ms

Profile verification

  • prepare_varlen_num_blocks calls reduced from 576 → 72 per profiling window (64 from prefill + 8 from get_scheduler_metadata; all per-layer decode calls eliminated)

main:
image

this PR:
image

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments (/tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci) or contact authorized users to do so.
  4. After green CI and required approvals, ask Merge Oncalls to merge.

…rlen_num_blocks

Expose get_scheduler_metadata from sgl-attn C++ code through sgl_kernel,
and call it once per batch in init_forward_metadata for decode. The result
is passed to flash_attn_with_kvcache so FA3 skips the internal
prepare_varlen_num_blocks kernel on every layer.

Changes:
- sgl_flash_kernel_ops.h: declare mha_fwd_get_scheduler_metadata
- flash_extension.cc: register sgl_kernel.get_scheduler_metadata torch op
- flash_attn.py: add Python get_scheduler_metadata wrapper
- flashattention_backend.py: compute and pass scheduler_metadata in decode,
  including CUDA graph capture/replay paths with pre-allocated buffer

Benchmark (Qwen3-32B TP=1, concurrency 4, input 1K, output 8K):
- Baseline: 177.31 tok/s, TPOT 22.55ms
- Optimized: 195.29 tok/s, TPOT 20.44ms
- +10.1% throughput, -9.4% TPOT

Profile: prepare_varlen_num_blocks calls reduced from 576 to 72 per
profiling window (64 prefill + 8 decode steps, vs 64*9=576 before).
All per-layer decode calls eliminated.
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant performance optimization for FlashAttention v3 (FA3) by streamlining the process of generating scheduler metadata during decoding. Instead of repeatedly invoking the prepare_varlen_num_blocks kernel for each attention layer, the necessary metadata is now efficiently precomputed once per batch. This change substantially reduces redundant GPU kernel invocations, leading to a measurable increase in decode throughput across various workloads without negatively impacting Time To First Token (TTFT).

Highlights

  • FlashAttention v3 (FA3) Optimization: Exposed the get_scheduler_metadata C++ function as a sgl_kernel PyTorch operation to enable precomputation of FA3 scheduler metadata.
  • Reduced GPU Kernel Calls: Implemented precomputation of FA3 scheduler metadata once per batch during decode metadata initialization, eliminating redundant per-layer prepare_varlen_num_blocks kernel calls.
  • Performance Improvement: Achieved a consistent 1.4-2.6% improvement in decode throughput across various batch sizes and traffic patterns, with no regression in Time To First Token (TTFT).
  • CUDA Graph Integration: Integrated the scheduler metadata precomputation into both non-CUDA graph and CUDA graph capture/replay paths for comprehensive optimization.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@zminglei zminglei changed the title perf: precompute FA3 scheduler_metadata to eliminate per-layer prepare_varlen_num_blocks [Perf] precompute FA3 scheduler_metadata to eliminate per-layer prepare_varlen_num_blocks Mar 19, 2026
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

Gemini is experiencing higher than usual traffic and was unable to create the review. Please try again in a few hours by commenting /gemini review.

@zminglei
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

@zminglei
Copy link
Copy Markdown
Collaborator Author

zminglei commented Mar 20, 2026

/rerun-failed-ci again

@Qiaolin-Yu Qiaolin-Yu self-assigned this Mar 20, 2026
@zminglei zminglei force-pushed the zminglei/fa3-scheduler-metadata branch from 00699e5 to 2d91fd7 Compare March 20, 2026 05:12
Copy link
Copy Markdown
Collaborator

@Qiaolin-Yu Qiaolin-Yu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

- Complete Python wrapper: expose all C++ op parameters (cu_seqlens_k,
  seqused_q, leftpad_k, attention_chunk)
- Simplify has_softcap logic: single getattr call
- Extract _compute_scheduler_metadata helper method to deduplicate code
  across init_forward_metadata, capture, and replay paths
- Include window_size for SWA in all paths via the helper
@zminglei
Copy link
Copy Markdown
Collaborator Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization by pre-computing FlashAttention v3 scheduler metadata once per batch, avoiding redundant kernel calls in each layer during the decode step. The changes involve exposing a C++ function through a torch op, adding a Python wrapper, and integrating this pre-computation into all decode paths, including those with CUDA graphs. The implementation looks solid and the performance gains are well-documented. I have one suggestion to refactor duplicated code for updating the scheduler metadata buffer in CUDA graph paths to improve maintainability.

Comment thread python/sglang/srt/layers/attention/flashattention_backend.py
scheduler_metadata is only consumed by non-SWA layers (SWA layers skip
it via the is_swa_layer guard in forward_decode). Computing it with SWA
window_size produces metadata with wrong tile scheduling for the non-SWA
layers that actually use it, causing CUDA illegal address errors on models
with mixed SWA/non-SWA layers (e.g. GPT-OSS).
Copy link
Copy Markdown
Contributor

@jasperjiaguo jasperjiaguo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, minor naming

" int sm_margin = 0"
") -> Tensor");

m.impl("get_scheduler_metadata", torch::kCUDA, make_pytorch_shim(&mha_fwd_get_scheduler_metadata));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do we want to explicitly keep the mha in the name? and in the comment add it does not work for mla

@jasperjiaguo
Copy link
Copy Markdown
Contributor

/rerun-failed-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants