Skip to content

Cache packed sequence metadata to reduce D2H syncs across layers#4243

Merged
danielhanchen merged 5 commits into
mainfrom
dh/recover-4173-packing-cache
Mar 12, 2026
Merged

Cache packed sequence metadata to reduce D2H syncs across layers#4243
danielhanchen merged 5 commits into
mainfrom
dh/recover-4173-packing-cache

Conversation

@danielhanchen

Copy link
Copy Markdown
Member

Replacement for #4173 due to Studio rebasing

Added per-forward-pass caching to eliminate redundant D2H copies and cudaStreamSynchronize calls across layers.

When packing (padding-free) is enabled, three functions are called on every layer of the model during the forward pass:

  • get_packed_info_from_kwargs: calls lengths.max().item() — triggers D2H copy + sync
  • build_sdpa_packed_attention_mask (SDPA backend): calls seq_lengths.sum().item() and seq_lengths.tolist() — triggers 2 D2H copies + syncs
  • build_xformers_block_causal_mask (XFormers backend): calls seq_lengths.to("cpu") — triggers D2H copy + sync

For a model with N layers, this results in N unnecessary D2H synchronizations per function, even though the packed sequence metadata (seq_lengths) is identical across all layers within the same forward pass.

Solution

Cache the output of each function using Python object identity (is) comparison on the seq_lengths tensor. Since the same seq_lengths tensor object is passed to all layers within a single forward pass, subsequent layers hit the cache and skip the D2H operations entirely. A new batch produces a new seq_lengths tensor object, which naturally invalidates the cache.

This reduces D2H synchronizations per forward pass:

Function Before After
get_packed_info_from_kwargs N 1
build_sdpa_packed_attention_mask 2N 2
build_xformers_block_causal_mask N 1

Nsys profiling traces:

  • Without this PR:
image
  • With this PR:
image

Performance

With the caching strategy, CudaStreamSync only appears in the first layer and in the following layers it disappears. We achieve around 43.3% speedup for forward, 5.8% speedup for backward, 14.3% speedup for each batch for Qwen3 14B QLoRA SFT.

@gemini-code-assist

Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a crucial performance enhancement by intelligently caching packed sequence metadata. By preventing redundant data transfers between device and host within a single forward pass, it drastically cuts down on synchronization overhead. This optimization ensures that once sequence length information is computed for the first layer, subsequent layers can retrieve it from a cache, leading to significant speedups in model training and inference, particularly for models utilizing packing.

Highlights

  • Performance Optimization: Implemented per-forward-pass caching for packed sequence metadata to significantly reduce redundant Device-to-Host (D2H) copies and cudaStreamSynchronize calls across model layers. This addresses a performance bottleneck where seq_lengths metadata was repeatedly copied and synchronized for each layer, despite being identical.
  • Caching Mechanism: Introduced three new per-device caches (_PACKED_INFO_CACHE, _SDPA_MASK_CACHE, _XFORMERS_BLOCK_MASK_CACHE) that store the results of get_packed_info_from_kwargs, build_sdpa_packed_attention_mask, and build_xformers_block_causal_mask respectively. The cache invalidation is handled naturally by Python object identity, as a new batch generates a new seq_lengths tensor object.
  • Impact: Achieved substantial speedups: approximately 43.3% for forward pass, 5.8% for backward pass, and 14.3% per batch for Qwen3 14B QLoRA SFT, by reducing D2H synchronizations from N (number of layers) to 1 or 2 for key functions.
Changelog
  • unsloth/utils/packing.py
    • Added three new global dictionaries (_PACKED_INFO_CACHE, _SDPA_MASK_CACHE, _XFORMERS_BLOCK_MASK_CACHE) for caching packed sequence metadata per device.
    • Modified get_packed_info_from_kwargs to incorporate caching logic, checking for existing results based on seq_lengths object identity and device.
    • Updated build_xformers_block_causal_mask to utilize a new caching mechanism for block causal masks, using seq_lengths object identity, device, and sliding window parameters.
    • Implemented caching in build_sdpa_packed_attention_mask to store and retrieve attention masks based on seq_lengths object identity, device, and mask parameters.
    • Introduced a new function clear_packed_caches to allow explicit clearing of all newly added caches.
    • Exported clear_packed_caches in the module's __all__ list for external access.
Activity
  • No human activity has been recorded on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a caching mechanism to reduce redundant D2H copies and synchronizations across layers, which is a significant performance improvement. The implementation is sound and correctly uses object identity for cache invalidation. A clear_packed_caches function is also added for manual memory management. My only suggestion is to improve the type hints for the new cache dictionaries for better code clarity and maintainability.

Comment thread unsloth/utils/packing.py
Comment on lines +39 to +46
# Cache per device for get_packed_info_from_kwargs to avoid repeated D2H sync across layers
_PACKED_INFO_CACHE: dict = {}

