[core] Refactor of gradient_checkpointing#27020
[core] Refactor of gradient_checkpointing#27020younesbelkada merged 16 commits intohuggingface:mainfrom
core] Refactor of gradient_checkpointing#27020Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
| config_class = SwinConfig | ||
| base_model_prefix = "swin" | ||
| main_input_name = "pixel_values" | ||
| supports_gradient_checkpointing = True |
There was a problem hiding this comment.
Here I removed it because not relevant to TF models
ArthurZucker
left a comment
There was a problem hiding this comment.
Very nice cleanup!
| """ | ||
| if self.supports_gradient_checkpointing: | ||
| self.apply(partial(self._set_gradient_checkpointing, value=False)) | ||
| self.apply(partial(self._set_gradient_checkpointing, gradient_checkpointing_func=None)) |
There was a problem hiding this comment.
WHen we disable gradient checkpointing, I think the module.gradient_checkpointing will still be True.
Let's make module.gradient_checkpointing into a property to be sure we always check if the function is none or not WDYT?
There was a problem hiding this comment.
Property could go at the ModelMixin level ?
|
Can you add a test to make sure setting and unsetting both work as expected (specifically for the fix we are implementing in TRL) |
|
+1 |
| # Enable / disable GC for the language model as well | ||
| if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"): | ||
| self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func) |
There was a problem hiding this comment.
BLIP2 never propagated gradient_checkpointing to its language_model
| # Enable / disable GC for the language model as well | ||
| if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"): | ||
| self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func) |
| for backbone_module in module.modules(): | ||
| if hasattr(backbone_module, "gradient_checkpointing"): | ||
| backbone_module.gradient_checkpointing_func = gradient_checkpointing_func | ||
| backbone_module.gradient_checkpointing = gradient_checkpointing_func is not None |
There was a problem hiding this comment.
Another edge case here where the backbone has some modules that support GC but that attribute never being propagated
younesbelkada
left a comment
There was a problem hiding this comment.
It turns out ~30 architectures were not properly using gradient_checkpointing, I left 3 comments to be aware of
ArthurZucker
left a comment
There was a problem hiding this comment.
I think we should use the call rather than forward to have the hooks!
ArthurZucker
left a comment
There was a problem hiding this comment.
Thanks a lot, very nice cleanup! 🔥
| layer_outputs = torch.utils.checkpoint.checkpoint( | ||
| create_custom_forward(layer_module), | ||
| layer_outputs = self.gradient_checkpointing_func( | ||
| layer_module.__call__, |
There was a problem hiding this comment.
let's document this in the gradient checkpointing doc (IMO important to know! why forward and call are different)
|
Ran some training tests with PEFT + GC using this branch and everything seem to pass! Merging once the CI is green |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
* v1 * fix * remove `create_custom_forward` * fixup * fixup * add test and fix all failing GC tests * remove all remaining `create_custom_forward` methods * fix idefics bug * fixup * replace with `__call__` * add comment * quality
## Describe your changes The latest version of transformers (>= 4.35.0) is not compatible with the model. PRs: huggingface/transformers#27020, huggingface/transformers#27073 change the expected signature of `_set_gradient_checkpointing` which now doesn't match the model's https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_mixformer_sequential.py#L802 ## Checklist before requesting a review - [ ] Add unit tests for this change. - [ ] Make sure all tests can pass. - [ ] Update documents if necessary. - [ ] Format your code by running `pre-commit run --all-files` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. ## (Optional) Issue link
* v1 * fix * remove `create_custom_forward` * fixup * fixup * add test and fix all failing GC tests * remove all remaining `create_custom_forward` methods * fix idefics bug * fixup * replace with `__call__` * add comment * quality
|
Whatis the difference of enable_gradient_checkpointing and gradient_checkpointing_enable?? |
|
@lucasjinreal I can only see |
|
🤯 transformers/src/transformers/modeling_utils.py Line 2195 in af4c026 |

What does this PR do?
Alternative to #26917
This way we make
set_gradient_checkpointingmore modulable, as requested by some users - e.g. #21381 (comment)Fixes some issues with DDP such as: huggingface/trl#835
Also removed GC support from
TFSwinas in theorygradient_checkpointingis used only for PT models.Added also a CI tests for that
For users that want to use
gradient_checkpointingwithuse_reentrant=False: