Skip to content

Relaxing PreTrainedModel requirement in _save #6901

@prajjwal1

Description

@prajjwal1

🚀 Feature request

It's great to see that Trainer is becoming flexible. Each functions seems to be more self contained now making inheritance easier. I've experimented with many custom models. For instance,

class Model(nn.Module):
    def __init__(self, ..):
        self.encoder = AutoModel.from_pretrained(..)
        self.custom_modules = ..
    def forward(self, **kwargs):
        output = self.encoder(**kwargs)
        # some custom operations

Many users are required to create custom models if they just don't want simple SequenceClassification head. In all cases, I have to override _save method because of this line which explicitly puts a restriction on Trainer to be used with models that inherit from PreTrainedModel. It would be good to relax this requirement and give a warning about not using PreTrainedModel instead.

Your contribution

I'll open a PR if I get approval.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions