Reproduction
The code example in doc for "GRPO with replay buffer" is kind of buggy.
- It imports
GRPOWithReplayBufferTrainer but never used.
- It uses
GRPOWithReplayBufferConfig but never imported
- The code is apparently not executable.
Below is the code example given in the doc:
from trl.experimental.grpo_with_replay_buffer import GRPOWithReplayBufferTrainer
from datasets import load_dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
# Guarantee that some rewards have 0 std
def custom_reward_func(completions, **kwargs):
if torch.rand(1).item() < 0.25:
return [0] * len(completions) # simulate some None rewards
else:
return torch.rand(len(completions)).tolist()
training_args = GRPOWithReplayBufferConfig(
output_dir=self.tmp_dir,
learning_rate=1e-4,
per_device_train_batch_size=4,
num_generations=4,
max_completion_length=8,
replay_buffer_size=8,
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=[custom_reward_func],
args=training_args,
train_dataset=dataset,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
System Info
NA
Checklist
Reproduction
The code example in doc for "GRPO with replay buffer" is kind of buggy.
GRPOWithReplayBufferTrainerbut never used.GRPOWithReplayBufferConfigbut never importedBelow is the code example given in the doc:
System Info
NA
Checklist