Skip to content

Introduce GradientCheckpointingLayer#37223

Merged
qubvel merged 16 commits into
huggingface:mainfrom
qubvel:gradient-checkpointing-layer
Apr 22, 2025
Merged

Introduce GradientCheckpointingLayer#37223
qubvel merged 16 commits into
huggingface:mainfrom
qubvel:gradient-checkpointing-layer

Conversation

@qubvel

@qubvel qubvel commented Apr 2, 2025

Copy link
Copy Markdown
Contributor

What does this PR do?

A super minimal abstraction for a layer with gradient checkpointing that keeps the logic for enabling and disabling gradient checkpointing within PreTrainedModel for backward compatibility. It allows for a gradual rollout of the feature by supporting both checkpointing mechanisms: with a the current wrap of _gradient_checkpointing_func and using inheritance from GradientCheckpointingLayer.

I've applied this to Llama, but it's just a PoC for the discussion. Perhaps it's better to start with another less popular model that has fewer dependent models to see how it goes and check if it can be breaking for the hub custom code

Who can review?

@qubvel qubvel marked this pull request as ready for review April 2, 2025 22:15
@qubvel

qubvel commented Apr 2, 2025

Copy link
Copy Markdown
Contributor Author

run-slow: llama

@github-actions

github-actions Bot commented Apr 2, 2025

Copy link
Copy Markdown
Contributor

This comment contains run-slow, running the specified jobs: This comment contains run-slow, running the specified jobs:

models: ['models/llama']
quantizations: [] ...

@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker ArthurZucker left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Lets put it in another file to separate a bit! Otherwise marvelleous

@qubvel

qubvel commented Apr 8, 2025

Copy link
Copy Markdown
Contributor Author

Ended up applying to all llama-based models and SigLIP/2 for the first iteration, all relevant tests pass

RUN_SLOW=1 pytest -k "gradient" tests/models/

CI error is unrelated

cc @ArthurZucker to merge if OK for you

@qubvel qubvel requested a review from ArthurZucker April 8, 2025 15:42

@ArthurZucker ArthurZucker left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧼 thanks a lot!



class GradientCheckpointingLayer(nn.Module):
"""Base class for layers with gradient checkpointing.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for documenting as well!

@qubvel qubvel merged commit 9167fad into huggingface:main Apr 22, 2025
@sfc-gh-sbekman

sfc-gh-sbekman commented May 6, 2025

Copy link
Copy Markdown
Contributor

This is very neat, Pavel!

Might be useful for us users to add to the doc:

  1. that "use_reentrant": True is by default and perhaps pointing to where it gets set if not overridden by the user and
  2. point the reader to gradient_checkpointing_enable which defines the doc on how to override use_reentrant value.

Perhaps even add redundancy/shortcut by adding to the doc:

            gradient_checkpointing_kwargs = {"use_reentrant": True}
            model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)

it's not very intuitive, I first tried to call:

model.gradient_checkpointing_enable(use_reentrant=True)

and it failed TypeError: PreTrainedModel.gradient_checkpointing_enable() got an unexpected keyword argument 'use_reentrant'

zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
* GradientCheckpointingLayer

* trigger

* Move GC layer to a separate file

* Update import

* Expose and document GC layer

* Fix dummy

* Apply to llama-based models

* Update modulars

* Update a few more models for consistency

* Update glm4

* Update Janus
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants