Rollout generation#5299
Conversation
Add experimental AsyncGRPOTrainer implementation with asynchronous generation and scoring loops using vLLM's weight transfer engine. Includes example script, metrics computation, and type safety improvements.
- Track completion status with `is_done()` method in environments - Convert environment methods to JSON schema before passing to vLLM - Early exit from generation when environment signals completion - Handle `AttributeError` in response parsing for malformed tool calls - Update example script with tool-calling iteration limits
Introduce structured data classes (`TurnRecord`, `ToolCallRecord`, `RolloutCompletion`) to track per-turn and per-tool-call metrics including wall-clock durations. Add helper functions `_build_completion()` and `_extract_tool_metrics()` to derive completion objects from turn records and extract per-tool duration and failure metrics. Refactor `_generate_one()` and `_execute_tool_calls()` to populate these records and propagate turn-level timing data through `RolloutGroup` and `RolloutSample`.
| del pending_failures[group_id] | ||
| else: | ||
| group.queued_at = time.monotonic() | ||
| await self._groups_to_score.put(group) |
There was a problem hiding this comment.
Blocking await put() can hang on shutdown
Medium Severity
The old put_nowait() retry loop that checked stop_event between attempts was replaced with a bare await self._groups_to_score.put(group). If the scoring queue is full when stop() is called, the score loop exits (stops consuming), and this await put() blocks forever. Since both loops run inside asyncio.gather, the generate loop never finishes, preventing the worker thread from shutting down cleanly.
Replace flat fields (completion, completion_ids, etc.) in `RolloutCompletion` and `RolloutGroup` with turn-based storage and lazy accessor methods. Introduce `TaggedMessage` to track which messages are part of the model's completion vs. context. Simplify `_build_completion` and remove the now-redundant `append_message` method from `TurnRecord`.
| class _InitialWeightSyncCallback(TrainerCallback): | ||
| """ | ||
| Syncs model weights to the vLLM rollout worker once, on train begin. | ||
|
|
||
| This fires after FSDP wrapping / ``accelerator.prepare()`` (so parameters are on GPU) but before | ||
| the first training step. After the sync the rollout worker is started and this callback removes | ||
| itself from the trainer. | ||
| """ | ||
|
|
||
| def __init__(self, trainer: "AsyncGRPOTrainer"): | ||
| self._trainer = trainer | ||
|
|
||
| def on_train_begin(self, _args, _state, _control, **_kwargs): | ||
| self._trainer._sync_weight() | ||
| self._trainer.remove_callback(type(self)) | ||
|
|
||
|
|
There was a problem hiding this comment.
What the motivation to use a callback instead of simply calling the weight sync in the trainer?
There was a problem hiding this comment.
_sync_weight() was called at the top of _inner_training_loop, before super()._inner_training_loop() ran. it's super()._inner_training_loop() that performs FSDP wrapping and accelerator.prepare() (which moves the model to GPU). So weights were still on CPU when NCCL tried to broadcast them.
So I had:
┃ [rank0]: File "/fsx/amine_dirhoussi/trl/.venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/pynccl.py", line
┃ 363, in broadcast
┃ [rank0]: assert tensor.device == self.device, (
┃ [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
┃ [rank0]: AssertionError: this nccl communicator is created to work on cuda:0, but the input tensor is on cpu| def __init__(self): | ||
| self.test_list = [] | ||
| self.done = False | ||
|
|
There was a problem hiding this comment.
IMO we can remove this
| def __init__(self): | |
| self.test_list = [] | |
| self.done = False |
There was a problem hiding this comment.
it's safer because it will fail if we try to step before the reset
There was a problem hiding this comment.
I think test_list has to be init to [] if I am not mistaken
| os.remove(temp_path) | ||
|
|
||
| def is_done(self) -> bool: | ||
| return self.done |
There was a problem hiding this comment.
because it's missing the docstring, it won't be exposed as a tool to the model. If it's not meant to be exposed, then I'd recommend def _is_done. If it is meant to be exposed as a tool, it requires a docstring.
There was a problem hiding this comment.
it is not a tool function. the _is_done means it it a private function but we cant to signal that it needs to be implemented.
The issue we currently have is that user can't define a stop function to exist the generation loop. For example, when in code environment, we want to exist if the tests success but the current setup keeps rolling until we reach max_turns or token exhaustion
| max_seq_length (`int`, *optional*): | ||
| Maximum total sequence length (prompt + completion) for training. When set, generation `max_tokens` is | ||
| dynamically clipped per-prompt so that `prompt_len + completion_len <= max_seq_length`, and any sample | ||
| that still exceeds this limit is dropped before training. If `None`, no training-side length limit is |
There was a problem hiding this comment.
what if one sequence in the group reaches the limit? The whole group is discarded?
There was a problem hiding this comment.
max_seq_length is only used in _generate_one function to truncate generation or to skip prompt. It is also skipped in scoring loop:
if self.max_seq_length is not None and seq_len > self.max_seq_length:
logger.warning(f"Dropping overlong sample (seq_len={seq_len}, max_seq_length={self.max_seq_length})")|
|
||
| self.tools = base_tools + ( | ||
| # Pre-convert any bound methods to JSON schema dicts so they pass transformers' `isfunction` check. | ||
| [get_json_schema(t) for t in environment_methods[0]] if self.environments is not None else [] |
There was a problem hiding this comment.
you don't need this change, apply_chat_template will get the json schema for you
There was a problem hiding this comment.
I added this because I kept getting error in my async_grpo_mbpp.py env. The isfunction check fails as it is a method.
There was a problem hiding this comment.
To be precise: transformers.utils.chat_template_utils.render_jinja_template uses inspect.isfunction() to validate tools, but isfunction() returns False for bound methods. The environment's execute_python_code is collected as a bound method via inspect.getmembers so I was getting
┃ File "/fsx/amine_dirhoussi/trl/trl/experimental/async_grpo/async_rollout_worker.py", line 287, in _run_loops
┃ await asyncio.gather(
┃ File "/fsx/amine_dirhoussi/trl/trl/experimental/async_grpo/async_rollout_worker.py", line 392, in _generate_loop
┃ prompt_ids = self.tokenizer.apply_chat_template(
┃ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
┃ File "/fsx/amine_dirhoussi/trl/.venv/lib/python3.11/site-packages/transformers/tokenization_utils_base.py", line 1667, in
┃ apply_chat_template
┃ rendered_chat, generation_indices = render_jinja_template(
┃ ^^^^^^^^^^^^^^^^^^^^^^
┃ File "/fsx/amine_dirhoussi/trl/.venv/lib/python3.11/site-packages/transformers/utils/chat_template_utils.py", line 493, in
┃ render_jinja_template
┃ raise ValueError(
┃ ValueError: Tools should either be a JSON schema, or a callable function with type hints and a docstring suitable for auto-conversion to a schema.Maybe I am missing a quick fix or I defined the Env wrongly?
There was a problem hiding this comment.
you must upgrade transformers to at least 5.2.0
There was a problem hiding this comment.
Weird, I am sure I have the correct transformers version in env:
(trl) amine_dirhoussi@login-node-1 trl ±|rollout-generation ✗|→ uv pip list | rg transformers
transformers 5.2.0I'll retry this specific parsing function is isolation 👍🏼
| except asyncio.QueueFull: | ||
| pass | ||
|
|
||
| def _cancel_stale_tasks( |
There was a problem hiding this comment.
are you sure it's needed? if you set max_inflight_tasks = max_staleness * per_device_train_batch_size * gradient_accumulation_steps * num_processes (default), then you shouldn't need this in my understanding
There was a problem hiding this comment.
Yes we still need this to cancel current inflight. max_inflight_tasks will limit fresh requests. If for example there is a generation that takes a long time that gets stale, we need to cancel it. If we don't we'll be wasting vllm's compute resources on a stale requests that will get dropped.
| self.num_generations = num_generations | ||
| self.max_inflight_tasks = max_inflight_tasks | ||
| self.environments = None | ||
| self._is_done_methods = [None] * self.max_inflight_tasks |
There was a problem hiding this comment.
I strongly think that we shouldn't add any done signal, see #5235 (comment)
Requiring extra methods unnecessarily narrows the set of valid implementations and reduces substitutability. We should depend on the minimal protocol the class actually needs. There is nothing we can't do without is_done, so I'd recommend simply not having it. Plus at inference, the model generation won't have access to this flag anyway, so it's useless
There was a problem hiding this comment.
The done was added for this case (cf earlier comment):
The issue we currently have is that user can't define a stop function to exit the generation loop. For example, when in code environment, we want to exist if the tests succeed but the current setup keeps rolling until we reach max_turns or token exhaustion
There was a problem hiding this comment.
yes I know, but ideally, the model should learn no to continue calling the tool if it gets an error each time. IMO the right way to handle done would be:
class MyEnv:
def reset(self, **kwargs):
self._done = False
def my_func(self):
"""Some nice documentation"""
if self._done:
raise Exception("Session expired, you can't use this tool anymore!")
if some_condition:
self._done = True
returnThere was a problem hiding this comment.
If the generation stops when the environment is done, then the model never learns to know that the [environment is done] implies that [it can't use it anymore].
There was a problem hiding this comment.
Sorry, maybe I didn't explain well the usage of the is_done in this portion codebase. The current impl does the following:
async def _generate_one(...):
while True:
# ....
if tool_calls is None or (max_iterations is not None and iteration_num >= max_iterations):
return ...Now we can have done() raise an Exception but I am not sure it provides a clean signal to know that this rollout had exception because of some env specific errors vs it genuinely reached end of turn. If we go with Exception route, I would image that we need to define a library side Exception like StopGeneration(Exception) to cleanly catch it and know when env has stopped. But then this is equivalent to defining a done function IMO.
Also having a done env defined tool extends tools context for the LLM which can be a good choice in some specific cases but overall maybe unnecesary ??
There was a problem hiding this comment.
So if we don't use the done flag to stop the generation, what's the point of having it? I can see it's used here, but the reason is not clear:
it seems like a proxy the a new is_truncated flag, which I don't quite understand either. The truncation is equivalent to "the last token isn't an EOS", why adding this flag here?
But more generally, don't you think it's out of the scope of this PR?
There was a problem hiding this comment.
So if we don't use the done flag to stop the generation, what's the point of having it? I can see it's used here, but the reason is not clear:
To be precise, I put the previous generation loop. The changed one does use it to break : line
if is_done is not None and is_done():
return _build_completion(turns, truncated=False, total_duration=time.monotonic() - t_start)it seems like a proxy the a new is_truncated flag, which I don't quite understand either. The truncation is equivalent to "the last token isn't an EOS", why adding this flag here?
Truncation is separate from done and is mainly for metrics and debugging purposes. It can be computed from "isn't last EOS" but its just a bool and recomputing something that small seems unnecessary ?
But more generally, don't you think it's out of the scope of this PR?
I think there was some feature creep, but this PR was originally intended to change the async-grpo branch directly. All the commits were added to debug the mbpp example: metrics, trajectories, the done flag, etc.
I believe the first async-grpo PR iteration had metrics, so I viewed this as a logical extension, especially since there wasn’t a strictly defined feature scope for this specific PR
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
- Rename `tm` to `completion_tool_metrics` for clarity - Add trajectory metrics comment to clarify the section - Extract `advantage` into metrics dict for logging - Remove duplicate `rollout/truncated` and `rollout/num_turns` assignments (already set in `traj_metrics`)
| seq_len = len(group.prompt_ids) + len(completion_ids) | ||
| if self.max_seq_length is not None and seq_len > self.max_seq_length: | ||
| logger.warning(f"Dropping overlong sample (seq_len={seq_len}, max_seq_length={self.max_seq_length})") | ||
| continue |
There was a problem hiding this comment.
Advantage normalization precedes sample dropping, breaking GRPO
Medium Severity
In _score_group, advantages are normalized across all completions in the group, then individual samples exceeding max_seq_length are dropped. The surviving samples retain advantage values computed relative to the full group (including dropped samples). This breaks the zero-mean property of GRPO normalization and can bias training—e.g., if the highest-reward sample is dropped, remaining samples all train with distorted advantages.
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
There are 3 total unresolved issues (including 2 from previous reviews).
Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
| del pending_completed[group_id] | ||
| pending_failures.pop(group_id, None) | ||
|
|
||
| return cancelled |
There was a problem hiding this comment.
Stale cancellation of partially-dispatched group causes permanent leak
Medium Severity
When _cancel_stale_tasks drops a group that has only been partially dispatched (fewer than num_generations tasks created so far), the remaining yields from _repeat_iterator for that same group_id will re-create the group via the if group_id not in pending_groups branch, but only the leftover yields produce tasks. The re-created group's pending_completed can never reach num_generations, so it permanently leaks in pending_groups and pending_completed, occupying slots that complete but never trigger group finalization.


What does this PR do?
Staleness Control (
max_staleness)max_stalenesstoAsyncRolloutWorker. Groups generated from a model version more thanmax_stalenesssteps behind the current version are automatically cancelled and dropped.version_heap) tracks the model version of each in-flight group, enabling O(log n) detection of stale groups.asyncio.CancelledErrorand their slots are immediately reclaimed.update_model_versionnow signals anasyncio.Event(_version_updated) so the generation loop can react immediately without polling.Sequence Length Safety (
max_seq_length,max_model_len)max_seq_lengthconfig option: the maximum allowed length for training. Samples exceeding this are dropped before being enqueued.max_model_lenis fetched from the vLLM/v1/modelsendpoint so the worker always knows the server's context window.effective_max_tokensis dynamically clipped tomin(max_completion_tokens, max_seq_length - prompt_len, max_model_len - prompt_len), preventing 400 errors from vLLM when prompts are long.max_model_lenoutright are skipped immediately with a logged warning.truncatedandnum_turnstracking toRolloutGroupandRolloutCompletion, and exposed them asrollout/truncatedandrollout/num_turnsmetrics per sample.Pause/Resume Correctness
_resume_event(anasyncio.Event) gates the dispatch inner loop: whenpause()is called, the event is cleared and the generation loop stops dispatching new requests before vLLM is paused — avoiding a race where requests land between the pause signal and the server actually pausing.resume()sets the event after the HTTP resume call, re-enabling dispatch.Error Handling and Retry Logic
_generate_one_turnnow has a bounded retry loop (max_generation_retry=10) instead of an infinite loop._generate_onenow wrapstask.result()in a try/except: per-sample exceptions are logged and counted as failures rather than crashing the entire generation loop. Groups where all samples fail are dropped with a warning.asyncio.CancelledError.Initial Weight Sync via Callback
_sync_weight()call from_inner_training_loop. Instead, a new_InitialWeightSyncCallbackfires onon_train_begin— afteraccelerator.prepare()has placed parameters on GPU but before the first training step. This ensures the rollout worker always uses the correct (possibly resumed) policy weights._inner_training_loopafter the sync completes.Return Type Refactor
_generate_onenow returns aRolloutCompletiondataclass instead of a bare tuple, improving readability and making it easier to extend with new fields.MBPP Example
examples/scripts/async_grpo_mbpp.py: an end-to-end example usingAsyncGRPOTraineron the MBPP coding benchmark with a tool-use environment that runs Python code and checks test cases.async_grpo.pylaunch command to show multi-GPU (FSDP2 + data-parallel vLLM) configuration.Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
Note
Medium Risk
Touches core async rollout generation/scoring and weight-sync sequencing, which can affect training stability and throughput if cancellation/length-clipping logic is wrong. No auth/security changes, but concurrency and integration with vLLM make regressions moderately likely.
Overview
Improves
AsyncGRPOrollout generation robustness by adding staleness control (max_staleness) that cancels/drops in-flight groups from old model versions and skips overly stale groups before scoring.Adds sequence-length safety via new
AsyncGRPOConfig.max_seq_length, fetchingmax_model_lenfrom vLLM, and dynamically clipping per-requestmax_tokens; overlong prompts/samples are skipped or truncated to avoid vLLM 400s and training-side overflow.Refactors rollout results to a structured
RolloutCompletion/turn record with per-tool timing/failure metrics, improves pause/resume coordination during weight sync, bounds generation retries with better 4xx/5xx handling, and moves initial weight sync to anon_train_begincallback. Also adds anasync_grpo_mbpp.pytool-use example and updates the existingasync_grpo.pylaunch instructions.Written by Cursor Bugbot for commit 674073a. This will update automatically on new commits. Configure here.