Skip to content

Cache packed sequence metadata to reduce D2H syncs across layers#4173

Closed
ruixiang63 wants to merge 3447 commits into
unslothai:mainfrom
ruixiang63:pakcing_cache_optimization
Closed

Cache packed sequence metadata to reduce D2H syncs across layers#4173
ruixiang63 wants to merge 3447 commits into
unslothai:mainfrom
ruixiang63:pakcing_cache_optimization

Conversation

@ruixiang63

@ruixiang63 ruixiang63 commented Mar 6, 2026

Copy link
Copy Markdown
Contributor

This PR has been merged through another PR: #4243

Added per-forward-pass caching to eliminate redundant D2H copies and cudaStreamSynchronize calls across layers.

When packing (padding-free) is enabled, three functions are called on every layer of the model during the forward pass:

  • get_packed_info_from_kwargs: calls lengths.max().item() — triggers D2H copy + sync
  • build_sdpa_packed_attention_mask (SDPA backend): calls seq_lengths.sum().item() and seq_lengths.tolist() — triggers 2 D2H copies + syncs
  • build_xformers_block_causal_mask (XFormers backend): calls seq_lengths.to("cpu") — triggers D2H copy + sync

For a model with N layers, this results in N unnecessary D2H synchronizations per function, even though the packed sequence metadata (seq_lengths) is identical across all layers within the same forward pass.

Solution

Cache the output of each function using Python object identity (is) comparison on the seq_lengths tensor. Since the same seq_lengths tensor object is passed to all layers within a single forward pass, subsequent layers hit the cache and skip the D2H operations entirely. A new batch produces a new seq_lengths tensor object, which naturally invalidates the cache.

This reduces D2H synchronizations per forward pass:

Function Before After
get_packed_info_from_kwargs N 1
build_sdpa_packed_attention_mask 2N 2
build_xformers_block_causal_mask N 1

Nsys profiling traces:

  • Without this PR:
image
  • With this PR:
image

Performance

With the caching strategy, CudaStreamSync only appears in the first layer and in the following layers it disappears. We achieve around 43.3% speedup for forward, 5.8% speedup for backward, 14.3% speedup for each batch for Qwen3 14B QLoRA SFT.

Fizza-Mukhtar and others added 30 commits December 28, 2025 21:23
* Guard optional trl.experimental.openenv usage in RL patches

* Simplify optional trl.openenv import handling

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…3790)

* Fix is_contiguous() method call and remove duplicate imports

- Fix bug in rope_embedding.py where is_contiguous was used without
  parentheses, causing the method object (always truthy) to be evaluated
  instead of calling the method. This fixes issue unslothai#3781 where fast rope
  backpropagation was broken for zero strided/non-contiguous tensors.

- Remove duplicate `import torch` in rl.py (lines 20 and 25)
- Remove duplicate `import functools` and `import types` in vision.py

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Fix Boolean value of Tensor ambiguity error in mistral.py

Replace `or` operator with explicit `is None` check when getting
n_items from kwargs. The `or` operator fails when the value is a
Tensor because Python cannot determine the boolean value of a
multi-element tensor.

Fixes unslothai#3766

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Update rope_embedding.py

---------

Co-authored-by: yurekami <yurekami@users.noreply.github.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
…lothai#3794)

Add "corda" as an allowed value for the init_lora_weights parameter
in FastLanguageModel.get_peft_model() and FastBaseModel.get_peft_model().

This enables users to use CorDA (Correlation-aware Decomposed Adaptation)
initialization from PEFT, which provides an alternative LoRA initialization
strategy for improved finetuning performance.

Fixes unslothai#3693

Signed-off-by: majiayu000 <1835304752@qq.com>
…lothai#3811)

* Fix correctness bugs in rl.py, rl_replacements.py, and vision.py

1. rl_replacements.py (lines 864, 870): Fixed undefined `nanmin`/`nanmax`
   functions by using `.nan_to_num(nan=inf/-inf).min()/.max()` pattern.
   PyTorch doesn't have torch.nanmin/nanmax, so we replace NaN values
   before computing min/max.

2. vision.py (line 150): Fixed bug where code checked for "input" key
   but then accessed kwargs["input_ids"] instead of kwargs["input"].

3. vision.py (line 159): Fixed bug where literal string "key" was used
   instead of the variable `key` when accessing kwargs.

4. rl.py (lines 903, 905): Fixed non-existent `MathError` exception
   by replacing with `ValueError`.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1. cohere.py:347-348 - Fixed wrong variable names in QK normalization.
   Used `Q`/`K` but variables were named `Qn`/`Kn`. This caused NameError
   when `use_qk_norm=True` (e.g., c4ai-command-r-plus models).

2. cohere.py:482 - Fixed wrong object reference in inference loop.
   Used `self.mlp` but should be `decoder_layer.mlp` since we're
   iterating through decoder layers. Caused AttributeError during inference.

3. falcon_h1.py:459,461 - Fixed wrong attribute names in inference path.
   Used `post_attention_layernorm` and `mlp` but Falcon H1 uses
   `pre_ff_layernorm` and `feed_forward`. Caused AttributeError during generation.

4. qwen3_moe.py:210 - Fixed wrong module path with incorrect capitalization.
   Used `transformers.models.Qwen3Moe` but should be `transformers.models.qwen3_moe`.
   Caused AttributeError when patching rotary embeddings.

5. qwen3_moe.py:239 - Fixed wrong model_patcher class.
   Used `FastQwen3Model` but should be `FastQwen3MoeModel` for MoE models.
   Caused incorrect patching for Qwen3 MoE models.

6. hf_hub.py:21-22 - Fixed floor division and missing return for billion values.
   Used `//` instead of `/` for millions, and had no return for values >= 1B.
   Caused incorrect formatting and None return for large numbers.

7. save.py:550 - Fixed self-assignment that did nothing.
   `sharded_ram_usage = sharded_ram_usage` should be `= max_shard_size`.
   Caused integer shard sizes to be ignored.

8. rl.py:562-567 - Fixed orphan string not included in length_check.
   The elif branch for max_seq_length validation was a standalone string
   expression, not concatenated to length_check. Caused silent skip of
   the max_seq_length > model_max_seq_length warning.

9. granite.py:49-52 - Fixed wrong model name and version in error message.
   Said "Gemma2" and "4.42.3" but should be "Granite" and "4.45.0".
…tmul

Fix 3D tensor support for bitsandbytes 8-bit matmul in forward pass
🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
FIX: weight tying for LoRA embeddings and lm_head
Gemma3 models have a large vocabulary (262144 tokens) which causes
training loss to explode when using int8 embedding quantization.

This fix auto-detects Gemma3 models and switches from int8-int4
(phone-deployment) to int4 weight-only QAT for stable training.
…lity

Fix Gemma3 QAT training instability with int8-int4 scheme
When users load a model with fast_inference=False but then try to use
vLLM-style arguments with fast_generate, they previously got confusing
errors. This adds a wrapper that detects common mistakes and provides
helpful guidance:

- Using sampling_params: explains to use HF generate args instead
- Using lora_request: explains LoRA weights are already merged
- Passing text strings: shows how to tokenize input first

Changes:
- Add make_fast_generate_wrapper to _utils.py
- Apply wrapper in llama.py when fast_inference=False
- Apply wrapper in vision.py when fast_inference=False
…apper-helpful-errors

Add helpful error messages for fast_generate when fast_inference=False
…curl

Make llama.cpp CURL dependency optional when building from source
rolandtannous and others added 9 commits March 3, 2026 06:34
…support (unslothai#4138)

* fix: update GGUF save paths to use ~/.unsloth/llama.cpp with Windows support

* fix: quote LLAMA_CPP_DEFAULT_DIR in fallback shell commands to handle paths with spaces

* refactor: deduplicate platform-specific build instructions in quantization error message

* chore: remove accidentally committed PR description file

* Fix import safety and f-string bugs in save.py

- H4: Add defensive try/except for LLAMA_CPP_DEFAULT_DIR and IS_WINDOWS imports
  with fallback defaults, so save.py works even if zoo PR unslothai#526 is not merged yet
- H5: Fix Kaggle error path using plain "Error: {e}" instead of f"Error: {e}",
  so the actual exception is shown to users

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Fixup mapper issues and resolve properly

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Fix broken wandb import crashing unsloth startup

When wandb is installed but broken (e.g., wandb < 0.19.11 with
protobuf >= 6.0), the import chain unsloth -> trl -> transformers ->
is_wandb_available() -> import wandb crashes with:

  ImportError: cannot import name 'Imports' from
  'wandb.proto.wandb_telemetry_pb2'

This happens because transformers' is_wandb_available() has no
try/except around `import wandb`. The error propagates up and kills
`from unsloth import FastLanguageModel` even though wandb is optional.

Add disable_broken_wandb() following the same pattern as
disable_torchcodec_if_broken(). It proactively tries importing wandb
during early init, and if the import fails, patches
is_wandb_available() to return False and sets WANDB_DISABLED=true.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…slothai#4148)

trl/trainer/callbacks.py imports is_wandb_available from
accelerate.utils, not from transformers. The original fix in unslothai#4147
only patched the transformers version, so `from trl import GRPOTrainer`
still crashed via the callbacks.py -> accelerate -> wandb path.

Must patch both the source module (accelerate.utils.imports) AND the
re-export namespace (accelerate.utils) since Python's
`from accelerate.utils import X` reads from the latter, which holds
its own cached reference.
@ruixiang63 ruixiang63 requested a review from mmathew23 as a code owner March 6, 2026 21:23
@gemini-code-assist

Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a caching mechanism to optimize the performance of models utilizing packed sequences. By intelligently caching metadata related to sequence lengths and attention masks, it drastically reduces the number of costly Device-to-Host synchronizations that previously occurred in every layer of a multi-layer model. This optimization leads to substantial speedups in both forward and backward passes, improving overall training and inference efficiency.

Highlights

  • D2H Synchronization Reduction: Implemented per-forward-pass caching for get_packed_info_from_kwargs, build_sdpa_packed_attention_mask, and build_xformers_block_causal_mask to eliminate redundant Device-to-Host (D2H) copies and cudaStreamSynchronize calls across model layers.
  • Caching Strategy: Caches are invalidated naturally by new seq_lengths tensor objects for each new batch, ensuring correctness while maximizing cache hits for subsequent layers within the same forward pass by using Python object identity (is) comparison.
  • Performance Improvement: Achieved significant speedups, including approximately 43.3% for forward pass, 5.8% for backward pass, and 14.3% per batch for Qwen3 14B QLoRA SFT.
Changelog
  • unsloth/utils/packing.py
    • Introduced global cache dictionaries for packed info, SDPA masks, and XFormers block masks.
    • Modified get_packed_info_from_kwargs to check and utilize the _PACKED_INFO_CACHE, and to store results.
    • Updated cu_seqlens initialization from torch.empty to torch.zeros in get_packed_info_from_kwargs.
    • Modified build_xformers_block_causal_mask to incorporate caching logic using _XFORMERS_BLOCK_MASK_CACHE.
    • Modified build_sdpa_packed_attention_mask to integrate caching logic using _SDPA_MASK_CACHE.
Activity
  • Pull request created by ichbinhandsome.
  • Initial code changes implemented to introduce caching for packed sequence metadata.
  • Performance benchmarks provided in the PR description demonstrating significant speedups.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: ab927312f7

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth/utils/packing.py Outdated

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a caching mechanism to optimize performance by reducing D2H synchronizations. However, the current implementation has critical thread-safety issues due to the use of global state without proper synchronization, which can lead to race conditions and data corruption in concurrent environments. Specifically, cache updates are not atomic, and one caching function ignores the target device, potentially causing crashes in multi-GPU or pipeline-parallel configurations. It is crucial to address these issues by ensuring thread safety and including all relevant parameters in the cache keys, possibly by making cache updates atomic through storing entries as tuples.

Comment thread unsloth/utils/packing.py Outdated
Comment thread unsloth/utils/packing.py Outdated
Comment thread unsloth/utils/packing.py Outdated
sstamenk and others added 9 commits March 7, 2026 01:33
* Refactor loss computation to include completion_mask

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
)

* Fix gpt temporary patch for grpo to happen after compile

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Refactor loss computation to include completion_mask

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixes for trl 0.28 and above

Remove sync/reload weights calls , remove vllm.LLM instantiation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Refactor loss computation to include completion_mask

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixes for trl 0.28 and above

Remove sync/reload weights calls , remove vllm.LLM instantiation

* patch rpc in openenv for newer trl

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pluesclues <136766175+pluesclues@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
updates:
- [github.com/astral-sh/ruff-pre-commit: v0.15.4 → v0.15.5](astral-sh/ruff-pre-commit@v0.15.4...v0.15.5)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
@ruixiang63 ruixiang63 force-pushed the pakcing_cache_optimization branch from 43ace06 to 85cebe7 Compare March 11, 2026 19:07

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 2711dbc81a

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth/utils/packing.py
@ruixiang63

Copy link
Copy Markdown
Contributor Author

The change from this PR was merged through another PR due to rebase issues: #4243
Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.