Skip to content

Fix DPOTrainer collators to truncate sequences before padding#5305

Merged
albertvillanova merged 3 commits into
huggingface:mainfrom
albertvillanova:fix-5304
Mar 18, 2026
Merged

Fix DPOTrainer collators to truncate sequences before padding#5305
albertvillanova merged 3 commits into
huggingface:mainfrom
albertvillanova:fix-5304

Conversation

@albertvillanova

@albertvillanova albertvillanova commented Mar 18, 2026

Copy link
Copy Markdown
Member

Fix DPOTrainer collators to truncate sequences before padding when max_length is set.

This PR adds support for truncating input sequences to a maximum length in both text and vision preference data collators in DPOTrainer. This helps prevent oversized tensors when handling very long sequences and allows for configurable truncation behavior. The changes also propagate these new options from the trainer arguments to the collators.

Fix #5304.

Motivation

Previously, DataCollatorForPreference and DataCollatorForVisionPreference padded every batch to the longest raw sequence in the dataset. _truncate_inputs then sliced the padded tensors down to max_length inside the forward pass. This caused the collator to allocate tensors up to N× larger than necessary, where N = max_raw_seq_len / max_length, leading to unexpected OOM errors for users who had set max_length precisely to stay within their GPU memory budget.

Solution

The fix moves truncation to the earliest possible point: right after concatenation in the collator, before any tensor allocation; so padding is always bounded by max_length.

_truncate_inputs in the forward pass is left unchanged. It remains the correct truncation path for user-supplied custom collators and acts as a no-op (cheap slice on already-bounded tensors) for the built-in collators.

Changes

Sequence truncation support:

  • Added max_length and truncation_mode parameters to DataCollatorForPreference and DataCollatorForVisionPreference, allowing input sequences to be truncated to a specified maximum length. For text, both "keep_start" and "keep_end" truncation modes are supported; for vision, only "keep_start" is allowed.
  • Implemented logic in the collators' torch_call methods to perform truncation on input IDs, attention masks, and related fields according to the specified max_length and truncation_mode.
    • DataCollatorForPreference (text): After building prompt_chosen_ids / prompt_rejected_ids as plain Python lists, slice them to max_length using keep_start ([:max_length]) or keep_end ([-max_length:]) before converting to tensors. The completion masks receive the same slice. chosen_attention_mask / rejected_attention_mask are derived after truncation, so their lengths are always consistent.
    • DataCollatorForVisionPreference (vision): After flush_left assembles the final [2*batch_size, seq_len] tensors, slice input_ids, attention_mask, completion_mask, and (where present) token_type_ids / mm_token_type_ids to [:, :max_length]. Only keep_start applies here; keep_end is already rejected upstream with a ValueError before the collator is constructed, because image tokens in the prompt would be silently dropped. pixel_values and other spatial image tensors are not touched.

Trainer integration:

  • Updated the trainer initialization to pass max_length and truncation_mode arguments from the trainer configuration to the appropriate data collator, ensuring the new functionality is available during training.

Note

Medium Risk
Changes batching/truncation behavior in DPOTrainer collators, which can affect training inputs and loss if slicing is incorrect, but scope is limited to data collation and argument plumbing.

Overview
Prevents oversized batch tensors in DPOTrainer when max_length is set by truncating sequences inside the collators before padding/tensor allocation.

DataCollatorForPreference now accepts max_length and truncation_mode (keep_start/keep_end) and truncates concatenated prompt+completion ids and completion_mask prior to building attention masks and padding.

DataCollatorForVisionPreference now accepts max_length and slices the post-flush_left tensors (input_ids, attention_mask, completion_mask, and optional token-type ids) to [:, :max_length]; trainer wiring passes max_length/truncation_mode into the text collator and max_length into the vision collator (with keep_end rejected for vision datasets).

Written by Cursor Bugbot for commit 99c194e. This will update automatically on new commits. Configure here.

@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec

Copy link
Copy Markdown
Member

I think it's very very unlikely we hit OOM at collation time. Even with seq_len=1M and batch_size=8, that’s only ~8M elements: ~30 MB in int32.

In practice, the memory peak is in the forward/backward pass due to the vocab dimension. For example, with a typical vocab_size of 150k and seq_len=512, logits alone are ~1.2 GB in (bf16). To reach 1.2 GB at collation time, you’d need seq_len ≈ 40M tokens, which is unrealistic.

@albertvillanova

albertvillanova commented Mar 18, 2026

Copy link
Copy Markdown
Member Author

@qgallouedec, I agree that the OMM argument was indeed exaggerated. The collation overhead is negligible by comparison with the forward/backward pass.

That said, there are three narrower motivations that still stand:

  1. Consistency with the existing SFT VLM collator: DataCollatorForVisionLanguageModeling in SFT already truncates in the collator. Diverging from that pattern without a technical reason makes the codebase harder to reason about.

    • See:
      # Truncate if necessary
      if self.max_length is not None:
      input_ids = input_ids[:, : self.max_length]
      attention_mask = attention_mask[:, : self.max_length]
      completion_mask = completion_mask[:, : self.max_length]
      if "token_type_ids" in processed_prompts:
      token_type_ids = token_type_ids[:, : self.max_length]
      if "mm_token_type_ids" in processed_prompts:
      mm_token_type_ids = mm_token_type_ids[:, : self.max_length]
  2. Semantic contract of the collator: after the fix, the tensors emitted by the built-in collators are guaranteed to respect max_length. Before the fix, the collator output silently violates max_length until the forward pass corrects it, which is confusing to anyone inspecting intermediate batch shapes or reasoning about memory.

  3. Code simplification: The current keep_end path in _truncate_inputs operates on an already-padded batch tensor and requires a flush_right → slice → flush_left sequence to correctly extract the trailing tokens from each right-padded sequence. Moving truncation to the collator, where the concatenated sequences are still plain Python lists, reduces keep_end to a trivial ids[-max_length:] slice, with no batch-level bookkeeping needed.

    • Please note that what I actually had in mind with this PR is that it lays the groundwork for a follow-up where _truncate_inputs gets removed entirely. Once both built-in collators truncate internally, the only reason _truncate_inputs still exists is as a silent safety net for custom collators: which is arguably worse than no safety net, because it hides the fact that the collator isn't doing its job. The follow-up would make the contract explicit (custom collators must truncate before padding), replace the silent fix-up with a single shape assertion at the system boundary to catch violations loudly and early, and then delete _truncate_inputs along with the flush_right/flush_left batch-tensor dance that lives inside it. Net result: less code in the trainer, no hidden behavior, and the keep_end complexity vanishes because truncating a plain list before padding is just ids[-max_length:]. This PR is a necessary precondition for that: you can't safely remove _truncate_inputs until the built-in collators already handle it themselves.

@qgallouedec qgallouedec left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I agree, those are good reasons to move the truncation to the collator.

I remember trying to do that when I refactored DPO in #3906, but for some reason I can't recall, I decided to do it before the forward step instead of in the collator. Sorry, that's not very helpful as feedback. 😅

@albertvillanova albertvillanova merged commit ee96845 into huggingface:main Mar 18, 2026
12 checks passed
qgallouedec added a commit that referenced this pull request Mar 18, 2026
commit 3972d66
Author: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
Date:   Wed Mar 18 22:26:44 2026 +0100

    Suggest the `Json()` type for tool calling dataset format (#5307)

    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

commit 5c6e915
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Mar 18 14:55:19 2026 -0600

    Update `RewardFunc` type annotation to allow `None`values in reward list (#5297)

commit ee96845
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Wed Mar 18 17:03:54 2026 +0100

    Fix DPOTrainer collators to truncate sequences before padding (#5305)

commit 435c2ae
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Mar 18 08:09:42 2026 -0600

    Add guidance to avoid `hasattr` and `getattr` with defaults in `AGENTS.md` (#5294)

commit 26ce6a3
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Mar 18 00:44:12 2026 -0600

    Apply docstyle (#5296)

commit 52cd0cc
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Tue Mar 17 15:31:26 2026 +0100

    Fix UNEXPECTED lm_head.weight warning when loading a CausalLM as a reward model (#5295)

commit 7b42fc4
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Tue Mar 17 15:29:11 2026 +0100

    Prevent corruption of DPO VLM training if "keep_end" truncation_mode (#5286)

commit 3acb8e8
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Tue Mar 17 15:27:10 2026 +0100

    Support max_length in DPO VLM training (#5284)

commit ee339a0
Author: Carlos Miguel Patiño <carlos.patino@huggingface.co>
Date:   Tue Mar 17 14:01:44 2026 +0100

    [GKD] Buffer Implementation for Distillation Trainer (#5137)

    Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

commit d46131f
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Mon Mar 16 15:27:19 2026 +0100

    Remove custom get_train/eval_dataloader from OnlineDPO (#5291)

commit 85cf8f4
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Mon Mar 16 15:24:24 2026 +0100

    Remove TrainingArguments import from experimental trainers (#5290)

commit 91e3da0
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Mon Mar 16 07:19:51 2026 -0600

    Fix `accuracy_reward` crash when called from non-main thread (#5281)

commit 4996631
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Mon Mar 16 07:44:28 2026 +0100

    Fix support for model_init_kwargs in MiniLLM when passed as CLI JSON string (#5274)

commit 5fceaa7
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Mon Mar 16 07:43:34 2026 +0100

    Simplify structured outputs logic across vLLM versions in scripts/vllm_serve (#5273)

commit 406d406
Author: casinca <47400729+casinca@users.noreply.github.com>
Date:   Sat Mar 14 04:12:49 2026 +0100

    feat(`grpo_trainer.py`): Variational Sequence-Level Soft Policy Optimization (VESPO) (#5199)

commit d0ac7ef
Author: LeonEricsson <70749762+LeonEricsson@users.noreply.github.com>
Date:   Sat Mar 14 02:53:33 2026 +0100

    Allow nullable logprobs in vLLM serve responses  (#5203)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit c0eabc4
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Fri Mar 13 18:19:15 2026 -0600

    Change default `vllm_mode` to `"colocate"` and add v0→v1 migration guide (#5255)

    Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

commit 6c0fccd
Author: Mario Šaško <mariosasko777@gmail.com>
Date:   Sat Mar 14 00:19:38 2026 +0100

    35% faster packing + rename `bfd-requeue` to `bfd_split` (#5189)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

DPO sequences longer than max_length cause over-allocation and potential OOM during collation

3 participants