Skip to content

Add Cohere training chat template#5627

Merged
qgallouedec merged 7 commits into
huggingface:mainfrom
dschulmeist:fix/add-cohere-training-chat-template-5471
Apr 28, 2026
Merged

Add Cohere training chat template#5627
qgallouedec merged 7 commits into
huggingface:mainfrom
dschulmeist:fix/add-cohere-training-chat-template-5471

Conversation

@dschulmeist

@dschulmeist dschulmeist commented Apr 22, 2026

Copy link
Copy Markdown
Contributor

What does this PR do?

Addresses the Cohere slot of the tracker issue #5471 by registering a training-variant chat template for the Cohere Command model family, so SFT with `assistant_only_loss=True` produces correct masks on Cohere tokenizers.

Changes

New templates

  • `trl/chat_templates/cohere.jinja` — verbatim copy of the Cohere Command chat template (for identity comparison via `get_training_chat_template`).
  • `trl/chat_templates/cohere_training.jinja` — patched training variant. Diff vs `cohere.jinja`:
    1. Assistant output wrapped in `{% generation %}` / `{% endgeneration %}` so `return_assistant_tokens_mask=True` produces correct masks.
    2. Dropped the user/assistant alternation check that raised on tool messages. The original Cohere template doesn't render `tool_calls` or tool-role content in any case, so silently ignoring tool messages (matching how the Cohere2 original template already behaves) doesn't change any observable output on user/assistant-only conversations, and is required for the training template to be prefix-preserving.

`trl/chat_template_utils.py`

  • Register `cohere_chat_template` and `cohere_training_chat_template` variables.
  • Add a dispatch branch to `get_training_chat_template`: `if tokenizer.chat_template == cohere_chat_template: return cohere_training_chat_template`.
  • Added `except TemplateError: return False` to `is_chat_template_prefix_preserving`, mirroring the existing `except TemplateError` in the sibling `_has_native_tool_support` at the top of the module. Previously, probing a template that rejects the tool-message role sequence (Cohere, FalconMamba) propagated `TemplateError` out of `is_chat_template_prefix_preserving` — and consequently out of `get_training_chat_template` — before the explicit `chat_template` dispatch could run. With the catch in place, such templates are treated as not prefix-preserving, so the explicit dispatch picks them up.
  • Docstring updated: "Currently Cohere, DeepSeek-V3, GPT-OSS, LLaMA 3, Qwen2.5, and Qwen3 are supported."

`trl/chat_templates/README.md`

  • Entries for `cohere.jinja` and `cohere_training.jinja` in the existing structure.

`tests/test_chat_template_utils.py`

  • Added `pytest.param("trl-internal-testing/tiny-CohereForCausalLM", id="cohere")` to the `TestGetTrainingChatTemplate` parametrize list.

Tests

All 12 `TestGetTrainingChatTemplate[cohere]` tests pass, including `test_new_chat_template_is_prefix_preserving` and both `test_assistant_masks*`.

Full `tests/test_chat_template_utils.py` run: 114 passed, 64 skipped (pre-existing), 2 xfailed (pre-existing DeepSeek xfails). No regressions on other tokenizers.

`ruff check` and `ruff format --check` clean on all touched files.

Design notes

Cohere is the first non-tool-supporting model family added to the training-template set. The existing DeepSeek-V3 / GPT-OSS / LLaMA 3 / Qwen2.5 / Qwen3 all support tool calls in their original templates, so `{% generation %}` markers alone were enough. For Cohere, the original template raises on any `tool` role message, which breaks the prefix-preserving probe in `is_chat_template_prefix_preserving` that drives `get_training_chat_template`'s dispatch.

Two changes were needed to close that gap:

  1. Catch `TemplateError` in `is_chat_template_prefix_preserving` so the probe returns `False` instead of propagating — a module-level fix that also unblocks any future non-tool-supporting additions (FalconMamba, Gemma, etc.).
  2. Drop the alternation check in `cohere_training.jinja` so the training template is itself prefix-preserving. Matches the existing Cohere2 behaviour of silently dropping tool messages rather than raising.

Happy to split (1) into a separate PR if preferred, though the test failure on `test_new_chat_template_is_prefix_preserving[cohere]` couples them naturally.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case)
  • Did you read the contributor guideline?
  • Did you make sure to update the documentation with your changes (added `cohere.jinja` / `cohere_training.jinja` entries to `trl/chat_templates/README.md`)?
  • Did you write any new necessary tests?

Fixes one slot of #5471.


Note

Low Risk
Low risk: changes are additive (new Cohere templates, docs, and a test) and only affect template patching when a tokenizer exactly matches the Cohere reference template.

Overview
Adds Cohere Command to TRL’s supported chat-template families by bundling cohere.jinja (reference) and a new cohere_training.jinja that wraps assistant output in {% generation %} / {% endgeneration %} to enable correct assistant_only_loss=True masking.

