Skip to content

[Perf][3/n] Eliminate GPU<->CPU syncs in attention impls#41434

Merged
vllm-bot merged 1 commit into
mainfrom
fix-gpucpu-syncs3
May 8, 2026
Merged

[Perf][3/n] Eliminate GPU<->CPU syncs in attention impls#41434
vllm-bot merged 1 commit into
mainfrom
fix-gpucpu-syncs3

Conversation

@njhill

@njhill njhill commented May 1, 2026

Copy link
Copy Markdown
Member

Unnecessary gpu/cpu syncs in attention implementations, found via #40561.

TurboQuant benchmark

Each scenario runs vLLM with --tensor-parallel-size 1 --distributed-executor-backend uni (UniProcExecutor) on a single NVIDIA GB200 GPU. Model: Qwen/Qwen3-0.6B. Each side (without / with change) is the mean ± population std across 3 timed runs sharing one server process; each run uses its own seed (1, 2, 3) and is preceded by a fresh warmup batch. Δ = relative change of with-mean vs. without-mean (✓ = improvement, ✗ = regression).

TurboQuant c=32, 256 in / 2048 out

VLLM_ATTENTION_BACKEND=TURBOQUANT --kv-cache-dtype turboquant_k8v4, 384 prompts/run, 24 warmups

Per-side metadata:

Without change:

  • runs: 3, prompts/run: 384, completed/run: 384 ±0, failed/run: 0 ±0
  • max-concurrency: 32, request-rate: inf
  • duration: 104.8s..106.3s (mean 105.7s)
  • total in/out tokens per run: 98304/786432

With change:

  • runs: 3, prompts/run: 384, completed/run: 384 ±0, failed/run: 0 ±0
  • max-concurrency: 32, request-rate: inf
  • duration: 100.6s..100.9s (mean 100.8s)
  • total in/out tokens per run: 98304/786432
Metric Without (n=3) With (n=3) Δ (mean)
Output throughput (tok/s) 7441.51 ±46.79 7801.99 ±9.30 +4.84% ✓
Total throughput (tok/s) 8371.70 ±52.64 8777.23 ±10.47 +4.84% ✓
Request throughput (req/s) 3.6335 ±0.0228 3.8096 ±0.0045 +4.84% ✓
Mean TTFT (ms) 72.81 ±9.97 72.76 ±9.56 -0.07% ✓
P50 TTFT (ms) 69.08 ±8.35 61.54 ±8.83 -10.92% ✓
P90 TTFT (ms) 87.47 ±19.42 83.30 ±5.47 -4.76% ✓
P99 TTFT (ms) 96.89 ±23.49 177.06 ±133.78 +82.74% ✗
Mean TPOT (ms) 4.267 ±0.022 4.068 ±0.006 -4.66% ✓
P50 TPOT (ms) 4.258 ±0.017 4.063 ±0.001 -4.57% ✓
P90 TPOT (ms) 4.297 ±0.043 4.075 ±0.008 -5.17% ✓
P99 TPOT (ms) 4.324 ±0.063 4.129 ±0.076 -4.52% ✓
Mean ITL (ms) 4.270 ±0.021 4.073 ±0.006 -4.62% ✓
Mean E2EL (ms) 8807.06 ±55.16 8399.65 ±10.01 -4.63% ✓
P99 E2EL (ms) 8946.49 ±144.85 8606.55 ±147.68 -3.80% ✓
Duration (s) 105.69 ±0.66 100.80 ±0.12 -4.62% ✓

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This repository is configured for manual code reviews. Comment @claude review to trigger a review and subscribe this PR to future pushes, or @claude review once for a one-time review.

Tip: disable this comment in your organization's Code Review settings.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a series of optimizations across various attention backends to minimize CPU-GPU synchronization. The changes focus on using CPU-resident metadata for sequence length calculations, replacing synchronizing operations such as torch.nonzero and torch.bincount with asynchronous equivalents, and utilizing slice-based assignments to avoid implicit synchronizations. Furthermore, the async_tensor_h2d utility was enhanced to facilitate non-blocking host-to-device transfers. I have no feedback to provide.

@njhill

njhill commented May 1, 2026

Copy link
Copy Markdown
Member Author

@claude review

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performance-focused PR eliminating GPU↔CPU syncs across 7 attention backends; changes look correct and well-documented, but the breadth of critical paths touched (including subtle rewrites of torch.nonzero, torch.bincount, and mask-mod construction) warrants human review.

Extended reasoning...

Overview

This PR eliminates unnecessary GPU↔CPU synchronization in attention metadata builders and impls across FlashInfer, FlexAttention, Mamba, Tree, Triton, and TurboQuant backends, plus shared utils.py, buffer_utils.py, and penalties.py. It also renames async_tensor_h2d's target_device parameter to device and adds a module-level PIN_MEMORY constant. The recurring patterns are: (a) tensor[0] = xtensor[:1] = x / .fill_() to avoid scalar-assignment sync, (b) precomputing max() / max_seq_len / max_query_len on CPU instead of .max().item(), (c) building list-shaped tensors via pinned async_tensor_h2d rather than torch.tensor(.., device=cuda), and (d) replacing data-dependent ops (torch.nonzero, torch.bincount, repeat_interleave of GPU tensors) with sync-free equivalents.

Security risks

None. This is a pure performance optimization — no auth, crypto, permissions, network, or input-handling code is touched. All changes are local to GPU kernel orchestration.

Level of scrutiny

Medium-high. The PR is performance-only and the bug hunting system found no issues, but it touches the hot path of every supported attention backend on a v1/nvidia-labeled PR. The non-mechanical changes deserve a human eye:

  • unique_static_unsorted rewrite in flex_attention.py: routes non-kept entries to a garbage column N then slices it off — equivalence depends on scatter_ correctly handling duplicates by last-write-wins, which is fine for unique dest_pos but the construction relies on cumsum-1 producing unique indices for kept entries plus the constant N for non-kept. Looks correct, but is subtle.
  • torch.bincountscatter_add_ in utils.py: the rewrite is functionally equivalent assuming request_ids values are bounded by num_reqs, which the surrounding code guarantees via bucketize.
  • mamba_attn's switch from compute_num_computed_tokens().cpu() to seq_lens_cpu_upper_bound[…] - prefill_query_lens_cpu relies on the documented contract that the upper bound is precise for prefill rows; verified against the field docstring at vllm/v1/attention/backend.py:405-408.
  • tree_attn's precomputed max_query_len_prefill/decode and max_seq_len_prefill/decode: needs verification that these flow correctly into the cached prefill_metadata/decode_metadata properties, which they do.
  • The mask-mod simplification in flex (torch.where(is_valid, .., False)is_valid & ..) changes evaluation semantics — the inner mask_mod is now always invoked with possibly-out-of-range logical indices. Likely fine because mask_mod implementations are pure index comparisons, but worth confirming for non-trivial mask_mods.

Other factors

  • Maintainer njhill explicitly requested review.
  • This is part of a series ([Perf][3/n]) — prior PRs in the series presumably established the patterns being applied here.
  • Each change is well-commented in-line, explaining the why (which sync was forced and why the new form avoids it). That makes review tractable.
  • I have no prior reviews on this PR.

@njhill njhill added the ready ONLY add when PR is ready to merge/full CI is needed label May 1, 2026
Signed-off-by: Nick Hill <nickhill123@gmail.com>
@njhill njhill force-pushed the fix-gpucpu-syncs3 branch from ce0658f to ab31ebf Compare May 7, 2026 02:17

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't find any bugs, but this is a broad performance refactor across 7 attention backends with subtle correctness implications (replacing torch.bincount with scatter_add_, using seq_lens_cpu_upper_bound as a max-seq-len proxy in mamba/tree/turboquant, slice-assignment to dodge scalar-write syncs), so I'd like a human to give it a look despite the explicit perf wins shown in the description.

Extended reasoning...

Overview

This PR eliminates several GPU↔CPU synchronization points across attention backends and a couple of supporting modules:

  • vllm/utils/torch_utils.py: Adds a module-level PIN_MEMORY constant and renames async_tensor_h2ds target_device parameter to device with a default pin_memory=PIN_MEMORY.
  • vllm/v1/attention/backends/flashinfer.py: Replaces a scalar-write (paged_kv_indptr_prefill_gpu[0] = 0) with a slice assignment.
  • vllm/v1/attention/backends/flex_attention.py: Reworks _offsets_to_doc_ids_tensor to build on CPU and async-upload, plumbs a new query_start_loc_cpu field through FlexAttentionMetadata, and rewrites unique_static_unsorted to use a scatter rather than torch.nonzero.
  • vllm/v1/attention/backends/mamba_attn.py: Derives prefill context lengths from CPU upper-bound seq_lens and switches list→tensor uploads to async_tensor_h2d.
  • vllm/v1/attention/backends/tree_attn.py: Precomputes max_query_len/max_seq_len for prefill/decode sub-batches on CPU in the builder, eliminating per-forward .max().item() syncs.
  • vllm/v1/attention/backends/triton_attn.py: Switches compute_mm_prefix_range_tensor to async H2D.
  • vllm/v1/attention/backends/turboquant_attn.py: Adds CPU copies of query_start_loc/seq_lens to metadata to avoid .tolist()/.max() syncs in the prefill path.
  • vllm/v1/attention/backends/utils.py: Replaces torch.bincount with scatter_add_ in make_kv_sharing_fast_prefill_common_attn_metadata and pushes index tensors to device up front in make_local_attention_virtual_batches.
  • vllm/v1/worker/gpu/{buffer_utils,sample/penalties}.py: Updates calls to use the renamed async_tensor_h2d keyword.