# Cache per device for build_sdpa_packed_attention_mask to avoid repeated D2H sync across layers
_SDPA_MASK_CACHE: dict = {}

# Cache per device for build_xformers_block_causal_mask to avoid repeated D2H sync across layers
_XFORMERS_BLOCK_MASK_CACHE: dict = {}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hints for the new cache dictionaries are dict. For better code clarity and type safety, consider using more specific types from the typing module. You'll need to add Dict and Any to your imports from typing.

Using Dict[torch.device, Dict[str, Any]] would be more descriptive. For even better documentation of the cache entry structure, you could consider using TypedDict.

Suggested change
# Cache per device for get_packed_info_from_kwargs to avoid repeated D2H sync across layers
_PACKED_INFO_CACHE: dict = {}
# Cache per device for build_sdpa_packed_attention_mask to avoid repeated D2H sync across layers
_SDPA_MASK_CACHE: dict = {}
# Cache per device for build_xformers_block_causal_mask to avoid repeated D2H sync across layers
_XFORMERS_BLOCK_MASK_CACHE: dict = {}
# Cache per device for get_packed_info_from_kwargs to avoid repeated D2H sync across layers
_PACKED_INFO_CACHE: "Dict[torch.device, Dict[str, Any]]" = {}
# Cache per device for build_sdpa_packed_attention_mask to avoid repeated D2H sync across layers
_SDPA_MASK_CACHE: "Dict[torch.device, Dict[str, Any]]" = {}
# Cache per device for build_xformers_block_causal_mask to avoid repeated D2H sync across layers
_XFORMERS_BLOCK_MASK_CACHE: "Dict[torch.device, Dict[str, Any]]" = {}

@danielhanchen

Copy link
Copy Markdown
Member Author

Verified on NVIDIA B200 (178GB), Torch 2.9.1, TRL 0.25.1, SDPA backend (no flash_attn). Ran 61-step SFT with packing on yahma/alpaca-cleaned, batch_size=2, grad_accum=3, seed=3407, max_length=2048, 4bit LoRA r=16.

Results

Llama-3.2-1B-Instruct (16 layers)

Method Train Runtime (s) Samples/s Steps/s Peak Mem (GB) Speedup
Baseline (main) 105.32 3.48 0.58 9.41 -
PR 93.18 3.93 0.66 9.45 13.0%

Qwen3-0.6B (28 layers)

Method Train Runtime (s) Samples/s Steps/s Peak Mem (GB) Speedup
Baseline (main) 131.03 2.79 0.47 4.85 -
PR 111.60 3.28 0.55 4.88 17.4%

Correctness

  • All 61 loss values match at full float precision (max diff = 0.00e+00) for both models
  • All 61 grad norm values match at full float precision (max diff = 0.00e+00) for both models
  • Inference outputs are identical (Fibonacci continuation test)
  • This is expected since the PR returns the exact same cached Python objects on layers 2+, so the computation path is bit-identical

Identity check verification

Instrumented get_packed_info_from_kwargs across all 16 Llama layers and confirmed all 16 calls receive the same packed_seq_lengths tensor object (same id()). The kwargs dict is passed by reference through the layer loop in llama.py, so is-based identity caching is sound. 1 cache miss (layer 1) + 15 cache hits (layers 2-16). More layers = more D2H syncs avoided, which lines up with Qwen3 (28 layers, 17.4%) outperforming Llama (16 layers, 13.0%) in speedup.

Also ran 8 unit tests against the caching logic directly: cache hit/miss on same/different tensor objects, bounded cache size (1 entry per device), clear_packed_caches(), torch.zeros vs torch.empty+assign equivalence, cached SDPA mask not mutated in-place by scaled_dot_product_attention, and non-packed input returning None. All passed.

Minor notes (non-blocking)

  • clear_packed_caches() is defined but never wired into any lifecycle hook. Fine since each cache is bounded to 1 entry per device, so memory overhead is negligible.
  • clear_packed_caches() does not clear the pre-existing _XFORMERS_MASK_CACHE LRU cache (the 32-entry OrderedDict). Not a problem in practice but worth noting for completeness.
  • No unit tests for cache hit/miss logic in the PR itself. The existing test suite in tests/utils/test_packing.py does not cover caching behavior.

LGTM.

@danielhanchen danielhanchen merged commit 12f8525 into main Mar 12, 2026
4 checks passed
@danielhanchen danielhanchen deleted the dh/recover-4173-packing-cache branch March 12, 2026 10:37
Stanley00 pushed a commit to stanley-fork/unsloth that referenced this pull request Mar 12, 2026
…lothai#4243)

* packing optimziation with cache to reduce D2H copy

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cache per device to avoid race condition for multi-gpu

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add cache freeing up func

---------

Co-authored-by: ruixiangw <ruixiangw@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: ruixiang <wangruixiang07@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants