Skip to content

[core/ gradient_checkpointing] Refactor GC - part 2#27073

Merged
younesbelkada merged 12 commits intohuggingface:mainfrom
younesbelkada:finalize-gc
Oct 27, 2023
Merged

[core/ gradient_checkpointing] Refactor GC - part 2#27073
younesbelkada merged 12 commits intohuggingface:mainfrom
younesbelkada:finalize-gc

Conversation

@younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Oct 25, 2023

What does this PR do?

Extends #27020 by further simplifying the GC enable / disable mechanism. We can simply iterate over all submodules of the PreTrainedModel and check for the attribute gradient_checkpointing.

Some models had supports_gradient_checkpointing attribute set to True whereas they actually don't. So this PR fixes that as well.

Some models were also calling torch.utils.checkpointing.checkpoint instead of self.gradient_checkpointing_func, this PR fixes it.

Also gradient_checkpointing is now private to avoid exposing it as a public attribute

cc @ArthurZucker

Comment on lines +1388 to 1391
self.gradient_checkpointing = False

# Initialize weights and apply final processing
self.post_init()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here moving it before post init because post_init() calls gradient_checkpointing_enable() if in the config you have an attribute gradient_checkpointing


if self.gradient_checkpointing and self.training:
layer_outputs = checkpoint(
layer_outputs = self.gradient_checkpointing_func(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here some models were using torch.checkpoint instead of self.gradient_checkpointing_func so I fixed it here


if self.gradient_checkpointing and self.training:
layer_outputs = checkpoint(
layer_outputs = self.gradient_checkpointing_func(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 25, 2023

The documentation is not available anymore as the PR was closed or merged.

@younesbelkada
Copy link
Contributor Author

@ArthurZucker @LysandreJik - as discussed offline now this PR reverts back the previous behaviour (i.e. if a user sets module.gradient_checkpointing = True in a module that supports it, everthing should work fine) + I have set gradient_checkpointing_func as a private attribute. This PR is ready for review

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot. Maybe if "checkpointing_function" is an attribute it would be more accessible and allows us to document it WDYT?

# Apply it on the top-level module in case the top-level modules supports it
# for example, LongT5Stack inherits from `PreTrainedModel`.
if hasattr(self, "gradient_checkpointing"):
self._gradient_checkpointing_func = gradient_checkpointing_func
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self._gradient_checkpointing_func = gradient_checkpointing_func
self._checkpoint = gradient_checkpointing_func

no what would appear best for users to know that its basically just torch.utils.checkpoint

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm what I like with _gradient_checkpointing_func is that it tells users that it is a function + 'checkpoint' seems a bit ambiguous to me (it can sound like a model checkpoint?)

younesbelkada and others added 5 commits October 27, 2023 15:36
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@younesbelkada younesbelkada merged commit ffff9e7 into huggingface:main Oct 27, 2023
@younesbelkada younesbelkada deleted the finalize-gc branch October 27, 2023 14:15
jambayk added a commit to microsoft/Olive that referenced this pull request Nov 2, 2023
## Describe your changes
The latest version of transformers (>= 4.35.0) is not compatible with
the model. PRs: huggingface/transformers#27020,
huggingface/transformers#27073 change the
expected signature of `_set_gradient_checkpointing` which now doesn't
match the model's
https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_mixformer_sequential.py#L802

## Checklist before requesting a review
- [ ] Add unit tests for this change.
- [ ] Make sure all tests can pass.
- [ ] Update documents if necessary.
- [ ] Format your code by running `pre-commit run --all-files`
- [ ] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.

## (Optional) Issue link
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
…27073)

* fix

* more fixes

* fix other models

* fix long t5

* use `gradient_checkpointing_func` instead

* fix copies

* set `gradient_checkpointing_func` as a private attribute and retrieve previous behaviour

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* replace it with `is_gradient_checkpointing_set`

* remove default

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fixup

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants