[Cross-entropy-loss] return mean token accuracy metric with CE loss#910
Conversation
|
@vaibhavjindal would you be able to kindly review? |
|
@shimizust this will be a breaking change i believe BTW |
|
@kashif could you please elaborate on how it will be a breaking change? Will it break the intergration with transformers or trl? |
|
yes if someone is using the raw functions in their lib. then now that functions returns one more thing... but on the HF side this PR takes care of this |
|
@kashif got it. So if i understand correctly, it will make sure that liger remains compatible with newer versions from HF. However, just want to confirm it will break liger support with older transformers/trl versions? |
|
no i believe my changes here will work with older version of HF.. i just meant non-HF frameworks |
|
TRL relies on HF integration for the CE loss so in TRL I will just pin to the liger version that has these changes |
|
@vaibhavjindal let me fix up the new qwen3-vl model to update its API |
|
@vaibhavjindal all good from my side |
Thanks a lot! I will do some final checks on correctness and benchmarks and will try to get it merged soon. |
|
thank you so much.. also see here: huggingface/trl#4302 (comment) |
|
thanks @vaibhavjindal for the typo fix and making it more robust! |
## Summary Add a `return_predicted_tokens` flag to `LigerCrossEntropyLoss` and `LigerFusedLinearCrossEntropyLoss` that returns per-token argmax predictions (as `int64` tensor) **without materializing full logits**. ## Motivation During training, it is often useful to access the model's predicted tokens (argmax of logits) for logging, visualization, and metric computation — for example, inspecting what the model actually predicts at each position, or tracking prediction distributions over time. Currently, obtaining predicted tokens requires either: 1. **Materializing full logits** and calling `.argmax(dim=-1)`, which defeats the memory savings of `FusedLinearCrossEntropy`, or 2. **Recomputing** the forward pass separately for metrics. Since the cross-entropy kernel already tracks `argmax` internally (for `return_token_accuracy`, introduced in #910), we can return the predicted token indices as a byproduct at near-zero additional cost. ## Design This builds on the `return_token_accuracy` infrastructure (#910). The existing `argmax_idx` tracking in the Triton kernel is reused, so: - When `return_predicted_tokens=False` (default), there is **zero overhead** — the `RETURN_PREDICTED_TOKENS` constexpr is compiled out. - When both `return_token_accuracy` and `return_predicted_tokens` are enabled, the argmax computation is **shared** (no duplicate work). - Ignored tokens (`ignore_index`) return `-1` as a sentinel value. ## Changes - **`ops/cross_entropy.py`**, **`ops/fused_linear_cross_entropy.py`**: Add `RETURN_PREDICTED_TOKENS` constexpr to the Triton kernel; store `argmax_idx` for non-ignored tokens, `-1` for ignored tokens. - **`transformers/cross_entropy.py`**, **`transformers/fused_linear_cross_entropy.py`**, **`transformers/functional.py`**: Propagate `return_predicted_tokens` through module and functional APIs. Return `CrossEntropyOutput` when any extra output is requested. - **`transformers/model/loss_utils.py`**: Thread `return_predicted_tokens` through `LigerForCausalLMLoss` → `fixed_fused_linear_cross_entropy`. - **`transformers/model/output_classes.py`**: Add `predicted_tokens` field to all `Liger*CausalLMOutputWithPast` dataclasses. - **`transformers/model/*.py`** (32 model files): Unpack and forward `predicted_tokens` in both tuple and dict return paths, following the same pattern as `token_accuracy`. ## Usage ```python # Standalone loss_fn = LigerCrossEntropyLoss(return_predicted_tokens=True) result = loss_fn(logits, target) # logits: (B*T, V), target: (B*T,) result.loss # scalar loss result.predicted_tokens # (B*T,) int64 tensor, -1 for ignored tokens # Fused (no logits materialization) loss_fn = LigerFusedLinearCrossEntropyLoss(return_predicted_tokens=True) result = loss_fn(lm_head_weight, hidden_states, target) # hidden_states: (B*T, H) result.predicted_tokens # (B*T,) int64 tensor # Can combine with token_accuracy loss_fn = LigerCrossEntropyLoss( return_token_accuracy=True, return_predicted_tokens=True, ) result = loss_fn(logits, target) result.token_accuracy # scalar result.predicted_tokens # (B*T,) int64 tensor ``` > **Note:** `predicted_tokens` is returned as a flat `(B*T,)` tensor, matching the input shape convention of the cross-entropy API (which expects `(B*T, V)` logits and `(B*T,)` targets, consistent with `torch.nn.CrossEntropyLoss`). Reshape as needed: > ```python > result.predicted_tokens.view(B, T) > ``` ## Testing Done - Hardware Type: NVIDIA GPU - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence ### New/updated tests: - `test_correctness_with_predicted_tokens` (cross-entropy): Verifies predicted tokens match reference argmax, ignored tokens are `-1`, backward works. Tests multiple dtypes, shapes, and ignore indices. - `test_correctness_with_predicted_tokens` (fused linear cross-entropy): Same coverage with logit-value comparison (handles chunked bfloat16 matmul tie-breaking). - `test_liger_cross_entropy_structured_output`: Extended to parametrize `return_predicted_tokens` across all 8 combinations of `(return_z_loss, return_token_accuracy, return_predicted_tokens)`. Includes consistency check between `predicted_tokens` and `token_accuracy` when both are enabled. Co-authored-by: Chun-Mao (Michael) Lai <72752478+Mecoli1219@users.noreply.github.com>
Summary
Returns the mean token accuracy metric when minimizing the cross-entropy loss without materializing the logits
https://x.com/jeremyphoward/status/1703246293802586155
Testing Done
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence