Add average_log_prob args for cpo#510
Conversation
Signed-off-by: Mecoli1219 <michaellai901026@gmail.com>
Signed-off-by: Mecoli1219 <michaellai901026@gmail.com>
Signed-off-by: Mecoli1219 <michaellai901026@gmail.com>
|
TRL is using the default as in the official repo for CPO: https://github.com/fe1ixxu/CPO_SIMPO/blob/main/scripts/cpo_trainer.py#L626 |
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Signed-off-by: Mecoli1219 <michaellai901026@gmail.com>
| "scalar, dtype, atol, rtol", | ||
| [ | ||
| (1.0, torch.bfloat16, 5e-3, 5e-3), | ||
| (1.0, torch.bfloat16, 5e-2, 5e-2), |
There was a problem hiding this comment.
What's the reasoning behind this adjustment?
There was a problem hiding this comment.
@kashif and I find that after disabling average_log_prob for CPO, it will have a higher deviation from HF implementation when the model is large and the data type is bf16. Since the result is still close within both methods, we increase atol and rtol to make this test pass.
There was a problem hiding this comment.
as bfloat16 is less accurate for larger numbers, this is needed to make the test pass and is the same as in the other bfloat16 tests
There was a problem hiding this comment.
Then adjusting tol makes sense. ❤️
austin362667
left a comment
There was a problem hiding this comment.
Thank you both for making this PR. Hopefully, it unblocks huggingface/trl#2506.
|
awesome thank you! we would still need a release of liger-kernel for the CI to pass but yes it will hopefully unblock! |
Summary
trlCPO implementation didn't average the log_probs, while the liger kernel averages it when computing the loss. This will cause a mismatch when integrating them.Testing Done
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence