Conversation
|
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
muellerzr
left a comment
There was a problem hiding this comment.
Nice job! Looks good to me, is get_base_model an old enough api we don’t need to worry about breaking stuff?
Should be. we are already using it in the past here: model_forward = (
unwrapped_model.forward
if not _is_peft_model(unwrapped_model)
else unwrapped_model.get_base_model().forward
) |
|
cc @ArthurZucker gentle ping |
|
ping @ArthurZucker |
ArthurZucker
left a comment
There was a problem hiding this comment.
Thanks let's make it go green and merge
|
@bot /style |
|
@bot /style |
|
Style fixes have been applied. View the workflow run here. |
What does this PR do?
This PR fixes the grad acc issue with gemma3 model. The issue was that we passed **kwargs in the model forward, so we were making the assumption that he was passing
**loss_kwargs->num_items_in_batchto calculate the loss. Not sure what is the best way to fix this @ArthurZucker in general as this might probably happen again. Maybe setaccepts_loss_kwargstoFalsein general and set it toTruefor models that we fixed ? I'm fine also just setting itFalsefor models that don't use thekwargsfor the loss.As for why I didn't have the loss function: In the code, they are filetring the logits/labels so I decided to simply not use
num_items_in_batchto calculate the loss. Otherwise, the loss won't be correctly calculated for one of the cases.Also I fixed an issue related to peft as we couldn't have access to that attribute as the model was a peft model.
To reproduce
winglian script
https://gist.github.com/winglian/569924fe154824c8ce148f6e185cd4cd
After fix
grad acc 2 bs 1 and grad acc 1 bs 2

Fixes #37197
cc @winglian