perf: precompute FA3 scheduler_metadata to eliminate per-layer prepare_varlen_num_blocks#21104
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the performance of FlashAttention v3 during the decode phase by optimizing the handling of scheduler metadata. Instead of repeatedly computing this metadata for each layer, it is now precomputed once per batch, drastically reducing the number of GPU kernel invocations. This change results in measurable improvements in decoding speed and throughput, all while maintaining model accuracy and ensuring compatibility with existing kernel versions. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a performance optimization by precomputing scheduler_metadata for FlashAttention v3, which avoids redundant kernel calls during decoding. The changes are well-structured and include support for CUDA graphs, with a fallback for backward compatibility.
My review identifies a couple of areas for improvement:
- A potential bug related to sliding window attention in CUDA graph paths due to code duplication.
- A minor simplification for better code readability.
Overall, this is a good performance enhancement. Addressing the identified issues will improve the robustness and maintainability of the code.
1e1fd5e to
5ac86a4
Compare
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
/tag-run-ci-label retry |
a5ad936 to
70c1df3
Compare
…e_varlen_num_blocks Call get_scheduler_metadata once per batch in decode metadata init (including CUDA graph capture/replay paths) and pass the result to flash_attn_with_kvcache so FA3 skips the internal prepare_varlen_num_blocks kernel on every layer. This eliminates 63 redundant GPU kernel calls per decode step (64 layers to 1 call), saving ~215us/step (~2% decode throughput on TP=4).
ccc3247 to
72b934d
Compare
|
/rerun-failed-ci again |
Without this, the dispatch wrapper defaults to ver=3 (FA3), causing crashes on Blackwell (B200/sm_100) which uses FA4.
…e_varlen_num_blocks (#21104) Co-authored-by: zminglei <zminglei@linkedin.com>
…e_varlen_num_blocks (sgl-project#21104) Co-authored-by: zminglei <zminglei@linkedin.com>
…e_varlen_num_blocks (sgl-project#21104) Co-authored-by: zminglei <zminglei@linkedin.com>
The scheduler_metadata buffer precomputed in `_compute_scheduler_metadata` (introduced by PR sgl-project#21104 to avoid per-layer `prepare_varlen_num_blocks`) can become inconsistent with the `num_splits` the C++ `mha_fwd` kernel derives from the live `cache_seqlens` once decode advances. The mismatch triggers an out-of-bounds read in the FA3 split-KV combine kernel and surfaces as a CUDA illegal-memory-access at `flash_fwd_combine_launch_template.h:52`. Reproduces with Qwen3-0.6B + `--enable-dp-attention --dp 8 --tp 8 --chunked-prefill-size 131072` on H200 after ~65 decode steps. Single-GPU and TP-only paths are unaffected. Skip the precompute when DP attention is on and let the C++ kernel recompute its own metadata per layer. PR sgl-project#21104's optimization is preserved on every other path. PR sgl-project#24235 had previously addressed a narrower variant on NSA + EAGLE. Co-authored-by: Cursor <cursoragent@cursor.com>
The scheduler_metadata buffer precomputed in `_compute_scheduler_metadata` (introduced by PR sgl-project#21104 to avoid per-layer `prepare_varlen_num_blocks`) can become inconsistent with the `num_splits` the C++ `mha_fwd` kernel derives from the live `cache_seqlens` once decode advances. The mismatch triggers an out-of-bounds read in the FA3 split-KV combine kernel and surfaces as a CUDA illegal-memory-access at `flash_fwd_combine_launch_template.h:52`. Reproduces with Qwen3-0.6B + `--enable-dp-attention --dp 8 --tp 8 --chunked-prefill-size 131072` on H200 after ~65 decode steps. Single-GPU and TP-only paths are unaffected. Skip the precompute when DP attention is on and let the C++ kernel recompute its own metadata per layer. PR sgl-project#21104's optimization is preserved on every other path. PR sgl-project#24235 had previously addressed a narrower variant on NSA + EAGLE.
The scheduler_metadata buffer precomputed in `_compute_scheduler_metadata` (introduced by PR sgl-project#21104 to avoid per-layer `prepare_varlen_num_blocks`) can become inconsistent with the `num_splits` the C++ `mha_fwd` kernel derives from live `cache_seqlens` once decode advances. The mismatch triggers an out-of-bounds read in the FA3 split-KV combine kernel and surfaces as a CUDA illegal-memory-access at `flash_fwd_combine_launch_template.h:52`. Reproduces with Qwen3-0.6B + `--enable-dp-attention --dp 8 --tp 8 --chunked-prefill-size 131072` on H200 after ~65 decode steps. Single-GPU and TP-only paths are unaffected. Skip the precompute when DP attention is on and leave `scheduler_metadata` unset, so FA3 uses its existing per-layer metadata path. This removes the stale precomputed metadata path for DP attention while preserving PR sgl-project#21104's optimization on non-DP paths.
Summary
Call
get_scheduler_metadataonce per batch in decode metadata init (including CUDA graph capture/replay paths) and pass the result toflash_attn_with_kvcacheso FA3 skips the internalprepare_varlen_num_blockskernel on every layer.This eliminates 63 redundant GPU kernel calls per decode step (64 layers → 1 call via
get_scheduler_metadata), matching what vLLM does.Changes (1 file, Python only)
flashattention_backend.py: Compute and passscheduler_metadatain all decode paths:init_forward_metadata(non-CUDA-graph decode)init_cuda_graph_state(pre-allocate buffer)init_forward_metadata_capture_cuda_graph(capture path)init_forward_metadata_replay_cuda_graph(replay path)forward_decode(pass toflash_attn_with_kvcache)This is part 2 of 2: sglang Python changes only.
Requires sgl-kernel with
get_scheduler_metadata(#21103) for the optimization to activate (falls back to no-op without it).Benchmark Results (TP=4, Qwen3-32B, 4×H200, 3 runs averaged, fresh server restart each)
BS=1 Decode-heavy (input=500, output=8K)
BS=4 Decode-heavy (input=500, output=8K)
BS=16 Decode-heavy (input=500, output=8K)
BS=8 Balanced (input=2K, output=2K)
BS=8 Prefill-heavy (input=8K, output=500)
Accuracy (GSM8K, 1319 questions, parallel=1319)
No accuracy regression (within normal sampling variance).
How to Reproduce
Server (Qwen3-32B, TP=4):
Decode-heavy benchmark:
Balanced benchmark:
Prefill-heavy benchmark:
GSM8K accuracy:
Key takeaways
Profile verification
prepare_varlen_num_blockscalls reduced from 576 → 72 per profiling window (64 from prefill + 8 fromget_scheduler_metadata; all per-layer decode calls eliminated)main:

this PR:
