Skip to content

DPOTrainer crashes when max_length is set with VLMs: IndexError #5283

@albertvillanova

Description

@albertvillanova

DPOTrainer crashes when max_length is set with VLMs.

When max_length is set and a VLM that returns sequence-aligned side-inputs is used, DPOTrainer crashes during the model forward pass with a shape mismatch error:

IndexError: The shape of the mask [37] at index 0 does not match the shape of the indexed tensor [43] at index 0

See related discussion: #5279 (comment)

mm_token_type_ids not truncated when max_length is set

Medium Severity

When max_length is set, _truncate_inputs truncates input_ids, attention_mask, and completion_mask but not mm_token_type_ids. Both compute_ref_log_probs and _compute_loss then pass the truncated input_ids alongside the original full-length mm_token_type_ids from inputs to the model, causing a shape mismatch that will crash during the forward pass.

Problem

Stacktrace:

  >       trainer.train()
  
  tests/test_dpo_trainer.py:1239: 
  _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
  .venv/lib/python3.10/site-packages/transformers/trainer.py:1424: in train
      return inner_training_loop(
  .venv/lib/python3.10/site-packages/transformers/trainer.py:1506: in _inner_training_loop
      self._run_epoch(
  .venv/lib/python3.10/site-packages/transformers/trainer.py:1734: in _run_epoch
      tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
  trl/trainer/dpo_trainer.py:1449: in training_step
      return super().training_step(*args, **kwargs)
  .venv/lib/python3.10/site-packages/transformers/trainer.py:1906: in training_step
      loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
  trl/trainer/dpo_trainer.py:1444: in compute_loss
      return self._compute_loss(model, inputs, return_outputs)
  trl/trainer/dpo_trainer.py:1147: in _compute_loss
      outputs = model(**model_kwargs)
  .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1776: in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
  .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1787: in _call_impl
      return forward_call(*args, **kwargs)
  .venv/lib/python3.10/site-packages/accelerate/utils/operations.py:823: in forward
      return model_forward(*args, **kwargs)
  .venv/lib/python3.10/site-packages/accelerate/utils/operations.py:811: in __call__
      return convert_to_fp32(self.model_forward(*args, **kwargs))
  .venv/lib/python3.10/site-packages/torch/amp/autocast_mode.py:44: in decorate_autocast
      return func(*args, **kwargs)
  .venv/lib/python3.10/site-packages/transformers/utils/generic.py:843: in wrapper
      output = func(self, *args, **kwargs)
  .venv/lib/python3.10/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py:1575: in forward
      outputs = self.model(
  .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1776: in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
  .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1787: in _call_impl
      return forward_call(*args, **kwargs)
  .venv/lib/python3.10/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py:1377: in forward
      position_ids = self.compute_3d_position_ids(
  .venv/lib/python3.10/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py:1292: in compute_3d_position_ids
      position_ids, rope_deltas = self.get_rope_index(
  _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
  
  self = Qwen2_5_VLModel(
    (visual): Qwen2_5_VisionTransformerPretrainedModel(
      (patch_embed): Qwen2_5_VisionPatchEmbed(
    ...e-06)
        )
      )
      (norm): Qwen2_5_VLRMSNorm((16,), eps=1e-06)
      (rotary_emb): Qwen2_5_VLRotaryEmbedding()
    )
  )
  input_ids = tensor([[151644,   8948,    198,   2610,    525,    264,  10950,  17847,     13,
           151645,    198, 151644,    8...5,
              198, 151644,  77091,    198,  82440,     13, 151645,    198, 151643,
           151643]], device='cuda:0')
  mm_token_type_ids = tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0,
           0, 0, 0, 0, 0, 0, 0, 0, 0, 0... 1, 1, 1, 1, 0, 0, 0, 0, 0,
           0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
         device='cuda:0')
  image_grid_thw = tensor([[1, 6, 4],
          [1, 4, 4],
          [1, 6, 4],
          [1, 4, 4]], device='cuda:0')
  video_grid_thw = None
  second_per_grid_ts = <list_iterator object at 0x7fe92934e6e0>
  attention_mask = tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
           1, 1, 1, 1, 1, 1, 1, 1, 1, 1..., 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
           1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]], device='cuda:0')
  kwargs = {}, spatial_merge_size = 2, tokens_per_second = 2
  mrope_position_deltas = []
  position_ids = tensor([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
            0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]], device='cuda:0')
  grid_iters = {1: <tuple_iterator object at 0x7fe92934f6d0>, 2: None}
  batch_idx = 0
  current_input_ids = tensor([151644,   8948,    198,   2610,    525,    264,  10950,  17847,     13,
          151645,    198, 151644,    872...1616,
             537,    387,   8036,    518,   1156,     30, 151645,    198, 151644,
           77091], device='cuda:0')
  
      def get_rope_index(
          self,
          input_ids: torch.LongTensor,
          mm_token_type_ids: torch.IntTensor,
          image_grid_thw: torch.LongTensor | None = None,
          video_grid_thw: torch.LongTensor | None = None,
          second_per_grid_ts: torch.Tensor | None = None,
          attention_mask: torch.Tensor | None = None,
          **kwargs,
      ) -> tuple[torch.Tensor, torch.Tensor]:
          spatial_merge_size = self.config.vision_config.spatial_merge_size
          tokens_per_second = self.config.vision_config.tokens_per_second
      
          mrope_position_deltas = []
          position_ids = torch.zeros(
              3,
              input_ids.shape[0],
              input_ids.shape[1],
              dtype=input_ids.dtype,
              device=input_ids.device,
          )
          grid_iters = {
              1: iter(image_grid_thw) if image_grid_thw is not None else None,
              2: iter(video_grid_thw) if video_grid_thw is not None else None,
          }
          second_per_grid_ts = (
              iter(second_per_grid_ts) if second_per_grid_ts is not None else iter([1] * input_ids.shape[1])
          )
          for batch_idx, current_input_ids in enumerate(input_ids):
              input_token_type = mm_token_type_ids[batch_idx]
              if attention_mask is not None:
                  current_input_ids = current_input_ids[attention_mask[batch_idx].bool()]
  >               input_token_type = input_token_type[attention_mask[batch_idx].bool()]
  E               IndexError: The shape of the mask [37] at index 0 does not match the shape of the indexed tensor [43] at index 0

Root cause

_truncate_inputs truncates input_ids, attention_mask, and completion_mask to max_length. However, in both compute_ref_log_probs and _compute_loss, token_type_ids and mm_token_type_ids are copied from the batch into model_kwargs without truncation. The model then receives input_ids of length max_length alongside side-inputs still at their original full length, causing a shape mismatch in the positional encoding computation.

Why CI didn't catch it

All existing VLM tests set max_length=None, which makes _truncate_inputs a no-op and never exercises the truncation path. Users who carefully pick a max_length large enough to preserve all image tokens (a valid and documented use case) hit this crash.

Metadata

Metadata

Labels

No labels
No labels

Type

No fields configured for Bug.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions