Cache packed sequence metadata to reduce D2H syncs across layers#4243
Conversation
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a crucial performance enhancement by intelligently caching packed sequence metadata. By preventing redundant data transfers between device and host within a single forward pass, it drastically cuts down on synchronization overhead. This optimization ensures that once sequence length information is computed for the first layer, subsequent layers can retrieve it from a cache, leading to significant speedups in model training and inference, particularly for models utilizing packing. Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a caching mechanism to reduce redundant D2H copies and synchronizations across layers, which is a significant performance improvement. The implementation is sound and correctly uses object identity for cache invalidation. A clear_packed_caches function is also added for manual memory management. My only suggestion is to improve the type hints for the new cache dictionaries for better code clarity and maintainability.
| # Cache per device for get_packed_info_from_kwargs to avoid repeated D2H sync across layers | ||
| _PACKED_INFO_CACHE: dict = {} | ||
|
|
||
| # Cache per device for build_sdpa_packed_attention_mask to avoid repeated D2H sync across layers | ||
| _SDPA_MASK_CACHE: dict = {} | ||
|
|
||
| # Cache per device for build_xformers_block_causal_mask to avoid repeated D2H sync across layers | ||
| _XFORMERS_BLOCK_MASK_CACHE: dict = {} |
There was a problem hiding this comment.
The type hints for the new cache dictionaries are dict. For better code clarity and type safety, consider using more specific types from the typing module. You'll need to add Dict and Any to your imports from typing.
Using Dict[torch.device, Dict[str, Any]] would be more descriptive. For even better documentation of the cache entry structure, you could consider using TypedDict.
| # Cache per device for get_packed_info_from_kwargs to avoid repeated D2H sync across layers | |
| _PACKED_INFO_CACHE: dict = {} | |
| # Cache per device for build_sdpa_packed_attention_mask to avoid repeated D2H sync across layers | |
| _SDPA_MASK_CACHE: dict = {} | |
| # Cache per device for build_xformers_block_causal_mask to avoid repeated D2H sync across layers | |
| _XFORMERS_BLOCK_MASK_CACHE: dict = {} | |
| # Cache per device for get_packed_info_from_kwargs to avoid repeated D2H sync across layers | |
| _PACKED_INFO_CACHE: "Dict[torch.device, Dict[str, Any]]" = {} | |
| # Cache per device for build_sdpa_packed_attention_mask to avoid repeated D2H sync across layers | |
| _SDPA_MASK_CACHE: "Dict[torch.device, Dict[str, Any]]" = {} | |
| # Cache per device for build_xformers_block_causal_mask to avoid repeated D2H sync across layers | |
| _XFORMERS_BLOCK_MASK_CACHE: "Dict[torch.device, Dict[str, Any]]" = {} |
|
Verified on NVIDIA B200 (178GB), Torch 2.9.1, TRL 0.25.1, SDPA backend (no flash_attn). Ran 61-step SFT with packing on ResultsLlama-3.2-1B-Instruct (16 layers)
Qwen3-0.6B (28 layers)
Correctness
Identity check verificationInstrumented Also ran 8 unit tests against the caching logic directly: cache hit/miss on same/different tensor objects, bounded cache size (1 entry per device), Minor notes (non-blocking)
LGTM. |
…lothai#4243) * packing optimziation with cache to reduce D2H copy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cache per device to avoid race condition for multi-gpu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add cache freeing up func --------- Co-authored-by: ruixiangw <ruixiangw@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: ruixiang <wangruixiang07@outlook.com>
Replacement for #4173 due to Studio rebasing
Added per-forward-pass caching to eliminate redundant D2H copies and
cudaStreamSynchronizecalls across layers.When packing (padding-free) is enabled, three functions are called on every layer of the model during the forward pass:
get_packed_info_from_kwargs: callslengths.max().item()— triggers D2H copy + syncbuild_sdpa_packed_attention_mask(SDPA backend): callsseq_lengths.sum().item()andseq_lengths.tolist()— triggers 2 D2H copies + syncsbuild_xformers_block_causal_mask(XFormers backend): callsseq_lengths.to("cpu")— triggers D2H copy + syncFor a model with N layers, this results in N unnecessary D2H synchronizations per function, even though the packed sequence metadata (
seq_lengths) is identical across all layers within the same forward pass.Solution
Cache the output of each function using Python object identity (is) comparison on the
seq_lengthstensor. Since the sameseq_lengthstensor object is passed to all layers within a single forward pass, subsequent layers hit the cache and skip the D2H operations entirely. A new batch produces a newseq_lengthstensor object, which naturally invalidates the cache.This reduces D2H synchronizations per forward pass:
get_packed_info_from_kwargsbuild_sdpa_packed_attention_maskbuild_xformers_block_causal_maskNsys profiling traces:
Performance
With the caching strategy, CudaStreamSync only appears in the first layer and in the following layers it disappears. We achieve around 43.3% speedup for forward, 5.8% speedup for backward, 14.3% speedup for each batch for Qwen3 14B QLoRA SFT.