Skip to content

Add ModernBERT to Transformers#35158

Merged
tomaarsen merged 91 commits intohuggingface:mainfrom
AnswerDotAI:modernbert
Dec 19, 2024
Merged

Add ModernBERT to Transformers#35158
tomaarsen merged 91 commits intohuggingface:mainfrom
AnswerDotAI:modernbert

Conversation

@warner-benjamin
Copy link
Copy Markdown
Contributor

This PR will add ModernBERT to Transformers.

@ArthurZucker
Copy link
Copy Markdown
Collaborator

cc @Cyrilvallez ! 🤗

@tomaarsen tomaarsen self-requested a review December 10, 2024 15:09
Copy link
Copy Markdown
Member

@tomaarsen tomaarsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beyond the obvious (sdpa, eager, flex attention, and documentation), I haven't seen anything outrageous or very unexpected in my first scroll-through.
I recognize that this implementation goes a bit beyond our "usual" with unpadding/padding when possible, but I personally don't mind. Beyond this change (and the other obvious upgrades like RoPE), I quite like how this still mirrors the original BERT rather closely.

I'll have to actually start running this to get a better feel, but so far so good.

Also, the SequenceClassification and TokenClassification classes don't exist yet.

Comment thread src/transformers/models/modernbert/modular_modernbert.py Outdated
Comment thread src/transformers/models/auto/modeling_auto.py Outdated
Comment thread src/transformers/models/auto/modeling_auto.py Outdated
@tomaarsen
Copy link
Copy Markdown
Member

tomaarsen commented Dec 10, 2024

@ArthurZucker @Cyrilvallez
ModernBERT requires no token_type_ids, but the tokenizers rely on PreTrainedTokenizerFast. By default, this produces token_type_ids. Is it preferable that we:

  1. Use model_input_names in all config.json of all ModernBERT models. This means that "fresh tokenizers" won't work out of the box, but people don't normally make fresh tokenizers.
  2. Create a custom ModernBertTokenizerFast that is literally just:
class ModernBertTokenizerFast(PreTrainedTokenizerFast):
    model_input_names = ["input_ids", "attention_mask"]

cc @warner-benjamin @orionw

  • Tom Aarsen

Comment thread src/transformers/models/auto/tokenization_auto.py Outdated
Comment thread src/transformers/models/modernbert/modular_modernbert.py
Comment thread src/transformers/models/modernbert/modular_modernbert.py Outdated
Comment thread src/transformers/models/modernbert/modular_modernbert.py Outdated
Comment thread src/transformers/models/modernbert/modular_modernbert.py Outdated
@tomaarsen
Copy link
Copy Markdown
Member

tomaarsen commented Dec 11, 2024

With all of these changes in place, I was able to confirm that the output to one of the trained models using the original research implementation nearly matches the output of the transformers ModernBERT-converted model. The only difference that remains is that the research implementation fuses the self.mlp(self.mlp_norm(...)) using @torch.compile(dynamic=True). I also only tested with a small input - I still have to test 1) larger inputs, 2) truncation, 3) batches, etc.

Do we allow something like this to get an exact 1-1 match?

    @torch.compile(dynamic=True)
    def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.mlp(self.mlp_norm(hidden_states))

Here's an indication of the difference between with and without:

tensor([[[ 0.0078,  0.0078,  0.0039,  ...,  0.0117,  0.0342,  0.0039],
         [ 0.0156,  0.0000,  0.0234,  ...,  0.0000,  0.0156, -0.0273],
         [-0.0015, -0.0010, -0.0020,  ...,  0.0000,  0.0022, -0.0015],
         ...,
         [-0.0195,  0.0000, -0.0029,  ..., -0.0156,  0.0146,  0.0039],
         [ 0.0078, -0.0273, -0.0059,  ..., -0.0215,  0.0103,  0.0088],
         [-0.0020,  0.0005, -0.0005,  ..., -0.0012, -0.0012, -0.0020]]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<SubBackward0>)
Min: -0.375
Max: 0.140625
Mean: -8.249282836914062e-05
Std: 0.01708984375
  • Tom Aarsen

Comment thread src/transformers/models/modernbert/modular_modernbert.py Outdated
Comment thread src/transformers/models/modernbert/modular_modernbert.py Outdated
Comment thread src/transformers/models/modernbert/configuration_modernbert.py Outdated
Comment thread src/transformers/models/modernbert/configuration_modernbert.py Outdated
@warner-benjamin
Copy link
Copy Markdown
Contributor Author

the model with FA2 and the RoPE kernel are not torch.compile compatible, we can't compile the whole model while using these.

FA2 is compatible now, but the FA RoPE kernel isn't yet. I have a in progress fix I need to get merged into the FA repo.

@warner-benjamin
Copy link
Copy Markdown
Contributor Author

We want mean pooling as an option for classification because Local Attention means unlike BERT the CLS token doesn't see all the output in all the attention layers, so mean pooling could outperform CLS on greater than 128 token sequences.

Also, I added the pooling head to TokenClassification because otherwise we are throwing away one pretrained linear layer ModernBertPoolingHead.dense.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super cool! 2 things left: 1. remove the gradient thing: padding and unpadding is not model weight dependant, and should never have gradients.
2. remove the 2 functions, and just call the Head ClsPooling or MeanPooling depending on the one that was release / most common cf our offline discussion @tomaarsen

Comment thread src/transformers/models/modernbert/modular_modernbert.py Outdated
Comment thread src/transformers/models/modernbert/modular_modernbert.py Outdated
Comment thread src/transformers/models/modernbert/modular_modernbert.py Outdated
@ArthurZucker
Copy link
Copy Markdown
Collaborator

Great addition! Thanks all for your hard work! 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants