Fixing Pytorch RMS norm implementation#133085
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/133085
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 6458021 with merge base b7bcfda ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
I have to process the CLA through the company legal. Unfortunately it will be probably done early next week. |
|
@kkontny would you mind if I open this PR again instead? |
@mayank31398 Please go on, I tried to hurry them, but with no effect. It takes much more time than it should... |
This PR is a replacement for #133085 for pushing a quick fix for RMSNorm. The original author is @kkontny Previous PR summary: Since FP16 has quite small dynamic range it is very easy to overflow while computing `at::pow(input, 2)` , and it happens in real world computation. I've tried to use `nn.RMSNorm` fused implementation instead of `LlamaRMSNorm` inside `transformers` implementation of Llama (`src/transformers/models/llama/modeling_llama.py`). It started to give wrong answers in Fp16 while still giving good in FP32. I figured out happens due to overflow while computing square of the input tensor. Original `LLamaRMSNorm` implementation upcasts input to fp32 to prevent this and give better numerical stability. ``` class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ LlamaRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) ``` Proposed commit fixed the issue. FP16 in RMSNorm has to be treated in special way, to be usable in real world implementations. Pull Request resolved: #134106 Approved by: https://github.com/mikaylagawarecki, https://github.com/eqy
|
Fixed via #134106 |
This PR is a replacement for pytorch#133085 for pushing a quick fix for RMSNorm. The original author is @kkontny Previous PR summary: Since FP16 has quite small dynamic range it is very easy to overflow while computing `at::pow(input, 2)` , and it happens in real world computation. I've tried to use `nn.RMSNorm` fused implementation instead of `LlamaRMSNorm` inside `transformers` implementation of Llama (`src/transformers/models/llama/modeling_llama.py`). It started to give wrong answers in Fp16 while still giving good in FP32. I figured out happens due to overflow while computing square of the input tensor. Original `LLamaRMSNorm` implementation upcasts input to fp32 to prevent this and give better numerical stability. ``` class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ LlamaRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) ``` Proposed commit fixed the issue. FP16 in RMSNorm has to be treated in special way, to be usable in real world implementations. Pull Request resolved: pytorch#134106 Approved by: https://github.com/mikaylagawarecki, https://github.com/eqy
Since FP16 has quite small dynamic range it is very easy to overflow while computing
at::pow(input, 2), and it happens in real world computation.I've tried to use
nn.RMSNormfused implementation instead ofLlamaRMSNorminsidetransformersimplementation of Llama (src/transformers/models/llama/modeling_llama.py). It started to give wrong answers in Fp16 while still giving good in FP32. I figured out happens due to overflow while computing square of the input tensor.Original
LLamaRMSNormimplementation upcasts input to fp32 to prevent this and give better numerical stability.Proposed commit fixed the issue. FP16 in RMSNorm has to be treated in special way, to be usable in real world implementations.