Skip to content

Introducing AutoPeftModelForxxx#694

Merged
younesbelkada merged 9 commits intohuggingface:mainfrom
younesbelkada:add-auto-peft-model
Jul 14, 2023
Merged

Introducing AutoPeftModelForxxx#694
younesbelkada merged 9 commits intohuggingface:mainfrom
younesbelkada:add-auto-peft-model

Conversation

@younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Jul 13, 2023

This PR introduces a new paradigm, AutoPeftModelForxxx intended for users that want to rapidly load and run peft models.
Currently a user needs to run all these steps:

from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM

peft_config = PeftConfig.from_pretrained("ybelkada/opt-350m-lora") 
base_model_path = peft_config.base_model_name_or_path

transformers_model = AutoModelForCausalLM.from_pretrained(base_model_path, device_map="auto", load_in_8bit=True)
peft_model = PeftModel.from_pretrained(transformers_model, peft_config)

to load a peft model from the Hub or locally, whereas now they could just do:

from peft import AutoPeftModelForCausalLM

peft_model = AutoPeftModelForCausalLM.from_pretrained("ybelkada/opt-350m-lora")

cc @pacman100 @BenjaminBossan

TODOs:

  • add tests
  • add docs

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 13, 2023

The documentation is not available anymore as the PR was closed or merged.

@younesbelkada younesbelkada marked this pull request as ready for review July 13, 2023 14:13
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks quite good to me, I have only minor comments.

Admittedly, I'm not an expert in the concept of the Auto* models, so I can't comment on the overall design of this feature. Let's see what Sourab has to add to this.

)


class PeftAutoModelTester(unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The individual tests all look very similar. I wonder if they could be parametrized.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah they are quite similar however, to check if the models are effectively converted in bfloat16 for instance I check some custom module attributes for each case. Maybe let's keep that as it is

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Those could be split into separate tests or could be parametrized via operator.attrgetter but it's not super important.

)


class PeftAutoModelTester(unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Those could be split into separate tests or could be parametrized via operator.attrgetter but it's not super important.

src/peft/auto.py Outdated
)

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *peft_model_args, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I just noticed this distinction between *args and **kwargs now. IMHO it's not super intuitive that it works that way. Not sure what would be a better solution, but it should be at least documented what *peft_model_args are.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As there are only 3 params adapter_name, adapter_name, config apart from the base model and peft model path that is already passed, and the fact that they are all the same across the supported tasks and are kwargs, would it make sense to just have them explicitly mentioned here too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the valuable feedback, yes that would makes totally sense

src/peft/auto.py Outdated
_target_peft_class = None

def __init__(self, *args, **kwargs):
raise TypeError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see now that transformers uses EnvironmentError and that's why you used it. I still think it's the wrong error to raise, but maybe it should be used for consistency if there is some code somewhere that catches this error specifically??

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that makes sense, will revert it to EnvironmentError for consistency with transformers!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment that explains why.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, just added a comment

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @younesbelkada, I really like the simplified UX ✨. I have a few queries based on the offline discussion.

  1. Methods such as LoRA, AdaLora, IA3, AdaptionPrompt etc are widely applicable to tasks other than those explicitly supported as they do inline changes to the modules, such as ASR, Image Captioning using BLIP, Stable DIffusion, etc. What are your thoughts on going about it as users might expect it to work for those too? Or do we restrict this API to only the explicitly supported NLP tasks?
  2. Left a comment

src/peft/auto.py Outdated
)

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *peft_model_args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As there are only 3 params adapter_name, adapter_name, config apart from the base model and peft model path that is already passed, and the fact that they are all the same across the supported tasks and are kwargs, would it make sense to just have them explicitly mentioned here too?

Copy link
Contributor Author

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @pacman100
Thanks for your review, regarding your first point, we could probably do a first version with only officially supported NLP tasks and do a second iteration to add a new auto mapping class that group the model classes based on modalities

@younesbelkada younesbelkada requested a review from pacman100 July 14, 2023 07:20
Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @younesbelkada for iterating, LGTM!

@younesbelkada younesbelkada merged commit 0675541 into huggingface:main Jul 14, 2023
@younesbelkada younesbelkada deleted the add-auto-peft-model branch July 14, 2023 09:07
Guy-Bilitski pushed a commit to Guy-Bilitski/peft that referenced this pull request May 13, 2025
* working v1 for LMs

* added tests.

* added documentation.

* fixed ruff issues.

* added `AutoPeftModelForFeatureExtraction` .

* replace with `TypeError`

* address last comments

* added comment.
cyyever pushed a commit to cyyever/peft that referenced this pull request Sep 4, 2025
* update to `prepare_model_for_kbit_training`

from deprecated `prepare_model_for_int8_training`
and add `use_gradient_checkpointing=args.gradient_checkpointing` to
automatically follow the gradient checkpointing choice

is also the workaround for huggingface#694

* workaround for gradient checkpointing issue

calling model.gradient_checkpointing_enable() twice causes issues
this workaround calls it in prepare_model_for_kbit_training and then
changes the arg to false to make sure it isn't called again in
huggingface trainer inner loop

also changes stack_llama_2 sft trainer to use correct device map for ddp
training so that you can test this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants