Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Hi this makes sense. |
|
Should I run slow tests or can this be merged as-is? |
|
might be better to run slow tests on the granite class @NielsRogge |
|
Seems like the slow tests are failing (cc @ydshieh), but I assume it's safe to merge this PR since the following passes from torch import nn
import torch
class GraniteRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
GraniteRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
a = torch.nn.RMSNorm(10)
b = GraniteRMSNorm(10)
assert a.weight.shape == b.weight.shape
c = torch.randn(1, 10)
assert torch.allclose(a(c), b(c)) |
|
when will this release to fix transformers? |
|
@ArthurZucker can you merge this? |
* first commit * drop tokenizer * drop tokenizer * drop tokenizer * drop convert * granite * drop tokenization test * mup * fix * reformat * reformat * reformat * fix docs * stop checking for checkpoint * update support * attention multiplier * update model * tiny drop * saibo drop * skip test * fix test * fix test * drop * drop useless imports * update docs * drop flash function * copied from * drop pretraining tp * drop pretraining tp * drop pretraining tp * drop unused import * drop code path * change name * softmax scale * head dim * drop legacy cache * rename params * cleanup * fix copies * comments * add back legacy cache * multipliers * multipliers * multipliers * text fix * fix copies * merge * multipliers * attention multiplier * drop unused imports * fix * fix * fix * move rope? * Update src/transformers/models/granite/configuration_granite.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix * Update src/transformers/models/granite/modeling_granite.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix * fix * fix * fix * fix-copies * torch rmsnorm * add authors * change model path * fix * test * drop static cache test * uupdate readme * drop non-causal * readme * drop useless imports * Update docs/source/en/model_doc/granite.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update docs/source/en/model_doc/granite.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update docs/source/en/model_doc/granite.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
|
Okay |
|
cc @ydshieh if you see failures on this, it's expected! |
|
Thanks @ArthurZucker , Ill fix this test in a new PR |
* Add GraniteRMSNorm * [run_slow] granite
What does this PR do?
This PR is a follow-up of #31502 which broke Transformers for PyTorch < 2.4.