Reproduction
Summary
In GRPOTrainer, a tool-call round can retokenize prompt+completion+tool so that the resulting completion portion is shorter than the previous completion length.
In _tool_call_loop, tool_mask (and optionally logprobs) was extended but not truncated in this case, so lengths diverged.
This later fails in _compute_loss at:
mask = completion_mask * inputs["tool_mask"]
with:
RuntimeError: The size of tensor a (...) must match the size of tensor b (...) at non-singleton dimension 1
I will be adding a PR with a bugfix and a minimal regression test that reproduces this path inside trainer.train().
Reproduction
#!/usr/bin/env python
"""
Reproduce the GRPO tool-mask shape mismatch *inside* `GRPOTrainer.train()`.
This script intentionally forces a tool-call round where:
- `completion_ids` is effectively shortened by re-tokenization logic, but
- `tool_mask` is only extended (not truncated),
which leads to a mismatch at:
`trl/trainer/grpo_trainer.py`, `_compute_loss`, line with
`completion_mask * inputs["tool_mask"]`.
"""
from __future__ import annotations
import types
from typing import Any
import torch
import torch.distributed.fsdp as fsdp
# Compatibility shims for this environment.
if not hasattr(fsdp, "FSDPModule"):
fsdp.FSDPModule = fsdp.FullyShardedDataParallel
if not hasattr(torch.backends.mps, "is_macos_or_newer"):
torch.backends.mps.is_macos_or_newer = lambda major, minor: False
from datasets import load_dataset
import trl.trainer.grpo_trainer as grpo_mod
from trl import GRPOConfig, GRPOTrainer
def multiply_tool(a: int, b: int) -> int:
"""Multiply two integers.
Args:
a: First integer.
b: Second integer.
Returns:
Product of `a` and `b`.
"""
return a * b
def constant_reward(completions: list[Any], **kwargs) -> list[float]:
del kwargs
return [0.0] * len(completions)
def main() -> None:
dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train[:3]")
args = GRPOConfig(
output_dir="/tmp/trl_bug_repro_train",
per_device_train_batch_size=3,
num_generations=3,
max_completion_length=128,
max_steps=1,
logging_steps=1,
report_to="none",
use_cpu=True,
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen3MoeForCausalLM",
reward_funcs=constant_reward,
args=args,
train_dataset=dataset,
tools=[multiply_tool],
)
# Keep references so we can restore global/module state.
original_generate_single_turn = trainer._generate_single_turn
original_parse_response = grpo_mod.parse_response
original_tool_call_loop = trainer._tool_call_loop
original_get_per_token_logps_and_entropies = trainer._get_per_token_logps_and_entropies
state = {"turn": 0}
def fake_generate_single_turn(self, prompts):
del prompts
state["turn"] += 1
if state["turn"] == 1:
# First generation batch (size 3).
prompt_ids = [[100, 101], [110, 111], [120, 121]]
completion_ids = [
[900, 901, 902, 903], # will be interpreted as a tool call
[700, 701, 702],
[710, 711, 712],
]
return prompt_ids, completion_ids, None, {}
if state["turn"] == 2:
# Tool-round generation for the single tool-calling sample.
# Intentionally short prompt+completion+tool sequence to induce
# negative `tool_length` in `_tool_call_loop`.
prompt_ids = [[100, 101, 201, 202]]
completion_ids = [[600, 601]]
return prompt_ids, completion_ids, None, {}
raise RuntimeError(f"Unexpected fake generation turn: {state['turn']}")
def fake_parse_response(processing_class, ids):
del processing_class
if ids and ids[0] == 900:
return {
"role": "assistant",
"content": "",
"tool_calls": [
{
"type": "function",
"function": {
"name": "multiply_tool",
"arguments": {"a": 3, "b": 4},
},
}
],
}
return {"role": "assistant", "content": "ok"}
def wrapped_tool_call_loop(*args, **kwargs):
out = original_tool_call_loop(*args, **kwargs)
tool_mask, _completions, completion_ids, _logprobs, *_ = out
print("DEBUG completion lengths:", [len(x) for x in completion_ids])
print("DEBUG tool_mask lengths :", [len(x) for x in tool_mask])
return out
def fake_get_per_token_logps_and_entropies(
self,
model,
input_ids,
attention_mask,
logits_to_keep,
batch_size=None,
compute_entropy=False,
**kwargs,
):
# Return synthetic tensors to keep this script focused on mask-shape behavior.
del model, attention_mask, batch_size, kwargs
batch = input_ids.shape[0]
shape = (batch, logits_to_keep)
logps = torch.zeros(shape, dtype=torch.float32, device=input_ids.device, requires_grad=True)
entropies = torch.zeros(shape, dtype=torch.float32, device=input_ids.device) if compute_entropy else None
return logps, entropies
trainer._generate_single_turn = types.MethodType(fake_generate_single_turn, trainer)
trainer._tool_call_loop = wrapped_tool_call_loop
trainer._get_per_token_logps_and_entropies = types.MethodType(fake_get_per_token_logps_and_entropies, trainer)
grpo_mod.parse_response = fake_parse_response
try:
trainer.train()
print("\nTraining finished without tool_mask/completion_mask shape mismatch.")
finally:
trainer._generate_single_turn = original_generate_single_turn
trainer._tool_call_loop = original_tool_call_loop
trainer._get_per_token_logps_and_entropies = original_get_per_token_logps_and_entropies
grpo_mod.parse_response = original_parse_response
if __name__ == "__main__":
main()
outputs:
DEBUG completion lengths: [4, 3, 3]
DEBUG tool_mask lengths : [6, 3, 3]
Traceback (most recent call last):
File "/Users/michalmraz/code/mmraz-trl/mmraz-trl/bug_replication_train.py", line 165, in <module>
main()
File "/Users/michalmraz/code/mmraz-trl/mmraz-trl/bug_replication_train.py", line 155, in main
trainer.train()
File "/Users/michalmraz/opt/anaconda3/envs/mmraz_trl_venv/lib/python3.12/site-packages/transformers/trainer.py", line 2174, in train
return inner_training_loop(
^^^^^^^^^^^^^^^^^^^^
File "/Users/michalmraz/opt/anaconda3/envs/mmraz_trl_venv/lib/python3.12/site-packages/transformers/trainer.py", line 2536, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/michalmraz/code/mmraz-trl/mmraz-trl/trl/trainer/grpo_trainer.py", line 1088, in training_step
output = super().training_step(model, inputs, num_items_in_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/michalmraz/opt/anaconda3/envs/mmraz_trl_venv/lib/python3.12/site-packages/transformers/trainer.py", line 3809, in training_step
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/michalmraz/code/mmraz-trl/mmraz-trl/trl/extras/profiling.py", line 202, in wrapper
return func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/michalmraz/code/mmraz-trl/mmraz-trl/trl/trainer/grpo_trainer.py", line 2002, in compute_loss
return self._compute_loss(model, inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/michalmraz/code/mmraz-trl/mmraz-trl/trl/trainer/grpo_trainer.py", line 2033, in _compute_loss
mask = completion_mask if "tool_mask" not in inputs else completion_mask * inputs["tool_mask"]
~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~
RuntimeError: The size of tensor a (4) must match the size of tensor b (6) at non-singleton dimension 1
System Info
- Platform: macOS-15.7.4-arm64-arm-64bit
- Python version: 3.12.11
- TRL version: 0.29.0.dev0+f0dd05f
- PyTorch version: 2.6.0
- accelerator(s): MPS
- Transformers version: 5.0.0
- Accelerate version: 1.12.0
- Accelerate config: not found
- Datasets version: 4.5.0
- HF Hub version: 1.4.1
- bitsandbytes version: not installed
- DeepSpeed version: not installed
- Liger-Kernel version: not installed
- LLM-Blender version: not installed
- OpenAI version: not installed
- PEFT version: not installed
- vLLM version: not installed
Checklist
Reproduction
Summary
In
GRPOTrainer, a tool-call round can retokenizeprompt+completion+toolso that the resulting completion portion is shorter than the previous completion length.In
_tool_call_loop,tool_mask(and optionallylogprobs) was extended but not truncated in this case, so lengths diverged.This later fails in
_compute_lossat:mask = completion_mask * inputs["tool_mask"]with:
RuntimeError: The size of tensor a (...) must match the size of tensor b (...) at non-singleton dimension 1I will be adding a PR with a bugfix and a minimal regression test that reproduces this path inside
trainer.train().Reproduction
outputs:
System Info
Checklist