Skip to content

perf: eliminate attention DtoD copy by passing pre-allocated output to FA#21985

Merged
Qiaolin-Yu merged 10 commits intosgl-project:mainfrom
jasperjiaguo:jiaguo/eliminate-attn-dtod-copy
Apr 24, 2026
Merged

perf: eliminate attention DtoD copy by passing pre-allocated output to FA#21985
Qiaolin-Yu merged 10 commits intosgl-project:mainfrom
jasperjiaguo:jiaguo/eliminate-attn-dtod-copy

Conversation

@jasperjiaguo
Copy link
Copy Markdown
Contributor

@jasperjiaguo jasperjiaguo commented Apr 3, 2026

Summary

Eliminate a ~14µs DtoD memcpy per decoder layer by passing the pre-allocated output tensor directly to Flash Attention, rather than letting FA allocate internally and then copying.

Root Cause

unified_attention_with_output pre-allocates an output tensor, calls attn_backend.forward() which invokes FA3 without an out= parameter, then does output.copy_(ret) — producing a 29MB DtoD memcpy per layer (~14µs × 28 layers = ~392µs per forward pass). vLLM avoids this by passing out=output to FA3 directly.

Changes

  1. radix_attention.py: Pass output=output to attn_backend.forward(), skip copy_() when ret.data_ptr() == output.data_ptr()
  2. flashattention_backend.py: Accept output param in forward_extend and forward_decode, reshape to _fa_out and pass out=_fa_out to flash_attn_with_kvcache
  3. sgl-kernel/flash_attn.py: Add out=None parameter to both flash_attn_with_kvcache and flash_attn_varlen_func, pass through to sgl_kernel.fwd
  4. sgl-kernel/flash_extension.cc: Fix op schema from Tensor? out to Tensor(a!)? out and return type from Tensor to Tensor(a!) — tells PyTorch dispatch that the return aliases the out input, preventing a defensive copy

Profile Results (7k token FP8 embedding, H200)

Metric Before After
DtoD events per forward 33 (29 large) 5 (1 large)
Per-layer DtoD copies 28 0

Benchmark Results (Qwen3-0.6B FP8, H200, production traffic distribution)

Config Items/sec vs Baseline
Baseline (main) 30.77
With PR #21734 + #21971 37.29 +21.2%
+ This PR (DtoD eliminated) 37.83 +22.9%

Test plan

  • Correctness: embedding outputs match baseline (cosine similarity > 0.999)
  • Profile: DtoD events reduced from 33 to 5 (28 per-layer copies eliminated)
  • Benchmark: 37.83 items/sec (+1.5% over without DtoD fix)
  • CI tests pass
  • Non-embedding workloads not regressed

🤖 Generated with Claude Code

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@jasperjiaguo jasperjiaguo force-pushed the jiaguo/eliminate-attn-dtod-copy branch 4 times, most recently from eb89f27 to 80a122c Compare April 4, 2026 00:44
@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/tag-and-rerun-ci

@github-actions github-actions Bot added the run-ci label Apr 6, 2026
@jasperjiaguo jasperjiaguo changed the title perf: eliminate attention DtoD copy by passing pre-allocated output to FA [WIP]perf: eliminate attention DtoD copy by passing pre-allocated output to FA Apr 6, 2026
@jasperjiaguo jasperjiaguo changed the title [WIP]perf: eliminate attention DtoD copy by passing pre-allocated output to FA [WIP] perf: eliminate attention DtoD copy by passing pre-allocated output to FA Apr 6, 2026
@jasperjiaguo jasperjiaguo changed the title [WIP] perf: eliminate attention DtoD copy by passing pre-allocated output to FA perf: eliminate attention DtoD copy by passing pre-allocated output to FA Apr 7, 2026
@jasperjiaguo jasperjiaguo force-pushed the jiaguo/eliminate-attn-dtod-copy branch 4 times, most recently from bdc4cab to dbc2b99 Compare April 7, 2026 18:23
@Qiaolin-Yu Qiaolin-Yu self-assigned this Apr 7, 2026
@jasperjiaguo jasperjiaguo force-pushed the jiaguo/eliminate-attn-dtod-copy branch from dbc2b99 to 28cbd21 Compare April 7, 2026 21:54
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 20, 2026
…l-project#21971 compat)

PR sgl-project#21971 added a new fa_skip_kv_cache path in forward_extend that uses
flash_attn_varlen_func for embedding mode. That path was missing out=_fa_out,
so the DtoD copy elimination from sgl-project#21985 did not cover it.
@jasperjiaguo jasperjiaguo force-pushed the jiaguo/eliminate-attn-dtod-copy branch from 9d6583a to d62806b Compare April 20, 2026 21:39
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 21, 2026
…l-project#21971 compat)

PR sgl-project#21971 added a new fa_skip_kv_cache path in forward_extend that uses
flash_attn_varlen_func for embedding mode. That path was missing out=_fa_out,
so the DtoD copy elimination from sgl-project#21985 did not cover it.
@jasperjiaguo jasperjiaguo force-pushed the jiaguo/eliminate-attn-dtod-copy branch from d62806b to cc92c16 Compare April 21, 2026 00:07
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 21, 2026
…l-project#21971 compat)

