Skip to content

DPO sequences longer than max_length cause over-allocation and potential OOM during collation #5304

@albertvillanova

Description

@albertvillanova

Description

When max_length is set, DPOTrainer truncates sequences inside the forward pass (_truncate_inputs), after the data collator has already padded the batch. Because DataCollatorForPreference and DataCollatorForVisionPreference have no knowledge of max_length, they pad every batch to the longest raw sequence in the dataset.

If the dataset contains sequences longer than max_length (which is the common case for real-world preference datasets) the collator allocates tensors of shape [2 * batch_size, max_raw_seq_len], only for _truncate_inputs to immediately discard everything beyond column max_length. The wasted factor equals max_raw_seq_len / max_length, which can easily be 4 to 8 times.

In practice this causes:

  • Significantly higher GPU memory usage than the user expects from setting max_length
  • OOM crashes that would not occur with the equivalent SFTTrainer setup on the same dataset and hardware

Root cause

SFTTrainer stores the full concatenated sequence (input_ids) in the dataset and calls truncate_dataset() at preprocessing time, so the collator never sees a sequence longer than max_length.

DPOTrainer stores sequences as three separate columns (prompt_ids, chosen_ids, rejected_ids). Truncation must happen on the combined sequence prompt_ids + chosen_ids, which only exists after concatenation inside the collator. truncate_dataset() cannot be used directly here: it would truncate each column independently, which is wrong (e.g. a 512-token prompt + 800-token completion would not be truncated at all against a max_length of 1024).

Expected behavior

The padded batch tensor should be bounded by max_length, matching user expectations and the behaviour of SFTTrainer.

Metadata

Metadata

Labels

No labels
No labels

Type

No fields configured for Bug.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions