Skip to content

GRPOTrainer tool_mask can become longer than completion_ids after tool-call retokenization #5144

@MichalMraz

Description

@MichalMraz

Reproduction

Summary

In GRPOTrainer, a tool-call round can retokenize prompt+completion+tool so that the resulting completion portion is shorter than the previous completion length.
In _tool_call_loop, tool_mask (and optionally logprobs) was extended but not truncated in this case, so lengths diverged.

This later fails in _compute_loss at:
mask = completion_mask * inputs["tool_mask"]
with:
RuntimeError: The size of tensor a (...) must match the size of tensor b (...) at non-singleton dimension 1

I will be adding a PR with a bugfix and a minimal regression test that reproduces this path inside trainer.train().

Reproduction

  #!/usr/bin/env python
"""
Reproduce the GRPO tool-mask shape mismatch *inside* `GRPOTrainer.train()`.

This script intentionally forces a tool-call round where:
- `completion_ids` is effectively shortened by re-tokenization logic, but
- `tool_mask` is only extended (not truncated),
which leads to a mismatch at:
`trl/trainer/grpo_trainer.py`, `_compute_loss`, line with
`completion_mask * inputs["tool_mask"]`.
"""

from __future__ import annotations

import types
from typing import Any

import torch
import torch.distributed.fsdp as fsdp


# Compatibility shims for this environment.
if not hasattr(fsdp, "FSDPModule"):
    fsdp.FSDPModule = fsdp.FullyShardedDataParallel
if not hasattr(torch.backends.mps, "is_macos_or_newer"):
    torch.backends.mps.is_macos_or_newer = lambda major, minor: False

from datasets import load_dataset

import trl.trainer.grpo_trainer as grpo_mod
from trl import GRPOConfig, GRPOTrainer


def multiply_tool(a: int, b: int) -> int:
    """Multiply two integers.

    Args:
        a: First integer.
        b: Second integer.

    Returns:
        Product of `a` and `b`.
    """
    return a * b


def constant_reward(completions: list[Any], **kwargs) -> list[float]:
    del kwargs
    return [0.0] * len(completions)


def main() -> None:
    dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train[:3]")

    args = GRPOConfig(
        output_dir="/tmp/trl_bug_repro_train",
        per_device_train_batch_size=3,
        num_generations=3,
        max_completion_length=128,
        max_steps=1,
        logging_steps=1,
        report_to="none",
        use_cpu=True,
    )

    trainer = GRPOTrainer(
        model="trl-internal-testing/tiny-Qwen3MoeForCausalLM",
        reward_funcs=constant_reward,
        args=args,
        train_dataset=dataset,
        tools=[multiply_tool],
    )

    # Keep references so we can restore global/module state.
    original_generate_single_turn = trainer._generate_single_turn
    original_parse_response = grpo_mod.parse_response
    original_tool_call_loop = trainer._tool_call_loop
    original_get_per_token_logps_and_entropies = trainer._get_per_token_logps_and_entropies

    state = {"turn": 0}

    def fake_generate_single_turn(self, prompts):
        del prompts
        state["turn"] += 1
        if state["turn"] == 1:
            # First generation batch (size 3).
            prompt_ids = [[100, 101], [110, 111], [120, 121]]
            completion_ids = [
                [900, 901, 902, 903],  # will be interpreted as a tool call
                [700, 701, 702],
                [710, 711, 712],
            ]
            return prompt_ids, completion_ids, None, {}
        if state["turn"] == 2:
            # Tool-round generation for the single tool-calling sample.
            # Intentionally short prompt+completion+tool sequence to induce
            # negative `tool_length` in `_tool_call_loop`.
            prompt_ids = [[100, 101, 201, 202]]
            completion_ids = [[600, 601]]
            return prompt_ids, completion_ids, None, {}
        raise RuntimeError(f"Unexpected fake generation turn: {state['turn']}")

    def fake_parse_response(processing_class, ids):
        del processing_class
        if ids and ids[0] == 900:
            return {
                "role": "assistant",
                "content": "",
                "tool_calls": [
                    {
                        "type": "function",
                        "function": {
                            "name": "multiply_tool",
                            "arguments": {"a": 3, "b": 4},
                        },
                    }
                ],
            }
        return {"role": "assistant", "content": "ok"}

    def wrapped_tool_call_loop(*args, **kwargs):
        out = original_tool_call_loop(*args, **kwargs)
        tool_mask, _completions, completion_ids, _logprobs, *_ = out
        print("DEBUG completion lengths:", [len(x) for x in completion_ids])
        print("DEBUG tool_mask lengths :", [len(x) for x in tool_mask])
        return out

    def fake_get_per_token_logps_and_entropies(
        self,
        model,
        input_ids,
        attention_mask,
        logits_to_keep,
        batch_size=None,
        compute_entropy=False,
        **kwargs,
    ):
        # Return synthetic tensors to keep this script focused on mask-shape behavior.
        del model, attention_mask, batch_size, kwargs
        batch = input_ids.shape[0]
        shape = (batch, logits_to_keep)
        logps = torch.zeros(shape, dtype=torch.float32, device=input_ids.device, requires_grad=True)
        entropies = torch.zeros(shape, dtype=torch.float32, device=input_ids.device) if compute_entropy else None
        return logps, entropies

    trainer._generate_single_turn = types.MethodType(fake_generate_single_turn, trainer)
    trainer._tool_call_loop = wrapped_tool_call_loop
    trainer._get_per_token_logps_and_entropies = types.MethodType(fake_get_per_token_logps_and_entropies, trainer)
    grpo_mod.parse_response = fake_parse_response

    try:
        trainer.train()
        print("\nTraining finished without tool_mask/completion_mask shape mismatch.")
    finally:
        trainer._generate_single_turn = original_generate_single_turn
        trainer._tool_call_loop = original_tool_call_loop
        trainer._get_per_token_logps_and_entropies = original_get_per_token_logps_and_entropies
        grpo_mod.parse_response = original_parse_response


if __name__ == "__main__":
    main()

outputs:

DEBUG completion lengths: [4, 3, 3]
DEBUG tool_mask lengths : [6, 3, 3]
Traceback (most recent call last):
  File "/Users/michalmraz/code/mmraz-trl/mmraz-trl/bug_replication_train.py", line 165, in <module>
    main()
  File "/Users/michalmraz/code/mmraz-trl/mmraz-trl/bug_replication_train.py", line 155, in main
    trainer.train()
  File "/Users/michalmraz/opt/anaconda3/envs/mmraz_trl_venv/lib/python3.12/site-packages/transformers/trainer.py", line 2174, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/michalmraz/opt/anaconda3/envs/mmraz_trl_venv/lib/python3.12/site-packages/transformers/trainer.py", line 2536, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/michalmraz/code/mmraz-trl/mmraz-trl/trl/trainer/grpo_trainer.py", line 1088, in training_step
    output = super().training_step(model, inputs, num_items_in_batch)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/michalmraz/opt/anaconda3/envs/mmraz_trl_venv/lib/python3.12/site-packages/transformers/trainer.py", line 3809, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/michalmraz/code/mmraz-trl/mmraz-trl/trl/extras/profiling.py", line 202, in wrapper
    return func(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/michalmraz/code/mmraz-trl/mmraz-trl/trl/trainer/grpo_trainer.py", line 2002, in compute_loss
    return self._compute_loss(model, inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/michalmraz/code/mmraz-trl/mmraz-trl/trl/trainer/grpo_trainer.py", line 2033, in _compute_loss
    mask = completion_mask if "tool_mask" not in inputs else completion_mask * inputs["tool_mask"]
                                                             ~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~
RuntimeError: The size of tensor a (4) must match the size of tensor b (6) at non-singleton dimension 1

System Info

  • Platform: macOS-15.7.4-arm64-arm-64bit
  • Python version: 3.12.11
  • TRL version: 0.29.0.dev0+f0dd05f
  • PyTorch version: 2.6.0
  • accelerator(s): MPS
  • Transformers version: 5.0.0
  • Accelerate version: 1.12.0
  • Accelerate config: not found
  • Datasets version: 4.5.0
  • HF Hub version: 1.4.1
  • bitsandbytes version: not installed
  • DeepSpeed version: not installed
  • Liger-Kernel version: not installed
  • LLM-Blender version: not installed
  • OpenAI version: not installed
  • PEFT version: not installed
  • vLLM version: not installed

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions