When max_length is set and a VLM that returns sequence-aligned side-inputs is used, DPOTrainer crashes during the model forward pass with a shape mismatch error:
> trainer.train()
tests/test_dpo_trainer.py:1239:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.venv/lib/python3.10/site-packages/transformers/trainer.py:1424: in train
return inner_training_loop(
.venv/lib/python3.10/site-packages/transformers/trainer.py:1506: in _inner_training_loop
self._run_epoch(
.venv/lib/python3.10/site-packages/transformers/trainer.py:1734: in _run_epoch
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
trl/trainer/dpo_trainer.py:1449: in training_step
return super().training_step(*args, **kwargs)
.venv/lib/python3.10/site-packages/transformers/trainer.py:1906: in training_step
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
trl/trainer/dpo_trainer.py:1444: in compute_loss
return self._compute_loss(model, inputs, return_outputs)
trl/trainer/dpo_trainer.py:1147: in _compute_loss
outputs = model(**model_kwargs)
.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1776: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1787: in _call_impl
return forward_call(*args, **kwargs)
.venv/lib/python3.10/site-packages/accelerate/utils/operations.py:823: in forward
return model_forward(*args, **kwargs)
.venv/lib/python3.10/site-packages/accelerate/utils/operations.py:811: in __call__
return convert_to_fp32(self.model_forward(*args, **kwargs))
.venv/lib/python3.10/site-packages/torch/amp/autocast_mode.py:44: in decorate_autocast
return func(*args, **kwargs)
.venv/lib/python3.10/site-packages/transformers/utils/generic.py:843: in wrapper
output = func(self, *args, **kwargs)
.venv/lib/python3.10/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py:1575: in forward
outputs = self.model(
.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1776: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1787: in _call_impl
return forward_call(*args, **kwargs)
.venv/lib/python3.10/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py:1377: in forward
position_ids = self.compute_3d_position_ids(
.venv/lib/python3.10/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py:1292: in compute_3d_position_ids
position_ids, rope_deltas = self.get_rope_index(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Qwen2_5_VLModel(
(visual): Qwen2_5_VisionTransformerPretrainedModel(
(patch_embed): Qwen2_5_VisionPatchEmbed(
...e-06)
)
)
(norm): Qwen2_5_VLRMSNorm((16,), eps=1e-06)
(rotary_emb): Qwen2_5_VLRotaryEmbedding()
)
)
input_ids = tensor([[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13,
151645, 198, 151644, 8...5,
198, 151644, 77091, 198, 82440, 13, 151645, 198, 151643,
151643]], device='cuda:0')
mm_token_type_ids = tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0... 1, 1, 1, 1, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
device='cuda:0')
image_grid_thw = tensor([[1, 6, 4],
[1, 4, 4],
[1, 6, 4],
[1, 4, 4]], device='cuda:0')
video_grid_thw = None
second_per_grid_ts = <list_iterator object at 0x7fe92934e6e0>
attention_mask = tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1..., 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]], device='cuda:0')
kwargs = {}, spatial_merge_size = 2, tokens_per_second = 2
mrope_position_deltas = []
position_ids = tensor([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]], device='cuda:0')
grid_iters = {1: <tuple_iterator object at 0x7fe92934f6d0>, 2: None}
batch_idx = 0
current_input_ids = tensor([151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13,
151645, 198, 151644, 872...1616,
537, 387, 8036, 518, 1156, 30, 151645, 198, 151644,
77091], device='cuda:0')
def get_rope_index(
self,
input_ids: torch.LongTensor,
mm_token_type_ids: torch.IntTensor,
image_grid_thw: torch.LongTensor | None = None,
video_grid_thw: torch.LongTensor | None = None,
second_per_grid_ts: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
spatial_merge_size = self.config.vision_config.spatial_merge_size
tokens_per_second = self.config.vision_config.tokens_per_second
mrope_position_deltas = []
position_ids = torch.zeros(
3,
input_ids.shape[0],
input_ids.shape[1],
dtype=input_ids.dtype,
device=input_ids.device,
)
grid_iters = {
1: iter(image_grid_thw) if image_grid_thw is not None else None,
2: iter(video_grid_thw) if video_grid_thw is not None else None,
}
second_per_grid_ts = (
iter(second_per_grid_ts) if second_per_grid_ts is not None else iter([1] * input_ids.shape[1])
)
for batch_idx, current_input_ids in enumerate(input_ids):
input_token_type = mm_token_type_ids[batch_idx]
if attention_mask is not None:
current_input_ids = current_input_ids[attention_mask[batch_idx].bool()]
> input_token_type = input_token_type[attention_mask[batch_idx].bool()]
E IndexError: The shape of the mask [37] at index 0 does not match the shape of the indexed tensor [43] at index 0
DPOTrainer crashes when
max_lengthis set with VLMs.When max_length is set and a VLM that returns sequence-aligned side-inputs is used, DPOTrainer crashes during the model forward pass with a shape mismatch error:
See related discussion: #5279 (comment)
Problem
Stacktrace:
Root cause
_truncate_inputstruncatesinput_ids,attention_mask, andcompletion_masktomax_length. However, in bothcompute_ref_log_probsand_compute_loss,token_type_idsandmm_token_type_idsare copied from the batch intomodel_kwargswithout truncation. The model then receivesinput_idsof lengthmax_lengthalongside side-inputs still at their original full length, causing a shape mismatch in the positional encoding computation.Why CI didn't catch it
All existing VLM tests set
max_length=None, which makes_truncate_inputsa no-op and never exercises the truncation path. Users who carefully pick a max_length large enough to preserve all image tokens (a valid and documented use case) hit this crash.