Skip to content

[Cross-entropy-loss] return mean token accuracy metric with CE loss#910

Merged
vaibhavjindal merged 27 commits into
linkedin:mainfrom
kashif:mean_token_accuracy
Nov 5, 2025
Merged

[Cross-entropy-loss] return mean token accuracy metric with CE loss#910
vaibhavjindal merged 27 commits into
linkedin:mainfrom
kashif:mean_token_accuracy

Conversation

@kashif

@kashif kashif commented Oct 16, 2025

Copy link
Copy Markdown
Contributor

Summary

Returns the mean token accuracy metric when minimizing the cross-entropy loss without materializing the logits

https://x.com/jeremyphoward/status/1703246293802586155

Testing Done

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

Comment thread src/liger_kernel/transformers/functional.py Outdated
Comment thread src/liger_kernel/transformers/functional.py Outdated
Comment thread src/liger_kernel/transformers/model/falcon_h1.py Outdated
@kashif

kashif commented Oct 20, 2025

Copy link
Copy Markdown
Contributor Author

@vaibhavjindal would you be able to kindly review?

@kashif

kashif commented Oct 28, 2025

Copy link
Copy Markdown
Contributor Author

@shimizust this will be a breaking change i believe BTW

@kashif kashif changed the title [Cross-entropy-loss] add return_token_accuracy flag to fused_linear_cross_entropy [Cross-entropy-loss] return mean token accuracy metric with CE loss Nov 1, 2025
@vaibhavjindal

Copy link
Copy Markdown
Collaborator

@kashif could you please elaborate on how it will be a breaking change? Will it break the intergration with transformers or trl?

@kashif

kashif commented Nov 3, 2025

Copy link
Copy Markdown
Contributor Author

yes if someone is using the raw functions in their lib. then now that functions returns one more thing... but on the HF side this PR takes care of this

@kashif

kashif commented Nov 3, 2025

Copy link
Copy Markdown
Contributor Author

@vaibhavjindal

Copy link
Copy Markdown
Collaborator

@kashif got it. So if i understand correctly, it will make sure that liger remains compatible with newer versions from HF. However, just want to confirm it will break liger support with older transformers/trl versions?

@kashif

kashif commented Nov 3, 2025

Copy link
Copy Markdown
Contributor Author

no i believe my changes here will work with older version of HF.. i just meant non-HF frameworks

@kashif

kashif commented Nov 3, 2025

Copy link
Copy Markdown
Contributor Author

TRL relies on HF integration for the CE loss so in TRL I will just pin to the liger version that has these changes

@kashif

kashif commented Nov 5, 2025

Copy link
Copy Markdown
Contributor Author

@vaibhavjindal let me fix up the new qwen3-vl model to update its API

@kashif

kashif commented Nov 5, 2025

Copy link
Copy Markdown
Contributor Author

@vaibhavjindal all good from my side

@vaibhavjindal

Copy link
Copy Markdown
Collaborator

@vaibhavjindal all good from my side

Thanks a lot! I will do some final checks on correctness and benchmarks and will try to get it merged soon.

@kashif

kashif commented Nov 5, 2025

Copy link
Copy Markdown
Contributor Author

thank you so much.. also see here: huggingface/trl#4302 (comment)

@vaibhavjindal vaibhavjindal merged commit 7dd8ecc into linkedin:main Nov 5, 2025
3 of 7 checks passed
@kashif

kashif commented Nov 6, 2025

Copy link
Copy Markdown
Contributor Author

thanks @vaibhavjindal for the typo fix and making it more robust!

@kashif kashif deleted the mean_token_accuracy branch November 22, 2025 22:09
github-merge-queue Bot pushed a commit that referenced this pull request Feb 19, 2026
## Summary

Add a `return_predicted_tokens` flag to `LigerCrossEntropyLoss` and
`LigerFusedLinearCrossEntropyLoss` that returns per-token argmax
predictions (as `int64` tensor) **without materializing full logits**.

## Motivation

During training, it is often useful to access the model's predicted
tokens (argmax of logits) for logging, visualization, and metric
computation — for example, inspecting what the model actually predicts
at each position, or tracking prediction distributions over time.