Updates get_training_chat_template to recognize the Cohere reference template and swap in the training variant, extends the docs/trl/chat_templates/README.md to document the new templates, and adds Cohere coverage to the TestGetTrainingChatTemplate parametrized tests.

Reviewed by Cursor Bugbot for commit ca7ccc7. Bugbot is set up for automated code reviews on this repo. Configure here.

Adds SFT assistant_only_loss support for the Cohere Command model family
by registering a training-variant chat template with {% generation %} /
{% endgeneration %} markers around the assistant output.

Changes vs the original Cohere template (cohere.jinja):
- Wrap assistant output with {% generation %} / {% endgeneration %} so
  return_assistant_tokens_mask=True produces correct masks.
- Drop the user/assistant alternation check that raises on tool messages.
  The original template doesn't render tool_calls or tool role content in
  any case; silently ignoring tool messages (matching Cohere2's existing
  behaviour) keeps the training template prefix-preserving without
  changing any observable output for user/assistant-only conversations.

Also catches jinja2.TemplateError in is_chat_template_prefix_preserving
(mirroring the existing TemplateError handling in _has_native_tool_support).
Previously, probing a template that rejects the tool-message role sequence
propagated the exception out of get_training_chat_template before the
explicit chat_template dispatch could run; now such templates are treated
as not prefix-preserving, letting the explicit dispatch pick up known
cases like Cohere.

Test coverage adds tiny-CohereForCausalLM to the
TestGetTrainingChatTemplate parametrize list. All 12 training-template
tests pass on the new tokenizer, including
test_new_chat_template_is_prefix_preserving and both assistant-mask tests.
Full test_chat_template_utils.py suite: 114 passed, no regressions.
Comment thread trl/chat_templates/cohere_training.jinja Outdated
Comment thread trl/chat_templates/cohere_training.jinja Outdated
Comment thread trl/chat_template_utils.py Outdated
Comment on lines +507 to +511
except TemplateError:
# Template rejects the role sequence (e.g. Cohere, FalconMamba enforce strict user/assistant alternation
# and raise on tool messages). Not prefix-preserving by this definition — patching is still supported via
# an explicit chat_template match in get_training_chat_template.
return False

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no we don't need that

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think instead we should condition that call to is_chat_template_prefix_preserving by supports_tool_calling

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

There are 2 total unresolved issues (including 1 from previous review).

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit 8282949. Configure here.

Comment thread trl/chat_template_utils.py Outdated
@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

- Drop `except TemplateError` from `is_chat_template_prefix_preserving`
  per @qgallouedec's review. Instead, gate the call in
  `get_training_chat_template` with `supports_tool_calling`, so the
  prefix-preservation probe only runs on templates that can actually
  render the tool-message sequence the probe relies on.
- `supports_tool_calling` gains a matching `except TypeError: return
  False` catch next to the existing `except TemplateError: return
  False`, so the new gate short-circuits cleanly on DeepSeek-V3 (whose
  original template raises TypeError on dict tool arguments) without
  having to wrap the gate in try/except at the call site.
- Move the `<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>` turn header
  outside the `{% generation %}` block in `cohere_training.jinja`, so
  the SFT assistant-only loss isn't applied to the constant formatting
  prefix (matching the llama3 / qwen2.5 training-template convention).
@dschulmeist

Copy link
Copy Markdown
Contributor Author

Thanks for the review, @qgallouedec!

e391982 addresses the three points:

  1. Dropped except TemplateError from is_chat_template_prefix_preserving. The call in get_training_chat_template is now gated with supports_tool_calling, so the prefix-preservation probe only runs on templates that can actually render the tool-message sequence it relies on.
  2. supports_tool_calling gains a matching except TypeError: return False next to the existing except TemplateError catch. Without this, the new gate propagated TypeError out of get_training_chat_template for DeepSeek-V3 (whose original template raises on dict arguments). With it, the gate short-circuits cleanly for both Cohere-class templates (TemplateError, rejects tool role) and DeepSeek-class templates (TypeError, dict args). The existing test_deepseek_tool_calling xfail still fails strictly as before.
  3. Moved <|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|> outside the {% generation %} block in cohere_training.jinja, matching the llama3 / qwen2.5 convention — the constant formatting prefix is no longer counted toward SFT loss.

Full tests/test_chat_template_utils.py run after the changes: 126 passed, 64 skipped, 2 xfailed. All 12 TestGetTrainingChatTemplate[cohere] pass, no regressions on other tokenizers.

@qgallouedec

Copy link
Copy Markdown
Member

@codex review

@chatgpt-codex-connector

Copy link
Copy Markdown

Codex Review: Didn't find any major issues. Chef's kiss.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

@qgallouedec qgallouedec changed the title Add Cohere training chat template (#5471) Add Cohere training chat template Apr 28, 2026
@qgallouedec qgallouedec merged commit 788555a into huggingface:main Apr 28, 2026
6 of 12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants