Add Cohere training chat template#5627
Conversation
Adds SFT assistant_only_loss support for the Cohere Command model family
by registering a training-variant chat template with {% generation %} /
{% endgeneration %} markers around the assistant output.
Changes vs the original Cohere template (cohere.jinja):
- Wrap assistant output with {% generation %} / {% endgeneration %} so
return_assistant_tokens_mask=True produces correct masks.
- Drop the user/assistant alternation check that raises on tool messages.
The original template doesn't render tool_calls or tool role content in
any case; silently ignoring tool messages (matching Cohere2's existing
behaviour) keeps the training template prefix-preserving without
changing any observable output for user/assistant-only conversations.
Also catches jinja2.TemplateError in is_chat_template_prefix_preserving
(mirroring the existing TemplateError handling in _has_native_tool_support).
Previously, probing a template that rejects the tool-message role sequence
propagated the exception out of get_training_chat_template before the
explicit chat_template dispatch could run; now such templates are treated
as not prefix-preserving, letting the explicit dispatch pick up known
cases like Cohere.
Test coverage adds tiny-CohereForCausalLM to the
TestGetTrainingChatTemplate parametrize list. All 12 training-template
tests pass on the new tokenizer, including
test_new_chat_template_is_prefix_preserving and both assistant-mask tests.
Full test_chat_template_utils.py suite: 114 passed, no regressions.
| except TemplateError: | ||
| # Template rejects the role sequence (e.g. Cohere, FalconMamba enforce strict user/assistant alternation | ||
| # and raise on tool messages). Not prefix-preserving by this definition — patching is still supported via | ||
| # an explicit chat_template match in get_training_chat_template. | ||
| return False |
There was a problem hiding this comment.
I think instead we should condition that call to is_chat_template_prefix_preserving by supports_tool_calling
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
There are 2 total unresolved issues (including 1 from previous review).
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit 8282949. Configure here.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
- Drop `except TemplateError` from `is_chat_template_prefix_preserving` per @qgallouedec's review. Instead, gate the call in `get_training_chat_template` with `supports_tool_calling`, so the prefix-preservation probe only runs on templates that can actually render the tool-message sequence the probe relies on. - `supports_tool_calling` gains a matching `except TypeError: return False` catch next to the existing `except TemplateError: return False`, so the new gate short-circuits cleanly on DeepSeek-V3 (whose original template raises TypeError on dict tool arguments) without having to wrap the gate in try/except at the call site. - Move the `<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>` turn header outside the `{% generation %}` block in `cohere_training.jinja`, so the SFT assistant-only loss isn't applied to the constant formatting prefix (matching the llama3 / qwen2.5 training-template convention).
|
Thanks for the review, @qgallouedec! e391982 addresses the three points:
Full |
…ng-chat-template-5471 # Conflicts: # trl/chat_template_utils.py
|
@codex review |
|
Codex Review: Didn't find any major issues. Chef's kiss. ℹ️ About Codex in GitHubYour team has set up Codex to review pull requests in this repo. Reviews are triggered when you
If Codex has suggestions, it will comment; otherwise it will react with 👍. Codex can also answer questions or update the PR. Try commenting "@codex address that feedback". |

What does this PR do?
Addresses the Cohere slot of the tracker issue #5471 by registering a training-variant chat template for the Cohere Command model family, so SFT with `assistant_only_loss=True` produces correct masks on Cohere tokenizers.
Changes
New templates
`trl/chat_template_utils.py`
`trl/chat_templates/README.md`
`tests/test_chat_template_utils.py`
Tests
All 12 `TestGetTrainingChatTemplate[cohere]` tests pass, including `test_new_chat_template_is_prefix_preserving` and both `test_assistant_masks*`.
Full `tests/test_chat_template_utils.py` run: 114 passed, 64 skipped (pre-existing), 2 xfailed (pre-existing DeepSeek xfails). No regressions on other tokenizers.
`ruff check` and `ruff format --check` clean on all touched files.
Design notes
Cohere is the first non-tool-supporting model family added to the training-template set. The existing DeepSeek-V3 / GPT-OSS / LLaMA 3 / Qwen2.5 / Qwen3 all support tool calls in their original templates, so `{% generation %}` markers alone were enough. For Cohere, the original template raises on any `tool` role message, which breaks the prefix-preserving probe in `is_chat_template_prefix_preserving` that drives `get_training_chat_template`'s dispatch.
Two changes were needed to close that gap:
Happy to split (1) into a separate PR if preferred, though the test failure on `test_new_chat_template_is_prefix_preserving[cohere]` couples them naturally.
Before submitting
Fixes one slot of #5471.
Note
Low Risk
Low risk: changes are additive (new Cohere templates, docs, and a test) and only affect template patching when a tokenizer exactly matches the Cohere reference template.
Overview
Adds Cohere Command to TRL’s supported chat-template families by bundling
cohere.jinja(reference) and a newcohere_training.jinjathat wraps assistant output in{% generation %}/{% endgeneration %}to enable correctassistant_only_loss=Truemasking.Updates
get_training_chat_templateto recognize the Cohere reference template and swap in the training variant, extends the docs/trl/chat_templates/README.mdto document the new templates, and adds Cohere coverage to theTestGetTrainingChatTemplateparametrized tests.Reviewed by Cursor Bugbot for commit ca7ccc7. Bugbot is set up for automated code reviews on this repo. Configure here.