Skip to content

Add average_log_prob args for cpo#510

Merged
austin362667 merged 7 commits into
linkedin:mainfrom
Mecoli1219:cpo-average-log-prob
Jan 8, 2025
Merged

Add average_log_prob args for cpo#510
austin362667 merged 7 commits into
linkedin:mainfrom
Mecoli1219:cpo-average-log-prob

Conversation

@Mecoli1219

@Mecoli1219 Mecoli1219 commented Jan 3, 2025

Copy link
Copy Markdown
Collaborator

Summary

trl CPO implementation didn't average the log_probs, while the liger kernel averages it when computing the loss. This will cause a mismatch when integrating them.

Testing Done

  • Updating unit test (still investigating why unit test fail locally)
  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

Signed-off-by: Mecoli1219 <michaellai901026@gmail.com>
Signed-off-by: Mecoli1219 <michaellai901026@gmail.com>
Signed-off-by: Mecoli1219 <michaellai901026@gmail.com>
Comment thread src/liger_kernel/chunked_loss/fused_linear_preference.py Outdated
@kashif

kashif commented Jan 3, 2025

Copy link
Copy Markdown
Contributor

TRL is using the default as in the official repo for CPO: https://github.com/fe1ixxu/CPO_SIMPO/blob/main/scripts/cpo_trainer.py#L626

Mecoli1219 and others added 3 commits January 4, 2025 05:56
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Signed-off-by: Mecoli1219 <michaellai901026@gmail.com>
"scalar, dtype, atol, rtol",
[
(1.0, torch.bfloat16, 5e-3, 5e-3),
(1.0, torch.bfloat16, 5e-2, 5e-2),

@austin362667 austin362667 Jan 8, 2025

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reasoning behind this adjustment?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kashif and I find that after disabling average_log_prob for CPO, it will have a higher deviation from HF implementation when the model is large and the data type is bf16. Since the result is still close within both methods, we increase atol and rtol to make this test pass.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as bfloat16 is less accurate for larger numbers, this is needed to make the test pass and is the same as in the other bfloat16 tests

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then adjusting tol makes sense. ❤️

@austin362667 austin362667 left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you both for making this PR. Hopefully, it unblocks huggingface/trl#2506.

@austin362667 austin362667 merged commit 23e3772 into linkedin:main Jan 8, 2025
@kashif

kashif commented Jan 8, 2025

Copy link
Copy Markdown
Contributor

awesome thank you! we would still need a release of liger-kernel for the CI to pass but yes it will hopefully unblock!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants