Add return_predicted_tokens support for cross-entropy kernels#1091
Conversation
…r cross-entropy kernels Enable returning per-token argmax predictions (as int64 tensor) without materializing full logits, reusing the existing argmax tracking shared with token_accuracy. Propagate the new flag through ops, transformers wrappers, functional API, loss_utils, all model forward functions, and output classes. Add tests for correctness, combined flags (return_predicted_tokens + return_token_accuracy), and backward pass.
|
@Tcc0403 Could you take a look at this PR when you have a chance? I'd appreciate your review and feedback on the proposal. |
|
Overall lgtm! @Mecoli1219 could you take a second look? Thank you |
|
Looks good to me! I’ve run both the functional and convergence tests on an H100, and everything passed in both transformers v4 & v5. Thanks for the contribution, @yukiu00! One minor concern: this looks like a breaking change. If users are currently importing |
It is certainly a breaking change on |
|
Sounds great. Let me merge it. Thanks for the contribution @yukiu00! |
|
@yukiu00 do you think we should remove the token accuracy output as one can do that outside now? |
|
|
Summary
Add a
return_predicted_tokensflag toLigerCrossEntropyLossandLigerFusedLinearCrossEntropyLossthat returns per-token argmax predictions (asint64tensor) without materializing full logits.Motivation
During training, it is often useful to access the model's predicted tokens (argmax of logits) for logging, visualization, and metric computation — for example, inspecting what the model actually predicts at each position, or tracking prediction distributions over time.
Currently, obtaining predicted tokens requires either:
.argmax(dim=-1), which defeats the memory savings ofFusedLinearCrossEntropy, orSince the cross-entropy kernel already tracks
argmaxinternally (forreturn_token_accuracy, introduced in #910), we can return the predicted token indices as a byproduct at near-zero additional cost.Design
This builds on the
return_token_accuracyinfrastructure (#910). The existingargmax_idxtracking in the Triton kernel is reused, so:return_predicted_tokens=False(default), there is zero overhead — theRETURN_PREDICTED_TOKENSconstexpr is compiled out.return_token_accuracyandreturn_predicted_tokensare enabled, the argmax computation is shared (no duplicate work).ignore_index) return-1as a sentinel value.Changes
ops/cross_entropy.py,ops/fused_linear_cross_entropy.py: AddRETURN_PREDICTED_TOKENSconstexpr to the Triton kernel; storeargmax_idxfor non-ignored tokens,-1for ignored tokens.transformers/cross_entropy.py,transformers/fused_linear_cross_entropy.py,transformers/functional.py: Propagatereturn_predicted_tokensthrough module and functional APIs. ReturnCrossEntropyOutputwhen any extra output is requested.transformers/model/loss_utils.py: Threadreturn_predicted_tokensthroughLigerForCausalLMLoss→fixed_fused_linear_cross_entropy.transformers/model/output_classes.py: Addpredicted_tokensfield to allLiger*CausalLMOutputWithPastdataclasses.transformers/model/*.py(32 model files): Unpack and forwardpredicted_tokensin both tuple and dict return paths, following the same pattern astoken_accuracy.Usage
Testing Done
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergenceNew/updated tests:
test_correctness_with_predicted_tokens(cross-entropy): Verifies predicted tokens match reference argmax, ignored tokens are-1, backward works. Tests multiple dtypes, shapes, and ignore indices.test_correctness_with_predicted_tokens(fused linear cross-entropy): Same coverage with logit-value comparison (handles chunked bfloat16 matmul tie-breaking).test_liger_cross_entropy_structured_output: Extended to parametrizereturn_predicted_tokensacross all 8 combinations of(return_z_loss, return_token_accuracy, return_predicted_tokens). Includes consistency check betweenpredicted_tokensandtoken_accuracywhen both are enabled.