perf: eliminate attention DtoD copy by passing pre-allocated output to FA#21985
Merged
Qiaolin-Yu merged 10 commits intosgl-project:mainfrom Apr 24, 2026
Merged
Conversation
Contributor
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
eb89f27 to
80a122c
Compare
Contributor
Author
|
/tag-and-rerun-ci |
bdc4cab to
dbc2b99
Compare
dbc2b99 to
28cbd21
Compare
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 20, 2026
…l-project#21971 compat) PR sgl-project#21971 added a new fa_skip_kv_cache path in forward_extend that uses flash_attn_varlen_func for embedding mode. That path was missing out=_fa_out, so the DtoD copy elimination from sgl-project#21985 did not cover it.
9d6583a to
d62806b
Compare
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 21, 2026
…l-project#21971 compat) PR sgl-project#21971 added a new fa_skip_kv_cache path in forward_extend that uses flash_attn_varlen_func for embedding mode. That path was missing out=_fa_out, so the DtoD copy elimination from sgl-project#21985 did not cover it.
d62806b to
cc92c16
Compare
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 21, 2026
…l-project#21971 compat) PR sgl-project#21971 added a new fa_skip_kv_cache path in forward_extend that uses flash_attn_varlen_func for embedding mode. That path was missing out=_fa_out, so the DtoD copy elimination from sgl-project#21985 did not cover it.
cc92c16 to
8c810fd
Compare
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 21, 2026
…l-project#21971 compat) PR sgl-project#21971 added a new fa_skip_kv_cache path in forward_extend that uses flash_attn_varlen_func for embedding mode. That path was missing out=_fa_out, so the DtoD copy elimination from sgl-project#21985 did not cover it.
8c810fd to
4739b5e
Compare
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 21, 2026
…l-project#21971 compat) PR sgl-project#21971 added a new fa_skip_kv_cache path in forward_extend that uses flash_attn_varlen_func for embedding mode. That path was missing out=_fa_out, so the DtoD copy elimination from sgl-project#21985 did not cover it.
4739b5e to
a91cbfc
Compare
jasperjiaguo
added a commit
to jasperjiaguo/sglang
that referenced
this pull request
Apr 22, 2026
…l-project#21971 compat) PR sgl-project#21971 added a new fa_skip_kv_cache path in forward_extend that uses flash_attn_varlen_func for embedding mode. That path was missing out=_fa_out, so the DtoD copy elimination from sgl-project#21985 did not cover it.
5 tasks
5 tasks
5 tasks
…o FA In unified_attention_with_output, the attention backend allocates a new output tensor internally, then .copy_() copies it into the pre-allocated output buffer. This causes a Memcpy DtoD (~14us) per attention layer. Fix: pass the pre-allocated output tensor through to flash_attn via the new out= parameter, so FA3 writes directly into it. Skip the .copy_() when the returned tensor already points to the output buffer. Changes: - radix_attention.py: pass output=output, skip copy if same data_ptr - flashattention_backend.py: accept output in forward_extend and forward_decode, reshape and pass as out= to flash_attn_with_kvcache - sgl-kernel/flash_attn.py: add out= param to flash_attn_with_kvcache, pass through to sgl_kernel.fwd Applies to all workloads (generation, embedding, decode). Profile (Qwen3-0.6B FP8, H200, 7k tokens): DtoD events: 33 (408us) -> 5 (12us) Large DtoD (>10us): 28 -> 0 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The fwd op schema declares out as Tensor? (optional input) instead of Tensor(a!)? (mutating output alias). This causes PyTorch dispatch to create a copy of the output tensor even when the C++ mha_fwd writes directly into the provided out buffer, resulting in a Memcpy DtoD (~14us) per attention layer per forward pass. Fix: change schema from Tensor? to Tensor(a!)? for the out parameter, and mark the first return as Tensor(a!) to indicate it aliases out. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Pass the pre-allocated output tensor through to sgl_kernel.fwd in flash_attn_varlen_func, matching the existing support in flash_attn_with_kvcache. This allows callers (e.g. the skip_kv_cache embedding path) to write FA3 output directly into the pre-allocated buffer, eliminating a 14us DtoD memcpy per decoder layer. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… paths Add out=_fa_out to the two existing flash_attn_varlen_func calls in forward_extend (chunked prefix and MHA extend paths), and add out=None parameter to flash_attn_varlen_func in sgl-kernel so the pre-allocated output buffer is passed through to sgl_kernel.fwd. This eliminates ~14us DtoD memcpy per decoder layer for all FA paths, not just the kvcache path. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Extract _fa_out reshape from 5 inline sites to 1 assignment at the top of each function (forward_extend, forward_decode) - Restore numel assert in unified_attention_with_output for the fallback copy path (safety check when FA does not write to out) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…put= The base AttentionBackend.forward() passes **kwargs to forward_extend() and forward_decode(). When output= is passed from radix_attention, backends like FlashInfer that do not explicitly accept output= fail with TypeError. Adding **kwargs to the base signatures allows all backends to accept and ignore unknown kwargs. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The FA backend imports flash_attn_varlen_func from sglang.jit_kernel.flash_attention, which wraps flash_attention_v3, which wraps sgl_kernel.flash_attn. Both wrapper layers need to accept and forward the out= parameter for the pre-allocated output to reach the underlying sgl_kernel.fwd call. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Instead of passing output= through kwargs (which breaks backends without **kwargs like FlashInfer), store the pre-allocated output on forward_batch._attn_output in radix_attention. Only FA backend reads it. Other backends are completely unaffected — no signature changes needed. Also adds out=None to jit_kernel flash_attn wrappers (lint fix). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
During CUDA graph capture, output has max_num_tokens rows but query is sliced to real_num_tokens. The FA kernel validates out.size(0) == q.size(0), so _attn_output must also be sliced to avoid shape mismatch. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…l-project#21971 compat) PR sgl-project#21971 added a new fa_skip_kv_cache path in forward_extend that uses flash_attn_varlen_func for embedding mode. That path was missing out=_fa_out, so the DtoD copy elimination from sgl-project#21985 did not cover it.
Contributor
Author
|
/rerun-failed-ci |
Contributor
Author
|
/rerun-failed-ci |
2 tasks
5 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Eliminate a ~14µs DtoD memcpy per decoder layer by passing the pre-allocated output tensor directly to Flash Attention, rather than letting FA allocate internally and then copying.
Root Cause
unified_attention_with_outputpre-allocates an output tensor, callsattn_backend.forward()which invokes FA3 without anout=parameter, then doesoutput.copy_(ret)— producing a 29MB DtoD memcpy per layer (~14µs × 28 layers = ~392µs per forward pass). vLLM avoids this by passingout=outputto FA3 directly.Changes
radix_attention.py: Passoutput=outputtoattn_backend.forward(), skipcopy_()whenret.data_ptr() == output.data_ptr()flashattention_backend.py: Acceptoutputparam inforward_extendandforward_decode, reshape to_fa_outand passout=_fa_outtoflash_attn_with_kvcachesgl-kernel/flash_attn.py: Addout=Noneparameter to bothflash_attn_with_kvcacheandflash_attn_varlen_func, pass through tosgl_kernel.fwdsgl-kernel/flash_extension.cc: Fix op schema fromTensor? outtoTensor(a!)? outand return type fromTensortoTensor(a!)— tells PyTorch dispatch that the return aliases theoutinput, preventing a defensive copyProfile Results (7k token FP8 embedding, H200)
Benchmark Results (Qwen3-0.6B FP8, H200, production traffic distribution)
Test plan
🤖 Generated with Claude Code