PR sgl-project#21971 added a new fa_skip_kv_cache path in forward_extend that uses
flash_attn_varlen_func for embedding mode. That path was missing out=_fa_out,
so the DtoD copy elimination from sgl-project#21985 did not cover it.
@jasperjiaguo jasperjiaguo force-pushed the jiaguo/eliminate-attn-dtod-copy branch from cc92c16 to 8c810fd Compare April 21, 2026 17:15
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 21, 2026
…l-project#21971 compat)

PR sgl-project#21971 added a new fa_skip_kv_cache path in forward_extend that uses
flash_attn_varlen_func for embedding mode. That path was missing out=_fa_out,
so the DtoD copy elimination from sgl-project#21985 did not cover it.
@jasperjiaguo jasperjiaguo force-pushed the jiaguo/eliminate-attn-dtod-copy branch from 8c810fd to 4739b5e Compare April 21, 2026 19:39
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 21, 2026
…l-project#21971 compat)

PR sgl-project#21971 added a new fa_skip_kv_cache path in forward_extend that uses
flash_attn_varlen_func for embedding mode. That path was missing out=_fa_out,
so the DtoD copy elimination from sgl-project#21985 did not cover it.
@jasperjiaguo jasperjiaguo force-pushed the jiaguo/eliminate-attn-dtod-copy branch from 4739b5e to a91cbfc Compare April 21, 2026 22:24
jasperjiaguo added a commit to jasperjiaguo/sglang that referenced this pull request Apr 22, 2026
…l-project#21971 compat)

PR sgl-project#21971 added a new fa_skip_kv_cache path in forward_extend that uses
flash_attn_varlen_func for embedding mode. That path was missing out=_fa_out,
so the DtoD copy elimination from sgl-project#21985 did not cover it.
jasperjiaguo and others added 10 commits April 23, 2026 14:54
…o FA

In unified_attention_with_output, the attention backend allocates a new
output tensor internally, then .copy_() copies it into the pre-allocated
output buffer. This causes a Memcpy DtoD (~14us) per attention layer.

Fix: pass the pre-allocated output tensor through to flash_attn via the
new out= parameter, so FA3 writes directly into it. Skip the .copy_()
when the returned tensor already points to the output buffer.

Changes:
- radix_attention.py: pass output=output, skip copy if same data_ptr
- flashattention_backend.py: accept output in forward_extend and
  forward_decode, reshape and pass as out= to flash_attn_with_kvcache
- sgl-kernel/flash_attn.py: add out= param to flash_attn_with_kvcache,
  pass through to sgl_kernel.fwd

Applies to all workloads (generation, embedding, decode).

Profile (Qwen3-0.6B FP8, H200, 7k tokens):
  DtoD events: 33 (408us) -> 5 (12us)
  Large DtoD (>10us): 28 -> 0

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The fwd op schema declares out as Tensor? (optional input) instead of
Tensor(a!)? (mutating output alias). This causes PyTorch dispatch to
create a copy of the output tensor even when the C++ mha_fwd writes
directly into the provided out buffer, resulting in a Memcpy DtoD
(~14us) per attention layer per forward pass.

Fix: change schema from Tensor? to Tensor(a!)? for the out parameter,
and mark the first return as Tensor(a!) to indicate it aliases out.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Pass the pre-allocated output tensor through to sgl_kernel.fwd in
flash_attn_varlen_func, matching the existing support in
flash_attn_with_kvcache. This allows callers (e.g. the skip_kv_cache
embedding path) to write FA3 output directly into the pre-allocated
buffer, eliminating a 14us DtoD memcpy per decoder layer.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… paths

Add out=_fa_out to the two existing flash_attn_varlen_func calls in
forward_extend (chunked prefix and MHA extend paths), and add out=None
parameter to flash_attn_varlen_func in sgl-kernel so the pre-allocated
output buffer is passed through to sgl_kernel.fwd.

This eliminates ~14us DtoD memcpy per decoder layer for all FA paths,
not just the kvcache path.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Extract _fa_out reshape from 5 inline sites to 1 assignment at the
  top of each function (forward_extend, forward_decode)
- Restore numel assert in unified_attention_with_output for the
  fallback copy path (safety check when FA does not write to out)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…put=

The base AttentionBackend.forward() passes **kwargs to forward_extend()
and forward_decode(). When output= is passed from radix_attention,
backends like FlashInfer that do not explicitly accept output= fail
with TypeError. Adding **kwargs to the base signatures allows all
backends to accept and ignore unknown kwargs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The FA backend imports flash_attn_varlen_func from
sglang.jit_kernel.flash_attention, which wraps flash_attention_v3,
which wraps sgl_kernel.flash_attn. Both wrapper layers need to accept
and forward the out= parameter for the pre-allocated output to reach
the underlying sgl_kernel.fwd call.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Instead of passing output= through kwargs (which breaks backends without
**kwargs like FlashInfer), store the pre-allocated output on
forward_batch._attn_output in radix_attention. Only FA backend reads it.
Other backends are completely unaffected — no signature changes needed.

Also adds out=None to jit_kernel flash_attn wrappers (lint fix).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
During CUDA graph capture, output has max_num_tokens rows but query is
sliced to real_num_tokens. The FA kernel validates out.size(0) == q.size(0),
so _attn_output must also be sliced to avoid shape mismatch.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…l-project#21971 compat)

PR sgl-project#21971 added a new fa_skip_kv_cache path in forward_extend that uses
flash_attn_varlen_func for embedding mode. That path was missing out=_fa_out,
so the DtoD copy elimination from sgl-project#21985 did not cover it.
@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@jasperjiaguo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants