[Perf] precompute FA3 scheduler_metadata to eliminate per-layer prepare_varlen_num_blocks#20943
[Perf] precompute FA3 scheduler_metadata to eliminate per-layer prepare_varlen_num_blocks#20943zminglei wants to merge 7 commits intosgl-project:mainfrom
Conversation
…rlen_num_blocks Expose get_scheduler_metadata from sgl-attn C++ code through sgl_kernel, and call it once per batch in init_forward_metadata for decode. The result is passed to flash_attn_with_kvcache so FA3 skips the internal prepare_varlen_num_blocks kernel on every layer. Changes: - sgl_flash_kernel_ops.h: declare mha_fwd_get_scheduler_metadata - flash_extension.cc: register sgl_kernel.get_scheduler_metadata torch op - flash_attn.py: add Python get_scheduler_metadata wrapper - flashattention_backend.py: compute and pass scheduler_metadata in decode, including CUDA graph capture/replay paths with pre-allocated buffer Benchmark (Qwen3-32B TP=1, concurrency 4, input 1K, output 8K): - Baseline: 177.31 tok/s, TPOT 22.55ms - Optimized: 195.29 tok/s, TPOT 20.44ms - +10.1% throughput, -9.4% TPOT Profile: prepare_varlen_num_blocks calls reduced from 576 to 72 per profiling window (64 prefill + 8 decode steps, vs 64*9=576 before). All per-layer decode calls eliminated.
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a significant performance optimization for FlashAttention v3 (FA3) by streamlining the process of generating scheduler metadata during decoding. Instead of repeatedly invoking the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
|
Warning Gemini is experiencing higher than usual traffic and was unable to create the review. Please try again in a few hours by commenting |
|
/tag-and-rerun-ci |
|
/rerun-failed-ci again |
00699e5 to
2d91fd7
Compare
- Complete Python wrapper: expose all C++ op parameters (cu_seqlens_k, seqused_q, leftpad_k, attention_chunk) - Simplify has_softcap logic: single getattr call - Extract _compute_scheduler_metadata helper method to deduplicate code across init_forward_metadata, capture, and replay paths - Include window_size for SWA in all paths via the helper
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a performance optimization by pre-computing FlashAttention v3 scheduler metadata once per batch, avoiding redundant kernel calls in each layer during the decode step. The changes involve exposing a C++ function through a torch op, adding a Python wrapper, and integrating this pre-computation into all decode paths, including those with CUDA graphs. The implementation looks solid and the performance gains are well-documented. I have one suggestion to refactor duplicated code for updating the scheduler metadata buffer in CUDA graph paths to improve maintainability.
scheduler_metadata is only consumed by non-SWA layers (SWA layers skip it via the is_swa_layer guard in forward_decode). Computing it with SWA window_size produces metadata with wrong tile scheduling for the non-SWA layers that actually use it, causing CUDA illegal address errors on models with mixed SWA/non-SWA layers (e.g. GPT-OSS).
| " int sm_margin = 0" | ||
| ") -> Tensor"); | ||
|
|
||
| m.impl("get_scheduler_metadata", torch::kCUDA, make_pytorch_shim(&mha_fwd_get_scheduler_metadata)); |
There was a problem hiding this comment.
nit: do we want to explicitly keep the mha in the name? and in the comment add it does not work for mla
|
/rerun-failed-ci |
Summary
get_scheduler_metadatafrom sgl-attn C++ code through sgl_kernel torch op registrationscheduler_metadatatoflash_attn_with_kvcacheso FA3 skips the internalprepare_varlen_num_blockskernel on every layerThis eliminates 63 redundant GPU kernel calls per decode step (64 layers → 1 call via
get_scheduler_metadata), matching what vLLM does.Changes (4 files)
sgl-kernel/include/sgl_flash_kernel_ops.h: Declaremha_fwd_get_scheduler_metadata(already compiled in flash_ops.so, just not exposed)sgl-kernel/csrc/flash_extension.cc: Registersgl_kernel.get_scheduler_metadatatorch opsgl-kernel/python/sgl_kernel/flash_attn.py: Add Pythonget_scheduler_metadata()wrapperflashattention_backend.py: Compute and passscheduler_metadatain all decode paths:init_forward_metadata(non-CUDA-graph decode)init_cuda_graph_state(pre-allocate buffer)init_forward_metadata_capture_cuda_graph(capture path)init_forward_metadata_replay_cuda_graph(replay path)forward_decode(pass toflash_attn_with_kvcache)Benchmark Results (TP=4, Qwen3-32B, 4×H200, 3 runs averaged, fresh server restart each)
BS=1 Decode-heavy (input=500, output=8K)
BS=4 Decode-heavy (input=500, output=8K)
BS=16 Decode-heavy (input=500, output=8K)
BS=8 Balanced (input=2K, output=2K)
BS=8 Prefill-heavy (input=8K, output=500)
Accuracy (GSM8K, 1319 questions, parallel=1319)
No regression with speculative decoding.
How to Reproduce
Server (Qwen3-32B, TP=4):
Decode-heavy benchmark:
Balanced benchmark:
Prefill-heavy benchmark:
GSM8K accuracy:
Key takeaways
Profile verification
prepare_varlen_num_blockscalls reduced from 576 → 72 per profiling window (64 from prefill + 8 fromget_scheduler_metadata; all per-layer decode calls eliminated)main:

this PR:

Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci) or contact authorized users to do so.