Security risks

None — this is a pure-performance refactor. No auth, crypto, network, or input-validation surface is touched.

Level of scrutiny

Higher than a typical perf change. Although each individual diff is mechanical, several rely on non-obvious equivalences (the seq_lens_cpu_upper_bound substitution being precise for prefills and merely safely-optimistic for decodes; scatter_add_ semantics matching torch.bincount with minlength=; repeat_interleave on CPU vs GPU producing identical orderings) that benefit from a human familiar with the v1 attention stack confirming no edge case slips through.

Other factors

  • A vLLM committer (njhill) explicitly tagged me to review — the maintainers are aware this needs eyes.
  • The PR is labelled ready and includes a benchmark showing ~5% throughput improvement on TurboQuant.
  • The bug-hunting system found no defects, and the comments throughout the diff explain why each rewrite preserves correctness, which is a good sign.

prefill_start : num_reqs + 1
]
paged_kv_indptr_prefill_gpu[0] = 0
# Assign to slice to avoid cpu sync.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a lot of real black magic in this pr

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuda_tensor[0] = 0 uses copy_ which does a sync, cuda_tensor[:1] = 0 uses fill_ which doesn't :)

Comment on lines 415 to 416
decode_max_query_len = int(num_decode_tokens.max().item())
total_num_decode_tokens = int(num_decode_tokens.sum().item())

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can these be avoided?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly but it would require more significant rework I think.

For now I am opening a series of PRs with "low hanging" fixes. Remaining syncs can be wrapped in the gpu_sync_allowed() context manager when #40561 is merged and we'll at least know where they are and can decide if/when to put in additional work to address them.

@github-project-automation github-project-automation Bot moved this to Ready in NVIDIA May 7, 2026
@vllm-bot vllm-bot merged commit 989c176 into main May 8, 2026
71 of 75 checks passed
@vllm-bot vllm-bot deleted the fix-gpucpu-syncs3 branch May 8, 2026 02:44
@github-project-automation github-project-automation Bot moved this from Ready to Done in NVIDIA May 8, 2026
haosdent added a commit to haosdent/vllm that referenced this pull request May 8, 2026
The test patches `torch.fx.experimental.symbolic_shapes.make_symbol`
in the parent process and counts via a `multiprocessing.Value`. In V1
the actual compile runs inside an `EngineCore` subprocess that vLLM
spawns whenever CUDA is initialized in the parent (via
`_maybe_force_spawn`), so the monkey-patch never sees the compile path
and the counter stays at 0. This is a structural test-infra issue,
not a regression: CI flagged it on the build for vllm-project#41434, but the same
failure reproduces on its parent commit and is unrelated to that PR's
attention-impl changes.

Replace the brittle torch-internal monkey-patch with the existing
`compilation_counter.expect(...)` pattern already used by
`test_aot_counters_on_save_and_load`. Force
`VLLM_ENABLE_V1_MULTIPROCESSING=0` so the singleton counter is
incremented in the same process that runs the assertions; the cache
code path itself is identical in-process vs subprocess. Add
`cleanup_dist_env_and_memory()` between the two `LLM(...)` instances
and lower `gpu_memory_utilization` to 0.1 to leave headroom on the
second instantiation.

The activation-registry reset is preserved: without it, GPT-2's
`gelu_new` op leaves `disabled_custom_ops` mutated, the AOT cache
hash shifts between phases, and `VLLM_FORCE_AOT_LOAD=1` raises
FileNotFoundError.

Signed-off-by: haosdent <haosdent@gmail.com>
libinta pushed a commit to libinta/vllm that referenced this pull request May 8, 2026
…t#41434)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
Signed-off-by: Libin Tang <libin.tang@intel.com>
JisoLya pushed a commit to JisoLya/vllm that referenced this pull request May 10, 2026
…t#41434)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
Signed-off-by: JisoLya <523420504@qq.com>
weifang231 pushed a commit to weifang231/eb-vllm that referenced this pull request May 13, 2026
mfylcek pushed a commit to mfylcek/vllm that referenced this pull request May 19, 2026
jhu960213 pushed a commit to jhu960213/vllm that referenced this pull request May 20, 2026
mvanhorn pushed a commit to mvanhorn/vllm that referenced this pull request Jun 4, 2026
…t#41434)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
knight0528 pushed a commit to knight0528/vllm that referenced this pull request Jun 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants