Skip to content

fix gemma3 grad acc#37208

Merged
ArthurZucker merged 15 commits intomainfrom
fix-gemma3-grad-acc
Jun 25, 2025
Merged

fix gemma3 grad acc#37208
ArthurZucker merged 15 commits intomainfrom
fix-gemma3-grad-acc

Conversation

@SunMarc
Copy link
Member

@SunMarc SunMarc commented Apr 2, 2025

What does this PR do?

This PR fixes the grad acc issue with gemma3 model. The issue was that we passed **kwargs in the model forward, so we were making the assumption that he was passing **loss_kwargs -> num_items_in_batch to calculate the loss. Not sure what is the best way to fix this @ArthurZucker in general as this might probably happen again. Maybe set accepts_loss_kwargs to False in general and set it to True for models that we fixed ? I'm fine also just setting it False for models that don't use the kwargs for the loss.

As for why I didn't have the loss function: In the code, they are filetring the logits/labels so I decided to simply not use num_items_in_batch to calculate the loss. Otherwise, the loss won't be correctly calculated for one of the cases.

Also I fixed an issue related to peft as we couldn't have access to that attribute as the model was a peft model.

To reproduce

winglian script
https://gist.github.com/winglian/569924fe154824c8ce148f6e185cd4cd

After fix

grad acc 2 bs 1 and grad acc 1 bs 2
Screenshot 2025-04-02 at 4 38 54 PM

Fixes #37197

cc @winglian

@github-actions github-actions bot marked this pull request as draft April 2, 2025 13:41
@github-actions
Copy link
Contributor

github-actions bot commented Apr 2, 2025

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@SunMarc SunMarc marked this pull request as ready for review April 2, 2025 13:42
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@SunMarc SunMarc requested a review from muellerzr April 2, 2025 14:43
Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job! Looks good to me, is get_base_model an old enough api we don’t need to worry about breaking stuff?

@SunMarc
Copy link
Member Author

SunMarc commented Apr 7, 2025

Nice job! Looks good to me, is get_base_model an old enough api we don’t need to worry about breaking stuff?

Should be. we are already using it in the past here:

 model_forward = (
            unwrapped_model.forward
            if not _is_peft_model(unwrapped_model)
            else unwrapped_model.get_base_model().forward
        )

@SunMarc
Copy link
Member Author

SunMarc commented Apr 7, 2025

cc @ArthurZucker gentle ping

@SunMarc SunMarc requested review from ArthurZucker and removed request for ArthurZucker April 10, 2025 09:13
@SunMarc
Copy link
Member Author

SunMarc commented May 6, 2025

ping @ArthurZucker

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks let's make it go green and merge

@ArthurZucker
Copy link
Collaborator

@bot /style

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oups sorry

@ArthurZucker
Copy link
Collaborator

@bot /style

@github-actions
Copy link
Contributor

Style fixes have been applied. View the workflow run here.

@ArthurZucker ArthurZucker merged commit 3c322c9 into main Jun 25, 2025
21 checks passed
@ArthurZucker ArthurZucker deleted the fix-gemma3-grad-acc branch June 25, 2025 14:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Gemma3 Gradient Accumulation loss

4 participants