Generation: PreTrainedModel no longer inherits GenerationMixin 🚨 🚨 #33150
Closed
gante wants to merge 6 commits intohuggingface:mainfrom
Closed
Generation: PreTrainedModel no longer inherits GenerationMixin 🚨 🚨 #33150gante wants to merge 6 commits intohuggingface:mainfrom
PreTrainedModel no longer inherits GenerationMixin 🚨 🚨 #33150gante wants to merge 6 commits intohuggingface:mainfrom
Conversation
GenerationMixin inherited by PreTrainedModelPreTrainedModel no longer inherits GenerationMixin
8 tasks
PreTrainedModel no longer inherits GenerationMixinPreTrainedModel no longer inherits GenerationMixin 🚨 🚨
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Contributor
Author
|
hold up, found a way to make it fully BC 💛 |
Contributor
|
@gante Basically I was trying to recreate a gpt2 model by inheriting PreTrainedModel. from __future__ import annotations
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from transformers import PreTrainedModel
from src.transformer_block.layer_norm import LayerNorm
from src.transformer_block.t_block import GPTTransformerBlock
class GPTConfig(PretrainedConfig):
"""
Configuration class for GPT-2 model
"""
model_type = "gpt_fast_llm"
def __init__(
self,
vocab_size: int = 200019,
context_len: int = 256,
embedding_dim: int = 768,
n_heads: int = 12,
n_layers: int = 12,
drop_rate: float = 0.0,
qkv_bias: bool = True,
batch_size: int = 8,
**kwargs,
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.context_len = context_len
self.embedding_dim = embedding_dim
self.n_heads = n_heads
self.n_layers = n_layers
self.drop_rate = drop_rate
self.qkv_bias = qkv_bias
self.batch_size = batch_size
class GPTModel(PreTrainedModel):
"""
The base model for GPT-2 architecture
Input:
x : tensor of shape (batch_size, seq_len)
Output:
logits : tensor of shape (batch_size, seq_len, vocab_size)
"""
config_class = GPTConfig
def __init__(self, config: GPTConfig):
super().__init__(config)
self.token_embedding = nn.Embedding(
num_embeddings=config.vocab_size,
embedding_dim=config.embedding_dim,
)
self.pos_embedding = nn.Embedding(
num_embeddings=config.context_len,
embedding_dim=config.embedding_dim,
)
self.drop_embedding = nn.Dropout(config.drop_rate)
self.transformer_blocks = nn.Sequential(
*[GPTTransformerBlock(config=config) for _ in range(config.n_layers)],
)
self.final_norm = LayerNorm(config.embedding_dim)
self.out_head = nn.Linear(
config.embedding_dim,
config.vocab_size,
bias=False,
)
def forward(self, input_ids, attention_mask=None, labels=None):
batch_size, seq_len = input_ids.shape
token_emb = self.token_embedding(input_ids)
pos_emb = self.pos_embedding(torch.arange(seq_len, device=input_ids.device))
x = token_emb + pos_emb
x = self.drop_embedding(x)
x = self.transformer_blocks(x)
x = self.final_norm(x)
logits = self.out_head(x)
if labels is not None:
loss = torch.nn.functional.cross_entropy(
logits.flatten(0, 1),
labels.flatten(),
)
return {"logits": logits, "loss": loss}
return {"logits": logits}I already have a out_head linear layer to map the logits to vocab. But when I call model.can_generate() it returns False and I run into this error: Is there a way to fix it? |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Step 2 of #32685 - Removes the
GenerationMixininheritance fromPreTrainedModel. Instead, models classes with generative capabilities directly inheritGenerationMixin.Why?
Currently, we have a circular dependency between
PreTrainedModelandGenerationMixin:PreTrainedModel👈GenerationMixin:PreTrainedModelhas acan_generate()method, which depends on methods that exist inGenerationMixin. Depending on the value ofcan_generate(), it may hold aGenerationConfigobject.GenerationMixin👈PreTrainedModel:GenerationMixinneeded to inspect the type of the model instance, to throw informative exceptions at the user. This was needed because ALL our models could callgenerate, but most of them didn't support it.This PR breaks this circular dependency:
GenerationMixinbecomes a stand-alone class with no dependencies onPreTrainedModel. It is now a proper mixin: it may be used with other model base classes, if users desire to do so.PreTrainedModeldoesn't inheritGenerationMixin. This means that non-generative models will become less bloated :)What else can we improve as a result of this change?
can_generate()can be simplified: if a model is a subclass ofGenerationMixinthen it can generategenerate-- allGenerationMixinsubclasses can callgenerateprepare_inputs_for_generationinto the generation mixin 🧹 #32685 become much simpler to implement (can_generate()no longer depends onprepare_inputs_for_generation-> easier to make structural changes there) 🤗GenerationConfiginstance toGenerationMixin, so that non-generative models don't hold ageneration_configattribute.🚨🚨 Caveats 🚨🚨
The changes in this PR have no visible consequences in the following cases:
✅ A user loads a
transformersmodel, likeLlamaForCausalLM✅ A user loads custom modeling code from the hub with our auto classes, like this example
However, there are breaking changes in the following situations:
❌ A user has custom code, inheriting
PreTrainedModel, and wants to callgenerate