Currently, obtaining predicted tokens requires either:
1. **Materializing full logits** and calling `.argmax(dim=-1)`, which
defeats the memory savings of `FusedLinearCrossEntropy`, or
2. **Recomputing** the forward pass separately for metrics.

Since the cross-entropy kernel already tracks `argmax` internally (for
`return_token_accuracy`, introduced in #910), we can return the
predicted token indices as a byproduct at near-zero additional cost.

## Design

This builds on the `return_token_accuracy` infrastructure (#910). The
existing `argmax_idx` tracking in the Triton kernel is reused, so:

- When `return_predicted_tokens=False` (default), there is **zero
overhead** — the `RETURN_PREDICTED_TOKENS` constexpr is compiled out.
- When both `return_token_accuracy` and `return_predicted_tokens` are
enabled, the argmax computation is **shared** (no duplicate work).
- Ignored tokens (`ignore_index`) return `-1` as a sentinel value.

## Changes

- **`ops/cross_entropy.py`**, **`ops/fused_linear_cross_entropy.py`**:
Add `RETURN_PREDICTED_TOKENS` constexpr to the Triton kernel; store
`argmax_idx` for non-ignored tokens, `-1` for ignored tokens.
- **`transformers/cross_entropy.py`**,
**`transformers/fused_linear_cross_entropy.py`**,
**`transformers/functional.py`**: Propagate `return_predicted_tokens`
through module and functional APIs. Return `CrossEntropyOutput` when any
extra output is requested.
- **`transformers/model/loss_utils.py`**: Thread
`return_predicted_tokens` through `LigerForCausalLMLoss` →
`fixed_fused_linear_cross_entropy`.
- **`transformers/model/output_classes.py`**: Add `predicted_tokens`
field to all `Liger*CausalLMOutputWithPast` dataclasses.
- **`transformers/model/*.py`** (32 model files): Unpack and forward
`predicted_tokens` in both tuple and dict return paths, following the
same pattern as `token_accuracy`.

## Usage

```python
# Standalone
loss_fn = LigerCrossEntropyLoss(return_predicted_tokens=True)
result = loss_fn(logits, target)  # logits: (B*T, V), target: (B*T,)
result.loss              # scalar loss
result.predicted_tokens  # (B*T,) int64 tensor, -1 for ignored tokens

# Fused (no logits materialization)
loss_fn = LigerFusedLinearCrossEntropyLoss(return_predicted_tokens=True)
result = loss_fn(lm_head_weight, hidden_states, target)  # hidden_states: (B*T, H)
result.predicted_tokens  # (B*T,) int64 tensor

# Can combine with token_accuracy
loss_fn = LigerCrossEntropyLoss(
    return_token_accuracy=True,
    return_predicted_tokens=True,
)
result = loss_fn(logits, target)
result.token_accuracy    # scalar
result.predicted_tokens  # (B*T,) int64 tensor
```

> **Note:** `predicted_tokens` is returned as a flat `(B*T,)` tensor,
matching the input shape convention of the cross-entropy API (which
expects `(B*T, V)` logits and `(B*T,)` targets, consistent with
`torch.nn.CrossEntropyLoss`). Reshape as needed:
> ```python
> result.predicted_tokens.view(B, T)
> ```

## Testing Done

- Hardware Type: NVIDIA GPU
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence

### New/updated tests:
- `test_correctness_with_predicted_tokens` (cross-entropy): Verifies
predicted tokens match reference argmax, ignored tokens are `-1`,
backward works. Tests multiple dtypes, shapes, and ignore indices.
- `test_correctness_with_predicted_tokens` (fused linear cross-entropy):
Same coverage with logit-value comparison (handles chunked bfloat16
matmul tie-breaking).
- `test_liger_cross_entropy_structured_output`: Extended to parametrize
`return_predicted_tokens` across all 8 combinations of `(return_z_loss,
return_token_accuracy, return_predicted_tokens)`. Includes consistency
check between `predicted_tokens` and `token_accuracy` when both are
enabled.

Co-authored-by: Chun-Mao (Michael) Lai <72752478+Mecoli1219@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants