Skip to content

Add return_predicted_tokens support for cross-entropy kernels#1091

Merged
Mecoli1219 merged 2 commits into
linkedin:mainfrom
yukiu00:feat/return-predicted-tokens
Feb 19, 2026
Merged

Add return_predicted_tokens support for cross-entropy kernels#1091
Mecoli1219 merged 2 commits into
linkedin:mainfrom
yukiu00:feat/return-predicted-tokens

Conversation

@yukiu00

@yukiu00 yukiu00 commented Feb 10, 2026

Copy link
Copy Markdown
Contributor

Summary

Add a return_predicted_tokens flag to LigerCrossEntropyLoss and LigerFusedLinearCrossEntropyLoss that returns per-token argmax predictions (as int64 tensor) without materializing full logits.

Motivation

During training, it is often useful to access the model's predicted tokens (argmax of logits) for logging, visualization, and metric computation — for example, inspecting what the model actually predicts at each position, or tracking prediction distributions over time.

Currently, obtaining predicted tokens requires either:

  1. Materializing full logits and calling .argmax(dim=-1), which defeats the memory savings of FusedLinearCrossEntropy, or
  2. Recomputing the forward pass separately for metrics.

Since the cross-entropy kernel already tracks argmax internally (for return_token_accuracy, introduced in #910), we can return the predicted token indices as a byproduct at near-zero additional cost.

Design

This builds on the return_token_accuracy infrastructure (#910). The existing argmax_idx tracking in the Triton kernel is reused, so:

  • When return_predicted_tokens=False (default), there is zero overhead — the RETURN_PREDICTED_TOKENS constexpr is compiled out.
  • When both return_token_accuracy and return_predicted_tokens are enabled, the argmax computation is shared (no duplicate work).
  • Ignored tokens (ignore_index) return -1 as a sentinel value.

Changes

  • ops/cross_entropy.py, ops/fused_linear_cross_entropy.py: Add RETURN_PREDICTED_TOKENS constexpr to the Triton kernel; store argmax_idx for non-ignored tokens, -1 for ignored tokens.
  • transformers/cross_entropy.py, transformers/fused_linear_cross_entropy.py, transformers/functional.py: Propagate return_predicted_tokens through module and functional APIs. Return CrossEntropyOutput when any extra output is requested.
  • transformers/model/loss_utils.py: Thread return_predicted_tokens through LigerForCausalLMLossfixed_fused_linear_cross_entropy.
  • transformers/model/output_classes.py: Add predicted_tokens field to all Liger*CausalLMOutputWithPast dataclasses.
  • transformers/model/*.py (32 model files): Unpack and forward predicted_tokens in both tuple and dict return paths, following the same pattern as token_accuracy.

Usage

# Standalone
loss_fn = LigerCrossEntropyLoss(return_predicted_tokens=True)
result = loss_fn(logits, target)  # logits: (B*T, V), target: (B*T,)
result.loss              # scalar loss
result.predicted_tokens  # (B*T,) int64 tensor, -1 for ignored tokens

# Fused (no logits materialization)
loss_fn = LigerFusedLinearCrossEntropyLoss(return_predicted_tokens=True)
result = loss_fn(lm_head_weight, hidden_states, target)  # hidden_states: (B*T, H)
result.predicted_tokens  # (B*T,) int64 tensor

# Can combine with token_accuracy
loss_fn = LigerCrossEntropyLoss(
    return_token_accuracy=True,
    return_predicted_tokens=True,
)
result = loss_fn(logits, target)
result.token_accuracy    # scalar
result.predicted_tokens  # (B*T,) int64 tensor

Note: predicted_tokens is returned as a flat (B*T,) tensor, matching the input shape convention of the cross-entropy API (which expects (B*T, V) logits and (B*T,) targets, consistent with torch.nn.CrossEntropyLoss). Reshape as needed:

result.predicted_tokens.view(B, T)

Testing Done

  • Hardware Type: NVIDIA GPU
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

New/updated tests:

  • test_correctness_with_predicted_tokens (cross-entropy): Verifies predicted tokens match reference argmax, ignored tokens are -1, backward works. Tests multiple dtypes, shapes, and ignore indices.
  • test_correctness_with_predicted_tokens (fused linear cross-entropy): Same coverage with logit-value comparison (handles chunked bfloat16 matmul tie-breaking).
  • test_liger_cross_entropy_structured_output: Extended to parametrize return_predicted_tokens across all 8 combinations of (return_z_loss, return_token_accuracy, return_predicted_tokens). Includes consistency check between predicted_tokens and token_accuracy when both are enabled.

…r cross-entropy kernels

Enable returning per-token argmax predictions (as int64 tensor) without materializing
full logits, reusing the existing argmax tracking shared with token_accuracy.
Propagate the new flag through ops, transformers wrappers, functional API, loss_utils,
all model forward functions, and output classes. Add tests for correctness, combined
flags (return_predicted_tokens + return_token_accuracy), and backward pass.
@yukiu00

yukiu00 commented Feb 16, 2026

Copy link
Copy Markdown
Contributor Author

@Tcc0403 Could you take a look at this PR when you have a chance? I'd appreciate your review and feedback on the proposal.

@Tcc0403

Tcc0403 commented Feb 17, 2026

Copy link
Copy Markdown
Collaborator

Overall lgtm! @Mecoli1219 could you take a second look? Thank you

@Mecoli1219

Copy link
Copy Markdown
Collaborator

Looks good to me! I’ve run both the functional and convergence tests on an H100, and everything passed in both transformers v4 & v5. Thanks for the contribution, @yukiu00!

One minor concern: this looks like a breaking change. If users are currently importing cross_entropy_forward directly from liger_kernel.ops.cross_entropy, their code will break in the next release. wdyt?

@Tcc0403

Tcc0403 commented Feb 19, 2026

Copy link
Copy Markdown
Collaborator

One minor concern: this looks like a breaking change. If users are currently importing cross_entropy_forward directly from liger_kernel.ops.cross_entropy, their code will break in the next release. wdyt?

It is certainly a breaking change on liger_kernel.ops.cross_entropy, but I believe most users only use high level APIs which are monkey patch functions, nn.Modules and autograd.Function. As long as we set default values on those interface so they behave the same, it should be fine in most use cases.

@Mecoli1219 Mecoli1219 added this pull request to the merge queue Feb 19, 2026
Merged via the queue into linkedin:main with commit 9c1ddb7 Feb 19, 2026
3 of 9 checks passed
@Mecoli1219

Copy link
Copy Markdown
Collaborator

Sounds great. Let me merge it. Thanks for the contribution @yukiu00!

@kashif

kashif commented Feb 20, 2026

Copy link
Copy Markdown
Contributor

@yukiu00 do you think we should remove the token accuracy output as one can do that outside now?

@yukiu00

yukiu00 commented Feb 20, 2026

Copy link
Copy Markdown
Contributor Author

@kashif

I agree with you, but I also think we should keep it for backward compatibility. What do you think, @Tcc0403 ?

@kashif

kashif commented Feb 20, 2026

Copy link
Copy Markdown
Contributor

@yukiu00 @Tcc0403 I know TRL is using it, but if we change it and make a release i can switch to the new API, it was an oversight from me; I should have done this approach rather than the actual metric approach...

@Tcc0403

Tcc0403 commented Feb 20, 2026

Copy link
Copy Markdown
Collaborator

I think we can keep metrics computations in ce kernel to reduce launch overhead from performance perspective.
Nvm, we still need another kernel for accuracy computation. Yeah we can remove it as return_token_accuracy was defaulted to False.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants