fix chosen_nll_loss in chunked losses#486
Merged
Merged
Conversation
shivam15s
approved these changes
Dec 18, 2024
shivam15s
left a comment
Collaborator
There was a problem hiding this comment.
Great catch, and thanks for the quick PR! With the TRL fixes we discussed today, we should be able to closely match the loss curve obtained from TRL Trainer
| compute_nll_loss=compute_nll_loss, | ||
| is_encoder_decoder=is_encoder_decoder, | ||
| ) | ||
| chosen_nll_loss = ( |
Collaborator
There was a problem hiding this comment.
I believe we also have to fix how we do normalization. My hunch is that's the reason for failing tests
Comment on lines
+298
to
+304
| if not is_encoder_decoder: | ||
| shifted_logits = log_probs_chunk[:len_chosen_chunk, :-1].contiguous() | ||
| shifted_target = target_chunk[:len_chosen_chunk, 1:].contiguous() | ||
| else: | ||
| shifted_logits = log_probs_chunk[:len_chosen_chunk].contiguous() | ||
| shifted_target = target_chunk[:len_chosen_chunk].contiguous() | ||
|
|
Collaborator
There was a problem hiding this comment.
The shifted logits/target should also be used to calculate chosen/rejected logps and in general for everything that follows to compute the loss.
shivam15s
added a commit
that referenced
this pull request
Dec 19, 2024
This reverts commit 61eefe9.
shivam15s
added a commit
that referenced
this pull request
Dec 19, 2024
This reverts commit 61eefe9. ## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fix the nll loss in the the chunked loses when the model is a decoder only model, by shifting the logits and targets
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence