Skip to content

Rollout generation#5299

Closed
AmineDiro wants to merge 80 commits into
mainfrom
rollout-generation
Closed

Rollout generation#5299
AmineDiro wants to merge 80 commits into
mainfrom
rollout-generation

Conversation

@AmineDiro

@AmineDiro AmineDiro commented Mar 18, 2026

Copy link
Copy Markdown
Member

What does this PR do?

Staleness Control (max_staleness)

  • Added a max_staleness to AsyncRolloutWorker. Groups generated from a model version more than max_staleness steps behind the current version are automatically cancelled and dropped.
  • A min-heap (version_heap) tracks the model version of each in-flight group, enabling O(log n) detection of stale groups.
  • When a new model version is pushed, in-flight tasks for stale groups are cancelled via asyncio.CancelledError and their slots are immediately reclaimed.
  • update_model_version now signals an asyncio.Event (_version_updated) so the generation loop can react immediately without polling.

Sequence Length Safety (max_seq_length, max_model_len)

  • Added max_seq_length config option: the maximum allowed length for training. Samples exceeding this are dropped before being enqueued.
  • At worker startup, max_model_len is fetched from the vLLM /v1/models endpoint so the worker always knows the server's context window.
  • Per-turn effective_max_tokens is dynamically clipped to min(max_completion_tokens, max_seq_length - prompt_len, max_model_len - prompt_len), preventing 400 errors from vLLM when prompts are long.
  • Prompts that exceed max_model_len outright are skipped immediately with a logged warning.
  • Both pre-turn and mid-multi-turn length checks are applied, with early stopping when the growing prompt would overflow.
  • Added truncated and num_turns tracking to RolloutGroup and RolloutCompletion, and exposed them as rollout/truncated and rollout/num_turns metrics per sample.

Pause/Resume Correctness

  • _resume_event (an asyncio.Event) gates the dispatch inner loop: when pause() is called, the event is cleared and the generation loop stops dispatching new requests before vLLM is paused — avoiding a race where requests land between the pause signal and the server actually pausing.
  • resume() sets the event after the HTTP resume call, re-enabling dispatch.

Error Handling and Retry Logic

  • _generate_one_turn now has a bounded retry loop (max_generation_retry=10) instead of an infinite loop.
  • Retry logic now distinguishes between transient server errors (5xx — retried) and client errors (4xx — re-raised immediately), avoiding wasteful retries on bad requests.
  • _generate_one now wraps task.result() in a try/except: per-sample exceptions are logged and counted as failures rather than crashing the entire generation loop. Groups where all samples fail are dropped with a warning.
  • Cancelled tasks (stale cancellation) are handled gracefully via asyncio.CancelledError.

Initial Weight Sync via Callback

  • Removed the inline _sync_weight() call from _inner_training_loop. Instead, a new _InitialWeightSyncCallback fires on on_train_begin — after accelerator.prepare() has placed parameters on GPU but before the first training step. This ensures the rollout worker always uses the correct (possibly resumed) policy weights.
  • The rollout worker is started in _inner_training_loop after the sync completes.

Return Type Refactor

  • _generate_one now returns a RolloutCompletion dataclass instead of a bare tuple, improving readability and making it easier to extend with new fields.

MBPP Example

  • Added examples/scripts/async_grpo_mbpp.py: an end-to-end example using AsyncGRPOTrainer on the MBPP coding benchmark with a tool-use environment that runs Python code and checks test cases.
  • Updated async_grpo.py launch command to show multi-GPU (FSDP2 + data-parallel vLLM) configuration.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.


Note

Medium Risk
Touches core async rollout generation/scoring and weight-sync sequencing, which can affect training stability and throughput if cancellation/length-clipping logic is wrong. No auth/security changes, but concurrency and integration with vLLM make regressions moderately likely.

Overview
Improves AsyncGRPO rollout generation robustness by adding staleness control (max_staleness) that cancels/drops in-flight groups from old model versions and skips overly stale groups before scoring.

Adds sequence-length safety via new AsyncGRPOConfig.max_seq_length, fetching max_model_len from vLLM, and dynamically clipping per-request max_tokens; overlong prompts/samples are skipped or truncated to avoid vLLM 400s and training-side overflow.

Refactors rollout results to a structured RolloutCompletion/turn record with per-tool timing/failure metrics, improves pause/resume coordination during weight sync, bounds generation retries with better 4xx/5xx handling, and moves initial weight sync to an on_train_begin callback. Also adds an async_grpo_mbpp.py tool-use example and updates the existing async_grpo.py launch instructions.

