[core/ gradient_checkpointing] Refactor GC - part 2#27073
[core/ gradient_checkpointing] Refactor GC - part 2#27073younesbelkada merged 12 commits intohuggingface:mainfrom
core/ gradient_checkpointing] Refactor GC - part 2#27073Conversation
| self.gradient_checkpointing = False | ||
|
|
||
| # Initialize weights and apply final processing | ||
| self.post_init() |
There was a problem hiding this comment.
here moving it before post init because post_init() calls gradient_checkpointing_enable() if in the config you have an attribute gradient_checkpointing
|
|
||
| if self.gradient_checkpointing and self.training: | ||
| layer_outputs = checkpoint( | ||
| layer_outputs = self.gradient_checkpointing_func( |
There was a problem hiding this comment.
here some models were using torch.checkpoint instead of self.gradient_checkpointing_func so I fixed it here
|
|
||
| if self.gradient_checkpointing and self.training: | ||
| layer_outputs = checkpoint( | ||
| layer_outputs = self.gradient_checkpointing_func( |
|
The documentation is not available anymore as the PR was closed or merged. |
… previous behaviour
|
@ArthurZucker @LysandreJik - as discussed offline now this PR reverts back the previous behaviour (i.e. if a user sets |
ArthurZucker
left a comment
There was a problem hiding this comment.
Thanks a lot. Maybe if "checkpointing_function" is an attribute it would be more accessible and allows us to document it WDYT?
| # Apply it on the top-level module in case the top-level modules supports it | ||
| # for example, LongT5Stack inherits from `PreTrainedModel`. | ||
| if hasattr(self, "gradient_checkpointing"): | ||
| self._gradient_checkpointing_func = gradient_checkpointing_func |
There was a problem hiding this comment.
| self._gradient_checkpointing_func = gradient_checkpointing_func | |
| self._checkpoint = gradient_checkpointing_func |
no what would appear best for users to know that its basically just torch.utils.checkpoint
There was a problem hiding this comment.
Hmmm what I like with _gradient_checkpointing_func is that it tells users that it is a function + 'checkpoint' seems a bit ambiguous to me (it can sound like a model checkpoint?)
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
## Describe your changes The latest version of transformers (>= 4.35.0) is not compatible with the model. PRs: huggingface/transformers#27020, huggingface/transformers#27073 change the expected signature of `_set_gradient_checkpointing` which now doesn't match the model's https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_mixformer_sequential.py#L802 ## Checklist before requesting a review - [ ] Add unit tests for this change. - [ ] Make sure all tests can pass. - [ ] Update documents if necessary. - [ ] Format your code by running `pre-commit run --all-files` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. ## (Optional) Issue link
…27073) * fix * more fixes * fix other models * fix long t5 * use `gradient_checkpointing_func` instead * fix copies * set `gradient_checkpointing_func` as a private attribute and retrieve previous behaviour * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * replace it with `is_gradient_checkpointing_set` * remove default * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fixup --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
What does this PR do?
Extends #27020 by further simplifying the GC enable / disable mechanism. We can simply iterate over all submodules of the
PreTrainedModeland check for the attributegradient_checkpointing.Some models had
supports_gradient_checkpointingattribute set toTruewhereas they actually don't. So this PR fixes that as well.Some models were also calling
torch.utils.checkpointing.checkpointinstead ofself.gradient_checkpointing_func, this PR fixes it.Also
gradient_checkpointingis now private to avoid exposing it as a public attributecc @ArthurZucker