tool mask support#5682
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for tool_mask (or env_mask) within GRPO training by updating the replacement logic and modifying the loss calculation to use a combined loss_mask. Feedback highlights a potential TypeError on older versions of unsloth_zoo where tool_mask is passed to grpo_accumulated_loss without a prior compatibility check if the mask is null. Additionally, it is recommended to use regular expressions with word boundaries when checking for the existence of tool_mask in function strings to avoid false positives.
| if tool_mask is not None and not getattr( | ||
| self, "_unsloth_grpo_tool_mask_zoo_checked", False | ||
| ): | ||
| _supports_tool_mask = ( | ||
| "tool_mask" in inspect.signature(grpo_accumulated_loss).parameters | ||
| ) | ||
| if not _supports_tool_mask: | ||
| try: | ||
| _zoo_src = inspect.getsource(grpo_accumulated_loss) | ||
| except (TypeError, OSError): | ||
| _zoo_src = "" | ||
| _supports_tool_mask = "tool_mask" in _zoo_src | ||
| if not _supports_tool_mask: | ||
| raise RuntimeError( | ||
| "env_mask/tool_mask GRPO requires an unsloth_zoo build whose " | ||
| "grpo_accumulated_loss handles tool_mask. Please upgrade " | ||
| "unsloth_zoo." | ||
| ) | ||
| self._unsloth_grpo_tool_mask_zoo_checked = True |
There was a problem hiding this comment.
This check for tool_mask support only runs when tool_mask is not None. However, the calls to grpo_accumulated_loss at lines 1744 and 1770 pass tool_mask as a keyword argument regardless of its value. This will cause a TypeError on older versions of unsloth_zoo that do not have tool_mask in the signature, even for users who are not using the tool mask feature.
Consider performing the support check unconditionally (once) and using that result to determine whether to pass the tool_mask argument in the function calls.
References
- Centralize recurring or complex logical checks into a single helper function and reuse it across the codebase to ensure consistency and simplify maintenance.
| ) | ||
| function = function.replace(_save_search, _save_replace) | ||
|
|
||
| if "tool_mask" in function and 'output["tool_mask"]' not in function: |
There was a problem hiding this comment.
The check if "tool_mask" in function is too broad and may match substrings in comments or unrelated code. If a match occurs but tool_mask is not actually a defined variable in the function's scope, the injected code will raise a NameError at runtime. Using a regular expression with word boundaries ensures that tool_mask is treated as a distinct identifier.
| if "tool_mask" in function and 'output["tool_mask"]' not in function: | |
| if re.search(r"\btool_mask\b", function) and 'output["tool_mask"]' not in function: |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 526930f85b
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| sampling_per_token_logps = sampling_per_token_logps, | ||
| token_type_ids = token_type_ids, | ||
| mm_token_type_ids = mm_token_type_ids, | ||
| tool_mask = tool_mask, |
There was a problem hiding this comment.
Gate tool_mask kwarg before calling grpo_accumulated_loss
compute_loss now always forwards tool_mask=tool_mask to grpo_accumulated_loss, but the compatibility check for older unsloth_zoo builds only runs when tool_mask is not None. In environments where grpo_accumulated_loss does not accept a tool_mask parameter, this raises TypeError: unexpected keyword argument 'tool_mask' even when users are not using tool masking, so regular GRPO training can fail after this change.
Useful? React with 👍 / 👎.
#5673
Currently we only have completion mask. Now we want to expand that so that users can specify a tool mask to avoid tool calls' loss being calculated..