Written by Cursor Bugbot for commit 674073a. This will update automatically on new commits. Configure here.

Comment thread trl/experimental/async_grpo/async_grpo_trainer.py
- Track completion status with `is_done()` method in environments
- Convert environment methods to JSON schema before passing to vLLM
- Early exit from generation when environment signals completion
- Handle `AttributeError` in response parsing for malformed tool calls
- Update example script with tool-calling iteration limits
Comment thread trl/experimental/async_grpo/async_rollout_worker.py
Introduce structured data classes (`TurnRecord`, `ToolCallRecord`,
`RolloutCompletion`) to track per-turn and per-tool-call metrics
including wall-clock durations. Add helper functions
`_build_completion()` and `_extract_tool_metrics()` to derive completion
objects from turn records and extract per-tool duration and failure
metrics. Refactor `_generate_one()` and `_execute_tool_calls()` to
populate these records and propagate turn-level timing data through
`RolloutGroup` and `RolloutSample`.
Comment thread trl/experimental/async_grpo/async_grpo_trainer.py
Comment thread trl/experimental/async_grpo/async_rollout_worker.py
Comment thread trl/experimental/async_grpo/async_rollout_worker.py
del pending_failures[group_id]
else:
group.queued_at = time.monotonic()
await self._groups_to_score.put(group)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blocking await put() can hang on shutdown

Medium Severity

The old put_nowait() retry loop that checked stop_event between attempts was replaced with a bare await self._groups_to_score.put(group). If the scoring queue is full when stop() is called, the score loop exits (stops consuming), and this await put() blocks forever. Since both loops run inside asyncio.gather, the generate loop never finishes, preventing the worker thread from shutting down cleanly.

Fix in Cursor Fix in Web

Replace flat fields (completion, completion_ids, etc.) in
`RolloutCompletion` and `RolloutGroup` with turn-based storage and lazy
accessor methods. Introduce `TaggedMessage` to track which messages are
part of the model's completion vs. context. Simplify `_build_completion`
and remove the now-redundant `append_message` method from `TurnRecord`.

@qgallouedec qgallouedec left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some comments

Comment on lines +65 to +81
class _InitialWeightSyncCallback(TrainerCallback):
"""
Syncs model weights to the vLLM rollout worker once, on train begin.

This fires after FSDP wrapping / ``accelerator.prepare()`` (so parameters are on GPU) but before
the first training step. After the sync the rollout worker is started and this callback removes
itself from the trainer.
"""

def __init__(self, trainer: "AsyncGRPOTrainer"):
self._trainer = trainer

def on_train_begin(self, _args, _state, _control, **_kwargs):
self._trainer._sync_weight()
self._trainer.remove_callback(type(self))


Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What the motivation to use a callback instead of simply calling the weight sync in the trainer?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_sync_weight() was called at the top of _inner_training_loop, before super()._inner_training_loop() ran. it's super()._inner_training_loop() that performs FSDP wrapping and accelerator.prepare() (which moves the model to GPU). So weights were still on CPU when NCCL tried to broadcast them.

So I had:

  ┃  [rank0]:   File "/fsx/amine_dirhoussi/trl/.venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/pynccl.py", line
  ┃  363, in broadcast
  ┃  [rank0]:     assert tensor.device == self.device, (
  ┃  [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  ┃  [rank0]: AssertionError: this nccl communicator is created to work on cuda:0, but the input tensor is on cpu

Comment thread examples/scripts/async_grpo_mbpp.py Outdated
Comment thread examples/scripts/async_grpo_mbpp.py Outdated
Comment thread examples/scripts/async_grpo.py Outdated
Comment on lines +50 to +53
def __init__(self):
self.test_list = []
self.done = False

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we can remove this

Suggested change
def __init__(self):
self.test_list = []
self.done = False

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's safer because it will fail if we try to step before the reset

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think test_list has to be init to [] if I am not mistaken

os.remove(temp_path)

def is_done(self) -> bool:
return self.done

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because it's missing the docstring, it won't be exposed as a tool to the model. If it's not meant to be exposed, then I'd recommend def _is_done. If it is meant to be exposed as a tool, it requires a docstring.

@AmineDiro AmineDiro Mar 19, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is not a tool function. the _is_done means it it a private function but we cant to signal that it needs to be implemented.
The issue we currently have is that user can't define a stop function to exist the generation loop. For example, when in code environment, we want to exist if the tests success but the current setup keeps rolling until we reach max_turns or token exhaustion

Comment thread examples/scripts/async_grpo_mbpp.py Outdated
Comment thread examples/scripts/async_grpo_mbpp.py Outdated
Comment thread examples/scripts/async_grpo_mbpp.py Outdated
max_seq_length (`int`, *optional*):
Maximum total sequence length (prompt + completion) for training. When set, generation `max_tokens` is
dynamically clipped per-prompt so that `prompt_len + completion_len <= max_seq_length`, and any sample
that still exceeds this limit is dropped before training. If `None`, no training-side length limit is

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if one sequence in the group reaches the limit? The whole group is discarded?

@AmineDiro AmineDiro Mar 19, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_seq_length is only used in _generate_one function to truncate generation or to skip prompt. It is also skipped in scoring loop:

  if self.max_seq_length is not None and seq_len > self.max_seq_length:
      logger.warning(f"Dropping overlong sample (seq_len={seq_len}, max_seq_length={self.max_seq_length})")

@qgallouedec qgallouedec left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some additional comments

Comment thread trl/experimental/async_grpo/async_grpo_trainer.py Outdated

self.tools = base_tools + (
# Pre-convert any bound methods to JSON schema dicts so they pass transformers' `isfunction` check.
[get_json_schema(t) for t in environment_methods[0]] if self.environments is not None else []

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need this change, apply_chat_template will get the json schema for you

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this because I kept getting error in my async_grpo_mbpp.py env. The isfunction check fails as it is a method.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be precise: transformers.utils.chat_template_utils.render_jinja_template uses inspect.isfunction() to validate tools, but isfunction() returns False for bound methods. The environment's execute_python_code is collected as a bound method via inspect.getmembers so I was getting

  ┃    File "/fsx/amine_dirhoussi/trl/trl/experimental/async_grpo/async_rollout_worker.py", line 287, in _run_loops
  ┃      await asyncio.gather(
  ┃    File "/fsx/amine_dirhoussi/trl/trl/experimental/async_grpo/async_rollout_worker.py", line 392, in _generate_loop
  ┃      prompt_ids = self.tokenizer.apply_chat_template(
  ┃                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  ┃    File "/fsx/amine_dirhoussi/trl/.venv/lib/python3.11/site-packages/transformers/tokenization_utils_base.py", line 1667, in
  ┃  apply_chat_template
  ┃      rendered_chat, generation_indices = render_jinja_template(
  ┃                                          ^^^^^^^^^^^^^^^^^^^^^^
  ┃    File "/fsx/amine_dirhoussi/trl/.venv/lib/python3.11/site-packages/transformers/utils/chat_template_utils.py", line 493, in
  ┃  render_jinja_template
  ┃      raise ValueError(
  ┃  ValueError: Tools should either be a JSON schema, or a callable function with type hints and a docstring suitable for auto-conversion to a schema.

Maybe I am missing a quick fix or I defined the Env wrongly?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you must upgrade transformers to at least 5.2.0

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird, I am sure I have the correct transformers version in env:

(trl) amine_dirhoussi@login-node-1 trl ±|rollout-generation ✗|→ uv pip list | rg transformers
transformers                             5.2.0

I'll retry this specific parsing function is isolation 👍🏼

except asyncio.QueueFull:
pass

def _cancel_stale_tasks(

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you sure it's needed? if you set max_inflight_tasks = max_staleness * per_device_train_batch_size * gradient_accumulation_steps * num_processes (default), then you shouldn't need this in my understanding

@AmineDiro AmineDiro Mar 19, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we still need this to cancel current inflight. max_inflight_tasks will limit fresh requests. If for example there is a generation that takes a long time that gets stale, we need to cancel it. If we don't we'll be wasting vllm's compute resources on a stale requests that will get dropped.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes good point

self.num_generations = num_generations
self.max_inflight_tasks = max_inflight_tasks
self.environments = None
self._is_done_methods = [None] * self.max_inflight_tasks

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I strongly think that we shouldn't add any done signal, see #5235 (comment)

Requiring extra methods unnecessarily narrows the set of valid implementations and reduces substitutability. We should depend on the minimal protocol the class actually needs. There is nothing we can't do without is_done, so I'd recommend simply not having it. Plus at inference, the model generation won't have access to this flag anyway, so it's useless

@AmineDiro AmineDiro Mar 19, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The done was added for this case (cf earlier comment):

The issue we currently have is that user can't define a stop function to exit the generation loop. For example, when in code environment, we want to exist if the tests succeed but the current setup keeps rolling until we reach max_turns or token exhaustion

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I know, but ideally, the model should learn no to continue calling the tool if it gets an error each time. IMO the right way to handle done would be:

class MyEnv:
    def reset(self, **kwargs):
        self._done = False

    def my_func(self):
        """Some nice documentation"""
        if self._done:
            raise Exception("Session expired, you can't use this tool anymore!")
        if some_condition:
            self._done = True
        return

@qgallouedec qgallouedec Mar 19, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the generation stops when the environment is done, then the model never learns to know that the [environment is done] implies that [it can't use it anymore].

@AmineDiro AmineDiro Mar 19, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, maybe I didn't explain well the usage of the is_done in this portion codebase. The current impl does the following:

async def _generate_one(...):
    while True:
        # .... 
        if tool_calls is None or (max_iterations is not None and iteration_num >= max_iterations):
            return ...

Now we can have done() raise an Exception but I am not sure it provides a clean signal to know that this rollout had exception because of some env specific errors vs it genuinely reached end of turn. If we go with Exception route, I would image that we need to define a library side Exception like StopGeneration(Exception) to cleanly catch it and know when env has stopped. But then this is equivalent to defining a done function IMO.

Also having a done env defined tool extends tools context for the LLM which can be a good choice in some specific cases but overall maybe unnecesary ??

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if we don't use the done flag to stop the generation, what's the point of having it? I can see it's used here, but the reason is not clear:

https://github.com/huggingface/trl/pull/5299/changes#diff-175873f8c2363ec8322c4a0af91437e83da991bb5d388f2d01342b2d31eaf8a3R822

it seems like a proxy the a new is_truncated flag, which I don't quite understand either. The truncation is equivalent to "the last token isn't an EOS", why adding this flag here?

But more generally, don't you think it's out of the scope of this PR?

@AmineDiro AmineDiro Mar 19, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if we don't use the done flag to stop the generation, what's the point of having it? I can see it's used here, but the reason is not clear:

To be precise, I put the previous generation loop. The changed one does use it to break : line

 if is_done is not None and is_done():
     return _build_completion(turns, truncated=False, total_duration=time.monotonic() - t_start)

it seems like a proxy the a new is_truncated flag, which I don't quite understand either. The truncation is equivalent to "the last token isn't an EOS", why adding this flag here?

Truncation is separate from done and is mainly for metrics and debugging purposes. It can be computed from "isn't last EOS" but its just a bool and recomputing something that small seems unnecessary ?

But more generally, don't you think it's out of the scope of this PR?

I think there was some feature creep, but this PR was originally intended to change the async-grpo branch directly. All the commits were added to debug the mbpp example: metrics, trajectories, the done flag, etc.

I believe the first async-grpo PR iteration had metrics, so I viewed this as a logical extension, especially since there wasn’t a strictly defined feature scope for this specific PR

AmineDiro and others added 8 commits March 19, 2026 17:06
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
- Rename `tm` to `completion_tool_metrics` for clarity
- Add trajectory metrics comment to clarify the section
- Extract `advantage` into metrics dict for logging
- Remove duplicate `rollout/truncated` and `rollout/num_turns`
  assignments (already set in `traj_metrics`)
Comment thread trl/experimental/async_grpo/async_grpo_trainer.py
Comment thread trl/experimental/async_grpo/async_rollout_worker.py
seq_len = len(group.prompt_ids) + len(completion_ids)
if self.max_seq_length is not None and seq_len > self.max_seq_length:
logger.warning(f"Dropping overlong sample (seq_len={seq_len}, max_seq_length={self.max_seq_length})")
continue

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Advantage normalization precedes sample dropping, breaking GRPO

Medium Severity

In _score_group, advantages are normalized across all completions in the group, then individual samples exceeding max_seq_length are dropped. The surviving samples retain advantage values computed relative to the full group (including dropped samples). This breaks the zero-mean property of GRPO normalization and can bias training—e.g., if the highest-reward sample is dropped, remaining samples all train with distorted advantages.

Fix in Cursor Fix in Web

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

There are 3 total unresolved issues (including 2 from previous reviews).

Fix All in Cursor

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

del pending_completed[group_id]
pending_failures.pop(group_id, None)

return cancelled

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stale cancellation of partially-dispatched group causes permanent leak

Medium Severity

When _cancel_stale_tasks drops a group that has only been partially dispatched (fewer than num_generations tasks created so far), the remaining yields from _repeat_iterator for that same group_id will re-create the group via the if group_id not in pending_groups branch, but only the leftover yields produce tasks. The re-created group's pending_completed can never reach num_generations, so it permanently leaks in pending_groups and pending_completed, occupying slots that complete but never trigger group finalization.

Additional Locations (1)
Fix in Cursor Fix in Web

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants