Add GenerateFnInput compatibility shim in sglang_rollout#1008
Add GenerateFnInput compatibility shim in sglang_rollout#1008
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a new script for fully-async agentic training of GLM-4.7-Flash using SWE-bench data and updates the rollout logic to support a new single-argument signature for custom generation functions. The review feedback suggests improving the robustness of function signature detection by checking argument counts, warns about the broad process termination patterns in the cleanup function, and recommends retrieving model parameters from configuration instead of using hardcoded values.
| sig = inspect.signature(custom_generate_func) | ||
| params = list(sig.parameters.values()) | ||
| # Support GenerateFnInput-style generate functions (single-arg with typed input) | ||
| if len(params) == 1 and params[0].annotation is GenerateFnInput: |
There was a problem hiding this comment.
The check params[0].annotation is GenerateFnInput is fragile. If the module defining the custom function uses from __future__ import annotations, the annotation will be a string, causing this check to fail. Since the legacy signatures use 3 or 4 arguments, checking for len(params) == 1 is a more reliable way to identify the new single-argument signature.
| if len(params) == 1 and params[0].annotation is GenerateFnInput: | |
| if len(params) == 1: |
| targets = ["sglang", "train.py", "train_async.py", "MegatronTrain"] | ||
| exclude = f"grep -v '^{my_pid}$' | grep -v '^{ppid}$'" | ||
| for t in targets: | ||
| subprocess.run( | ||
| f"pgrep -f '{t}' | {exclude} | xargs -r kill 2>/dev/null || true", | ||
| shell=True, | ||
| ) |
There was a problem hiding this comment.
Using pgrep -f with broad patterns like 'sglang' or 'train.py' in the cleanup function can be risky on shared systems, as it might match and kill unrelated processes (e.g., a text editor with a file open that matches the pattern). Consider using more specific patterns or restricting the cleanup to processes owned by the current user.
|
|
||
| def prepare(args: ScriptArgs): | ||
| """Convert HF checkpoint to torch_dist format.""" | ||
| max_convert_nodes = 92 // args.num_gpus_per_node |
There was a problem hiding this comment.
The value 92 appears to be a hardcoded parameter. Model parameters should be retrieved from the model configuration rather than being hardcoded to ensure consistency and maintainability. Additionally, this value does not seem to be a multiple of the default num_gpus_per_node (8).
References
- Model parameters should be retrieved from the model configuration rather than being hardcoded.
10e7d47 to
165855c
Compare
Lets generate_and_rm dispatch to the new single-argument GenerateFnInput-style custom generate functions (e.g. miles/rollout/generate_hub/agentic_tool_call.py::generate) alongside the existing 3-arg and 4-arg signatures, so async agentic rollouts don't crash with "TypeError: generate() takes 1 positional argument but 3 were given".
165855c to
2d231e4
Compare
|
Superseded by #1016 which lands the same fix more cleanly: PR #1016 centralizes the legacy ↔ new |
Summary
Small compatibility shim so
generate_and_rmcan dispatch to the new single-argumentGenerateFnInput-style custom generate functions (e.g.miles/rollout/generate_hub/agentic_tool_call.py::generate) in addition to the legacy 3-arg and 4-arg signatures.Without the shim, any rollout wired to a
GenerateFnInput-style function crashes immediately with:because the old dispatcher always called the target as
custom_generate_func(args, sample, sampling_params[, evaluation=...]).The shim inspects the target's signature: if it has exactly one parameter, the four runtime values are packed into a
GenerateFnInputdataclass and awaited as a single-arg call, andoutput.samplesreplacessample. Otherwise it falls through to the existing 3-arg / 4-arg paths. Fully backwards-compatible with all existing custom generate functions.Test plan
len(params) == 1branch only fires for the new contract).miles/rollout/generate_hub/agentic_tool_call.py::generate(single-argGenerateFnInput) dispatches through the new branch without theTypeError, verified on a 2-node async agentic rollout against GLM-4.7-Flash on 2026-04-18 PT.