Description
When max_length is set, DPOTrainer truncates sequences inside the forward pass (_truncate_inputs), after the data collator has already padded the batch. Because DataCollatorForPreference and DataCollatorForVisionPreference have no knowledge of max_length, they pad every batch to the longest raw sequence in the dataset.
If the dataset contains sequences longer than max_length (which is the common case for real-world preference datasets) the collator allocates tensors of shape [2 * batch_size, max_raw_seq_len], only for _truncate_inputs to immediately discard everything beyond column max_length. The wasted factor equals max_raw_seq_len / max_length, which can easily be 4 to 8 times.
In practice this causes:
- Significantly higher GPU memory usage than the user expects from setting
max_length
- OOM crashes that would not occur with the equivalent
SFTTrainer setup on the same dataset and hardware
Root cause
SFTTrainer stores the full concatenated sequence (input_ids) in the dataset and calls truncate_dataset() at preprocessing time, so the collator never sees a sequence longer than max_length.
DPOTrainer stores sequences as three separate columns (prompt_ids, chosen_ids, rejected_ids). Truncation must happen on the combined sequence prompt_ids + chosen_ids, which only exists after concatenation inside the collator. truncate_dataset() cannot be used directly here: it would truncate each column independently, which is wrong (e.g. a 512-token prompt + 800-token completion would not be truncated at all against a max_length of 1024).
Expected behavior
The padded batch tensor should be bounded by max_length, matching user expectations and the behaviour of SFTTrainer.
Description
When
max_lengthis set,DPOTrainertruncates sequences inside the forward pass (_truncate_inputs), after the data collator has already padded the batch. BecauseDataCollatorForPreferenceandDataCollatorForVisionPreferencehave no knowledge ofmax_length, they pad every batch to the longest raw sequence in the dataset.If the dataset contains sequences longer than
max_length(which is the common case for real-world preference datasets) the collator allocates tensors of shape[2 * batch_size, max_raw_seq_len], only for_truncate_inputsto immediately discard everything beyond columnmax_length. The wasted factor equalsmax_raw_seq_len / max_length, which can easily be 4 to 8 times.In practice this causes:
max_lengthSFTTrainersetup on the same dataset and hardwareRoot cause
SFTTrainerstores the full concatenated sequence (input_ids) in the dataset and callstruncate_dataset()at preprocessing time, so the collator never sees a sequence longer thanmax_length.DPOTrainerstores sequences as three separate columns (prompt_ids,chosen_ids,rejected_ids). Truncation must happen on the combined sequenceprompt_ids + chosen_ids, which only exists after concatenation inside the collator.truncate_dataset()cannot be used directly here: it would truncate each column independently, which is wrong (e.g. a 512-token prompt + 800-token completion would not be truncated at all against amax_lengthof 1024).Expected behavior
The padded batch tensor should be bounded by
max_length, matching user expectations and the behaviour ofSFTTrainer.