Skip to content

fix chosen_nll_loss in chunked losses#486

Merged
shivam15s merged 9 commits into
linkedin:mainfrom
kashif:fix-orpo-nll
Dec 18, 2024
Merged

fix chosen_nll_loss in chunked losses#486
shivam15s merged 9 commits into
linkedin:mainfrom
kashif:fix-orpo-nll

Conversation

@kashif

@kashif kashif commented Dec 17, 2024

Copy link
Copy Markdown
Contributor

Summary

Fix the nll loss in the the chunked loses when the model is a decoder only model, by shifting the logits and targets

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@kashif kashif changed the title fix chosen_nll_loss in chunked loses fix chosen_nll_loss in chunked losses Dec 18, 2024

@shivam15s shivam15s left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch, and thanks for the quick PR! With the TRL fixes we discussed today, we should be able to closely match the loss curve obtained from TRL Trainer

compute_nll_loss=compute_nll_loss,
is_encoder_decoder=is_encoder_decoder,
)
chosen_nll_loss = (

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we also have to fix how we do normalization. My hunch is that's the reason for failing tests

Comment on lines +298 to +304
if not is_encoder_decoder:
shifted_logits = log_probs_chunk[:len_chosen_chunk, :-1].contiguous()
shifted_target = target_chunk[:len_chosen_chunk, 1:].contiguous()
else:
shifted_logits = log_probs_chunk[:len_chosen_chunk].contiguous()
shifted_target = target_chunk[:len_chosen_chunk].contiguous()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shifted logits/target should also be used to calculate chosen/rejected logps and in general for everything that follows to compute the loss.

@shivam15s shivam15s merged commit 61eefe9 into linkedin:main Dec 18, 2024
shivam15s added a commit that referenced this pull request Dec 19, 2024
shivam15s added a commit that referenced this pull request Dec 19, 2024
This reverts commit 61eefe9.

## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants