Fix DPOTrainer collators to truncate sequences before padding#5305
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
I think it's very very unlikely we hit OOM at collation time. Even with In practice, the memory peak is in the forward/backward pass due to the vocab dimension. For example, with a typical vocab_size of 150k and |
|
@qgallouedec, I agree that the OMM argument was indeed exaggerated. The collation overhead is negligible by comparison with the forward/backward pass. That said, there are three narrower motivations that still stand:
|
qgallouedec
left a comment
There was a problem hiding this comment.
Okay, I agree, those are good reasons to move the truncation to the collator.
I remember trying to do that when I refactored DPO in #3906, but for some reason I can't recall, I decided to do it before the forward step instead of in the collator. Sorry, that's not very helpful as feedback. 😅
commit 3972d66 Author: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Date: Wed Mar 18 22:26:44 2026 +0100 Suggest the `Json()` type for tool calling dataset format (#5307) Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> commit 5c6e915 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Wed Mar 18 14:55:19 2026 -0600 Update `RewardFunc` type annotation to allow `None`values in reward list (#5297) commit ee96845 Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed Mar 18 17:03:54 2026 +0100 Fix DPOTrainer collators to truncate sequences before padding (#5305) commit 435c2ae Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Wed Mar 18 08:09:42 2026 -0600 Add guidance to avoid `hasattr` and `getattr` with defaults in `AGENTS.md` (#5294) commit 26ce6a3 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Wed Mar 18 00:44:12 2026 -0600 Apply docstyle (#5296) commit 52cd0cc Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Tue Mar 17 15:31:26 2026 +0100 Fix UNEXPECTED lm_head.weight warning when loading a CausalLM as a reward model (#5295) commit 7b42fc4 Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Tue Mar 17 15:29:11 2026 +0100 Prevent corruption of DPO VLM training if "keep_end" truncation_mode (#5286) commit 3acb8e8 Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Tue Mar 17 15:27:10 2026 +0100 Support max_length in DPO VLM training (#5284) commit ee339a0 Author: Carlos Miguel Patiño <carlos.patino@huggingface.co> Date: Tue Mar 17 14:01:44 2026 +0100 [GKD] Buffer Implementation for Distillation Trainer (#5137) Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> commit d46131f Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Mon Mar 16 15:27:19 2026 +0100 Remove custom get_train/eval_dataloader from OnlineDPO (#5291) commit 85cf8f4 Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Mon Mar 16 15:24:24 2026 +0100 Remove TrainingArguments import from experimental trainers (#5290) commit 91e3da0 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Mon Mar 16 07:19:51 2026 -0600 Fix `accuracy_reward` crash when called from non-main thread (#5281) commit 4996631 Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Mon Mar 16 07:44:28 2026 +0100 Fix support for model_init_kwargs in MiniLLM when passed as CLI JSON string (#5274) commit 5fceaa7 Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Mon Mar 16 07:43:34 2026 +0100 Simplify structured outputs logic across vLLM versions in scripts/vllm_serve (#5273) commit 406d406 Author: casinca <47400729+casinca@users.noreply.github.com> Date: Sat Mar 14 04:12:49 2026 +0100 feat(`grpo_trainer.py`): Variational Sequence-Level Soft Policy Optimization (VESPO) (#5199) commit d0ac7ef Author: LeonEricsson <70749762+LeonEricsson@users.noreply.github.com> Date: Sat Mar 14 02:53:33 2026 +0100 Allow nullable logprobs in vLLM serve responses (#5203) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com> commit c0eabc4 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Fri Mar 13 18:19:15 2026 -0600 Change default `vllm_mode` to `"colocate"` and add v0→v1 migration guide (#5255) Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> commit 6c0fccd Author: Mario Šaško <mariosasko777@gmail.com> Date: Sat Mar 14 00:19:38 2026 +0100 35% faster packing + rename `bfd-requeue` to `bfd_split` (#5189) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Fix
DPOTrainercollators to truncate sequences before padding whenmax_lengthis set.This PR adds support for truncating input sequences to a maximum length in both text and vision preference data collators in
DPOTrainer. This helps prevent oversized tensors when handling very long sequences and allows for configurable truncation behavior. The changes also propagate these new options from the trainer arguments to the collators.Fix #5304.
Motivation
Previously,
DataCollatorForPreferenceandDataCollatorForVisionPreferencepadded every batch to the longest raw sequence in the dataset._truncate_inputsthen sliced the padded tensors down tomax_lengthinside the forward pass. This caused the collator to allocate tensors up to N× larger than necessary, where N = max_raw_seq_len / max_length, leading to unexpected OOM errors for users who had setmax_lengthprecisely to stay within their GPU memory budget.Solution
The fix moves truncation to the earliest possible point: right after concatenation in the collator, before any tensor allocation; so padding is always bounded by
max_length._truncate_inputsin the forward pass is left unchanged. It remains the correct truncation path for user-supplied custom collators and acts as a no-op (cheap slice on already-bounded tensors) for the built-in collators.Changes
Sequence truncation support:
max_lengthandtruncation_modeparameters toDataCollatorForPreferenceandDataCollatorForVisionPreference, allowing input sequences to be truncated to a specified maximum length. For text, both"keep_start"and"keep_end"truncation modes are supported; for vision, only"keep_start"is allowed.torch_callmethods to perform truncation on input IDs, attention masks, and related fields according to the specifiedmax_lengthandtruncation_mode.DataCollatorForPreference(text): After buildingprompt_chosen_ids/prompt_rejected_idsas plain Python lists, slice them tomax_lengthusingkeep_start([:max_length]) orkeep_end([-max_length:]) before converting to tensors. The completion masks receive the same slice.chosen_attention_mask/rejected_attention_maskare derived after truncation, so their lengths are always consistent.DataCollatorForVisionPreference(vision): Afterflush_leftassembles the final[2*batch_size, seq_len]tensors, sliceinput_ids,attention_mask,completion_mask, and (where present)token_type_ids/mm_token_type_idsto[:, :max_length]. Onlykeep_startapplies here;keep_endis already rejected upstream with aValueErrorbefore the collator is constructed, because image tokens in the prompt would be silently dropped.pixel_valuesand other spatial image tensors are not touched.Trainer integration:
max_lengthandtruncation_modearguments from the trainer configuration to the appropriate data collator, ensuring the new functionality is available during training.Note
Medium Risk
Changes batching/truncation behavior in
DPOTrainercollators, which can affect training inputs and loss if slicing is incorrect, but scope is limited to data collation and argument plumbing.Overview
Prevents oversized batch tensors in
DPOTrainerwhenmax_lengthis set by truncating sequences inside the collators before padding/tensor allocation.DataCollatorForPreferencenow acceptsmax_lengthandtruncation_mode(keep_start/keep_end) and truncates concatenatedprompt+completionids andcompletion_maskprior to building attention masks and padding.DataCollatorForVisionPreferencenow acceptsmax_lengthand slices the post-flush_lefttensors (input_ids,attention_mask,completion_mask, and optional token-type ids) to[:, :max_length]; trainer wiring passesmax_length/truncation_modeinto the text collator andmax_lengthinto the vision collator (withkeep_endrejected for vision datasets).Written by Cursor Bugbot for commit 99c194e. This will update automatically on new commits. Configure here.