Skip to content

KTO changes to return aux outputs#589

Merged
vaibhavjindal merged 11 commits into
linkedin:mainfrom
vaibhavjindal:kto_trl
Mar 1, 2025
Merged

KTO changes to return aux outputs#589
vaibhavjindal merged 11 commits into
linkedin:mainfrom
vaibhavjindal:kto_trl

Conversation

@vaibhavjindal

@vaibhavjindal vaibhavjindal commented Feb 26, 2025

Copy link
Copy Markdown
Collaborator

Summary

This PR introduces the following changes to enable integration with huggingface TRL:

  1. KTO loss can now return the following along with the loss:
  • chosen_logps
  • rejected_logps
  • sum(chosen_logits)
  • sum(rejected_logits)
  • chosen_rewards
  • rejected_rewards
  1. Adds an option to enable/disable log_probs averaging while calculating the loss.

Details

Benchmark results with the new implementation:
kto_loss_memory
kto_loss_speed

Testing Done

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@vaibhavjindal

Copy link
Copy Markdown
Collaborator Author

@shivam15s @hebiao064 @kashif this PR contains the changes needed for trl integration. PTAL, thanks!

@hebiao064

Copy link
Copy Markdown
Contributor

Can you run a benchmark to see if the memory/speed still performs at certain level?

@vaibhavjindal

Copy link
Copy Markdown
Collaborator Author

Can you run a benchmark to see if the memory/speed still performs at certain level?

Sure, will add it.

Comment on lines +342 to +343
chosen_logits_sum = chosen_logits.nansum()
rejected_logits_sum = rejected_logits.nansum()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you expect these to have nans?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was a bit unsure about this as TRL is using nansum() for this: https://github.com/huggingface/trl/blob/491921c1a4167e7c84429382470b0bb3158e66b0/trl/trainer/kto_trainer.py#L1271. Thus kept the nansum() to deal with the worst case.

@shivam15s shivam15s left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, lets figure out the speed drop in the next PR

@vaibhavjindal vaibhavjindal merged commit d63b888 into linkedin:main Mar 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants