Skip to content

Fixing Pytorch RMS norm implementation#133085

Closed
kkontny wants to merge 2 commits intopytorch:mainfrom
kkontny:ampere/fix-rms-norm
Closed

Fixing Pytorch RMS norm implementation#133085
kkontny wants to merge 2 commits intopytorch:mainfrom
kkontny:ampere/fix-rms-norm

Conversation

@kkontny
Copy link
Copy Markdown

@kkontny kkontny commented Aug 9, 2024

Since FP16 has quite small dynamic range it is very easy to overflow while computing at::pow(input, 2) , and it happens in real world computation.

I've tried to use nn.RMSNorm fused implementation instead of LlamaRMSNorm inside transformers implementation of Llama (src/transformers/models/llama/modeling_llama.py). It started to give wrong answers in Fp16 while still giving good in FP32. I figured out happens due to overflow while computing square of the input tensor.

Original LLamaRMSNorm implementation upcasts input to fp32 to prevent this and give better numerical stability.

class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

Proposed commit fixed the issue. FP16 in RMSNorm has to be treated in special way, to be usable in real world implementations.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Aug 9, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/133085

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6458021 with merge base b7bcfda (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link
Copy Markdown

linux-foundation-easycla bot commented Aug 9, 2024

CLA Not Signed

@kkontny kkontny requested a review from mruberry as a code owner August 9, 2024 12:27
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 9, 2024
Copy link
Copy Markdown
Contributor

@mikaylagawarecki mikaylagawarecki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you sign CLA in #133085 (comment) please

@kkontny
Copy link
Copy Markdown
Author

kkontny commented Aug 9, 2024

Could you sign CLA in #133085 (comment) please

I have to process the CLA through the company legal. Unfortunately it will be probably done early next week.

@mayank31398
Copy link
Copy Markdown
Contributor

@kkontny would you mind if I open this PR again instead?
Its blocking me as well, want to get it in as soon as possible

@kkontny
Copy link
Copy Markdown
Author

kkontny commented Aug 21, 2024

@kkontny would you mind if I open this PR again instead? Its blocking me as well, want to get it in as soon as possible

@mayank31398 Please go on, I tried to hurry them, but with no effect. It takes much more time than it should...

@mayank31398 mayank31398 mentioned this pull request Aug 21, 2024
@mayank31398
Copy link
Copy Markdown
Contributor

@kkontny I have created a PR here: #134106

pytorchmergebot pushed a commit that referenced this pull request Sep 11, 2024
This PR is a replacement for #133085 for pushing a quick fix for RMSNorm.
The original author is @kkontny

Previous PR summary:
Since FP16 has quite small dynamic range it is very easy to overflow while computing `at::pow(input, 2)` , and it happens in real world computation.

I've tried to use `nn.RMSNorm` fused implementation instead of `LlamaRMSNorm` inside `transformers` implementation of Llama (`src/transformers/models/llama/modeling_llama.py`). It started to give wrong answers in Fp16 while still giving good in FP32. I figured out happens due to overflow while computing square of the input tensor.

Original `LLamaRMSNorm` implementation upcasts input to fp32 to prevent this and give better numerical stability.

```
class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)
```

Proposed commit fixed the issue. FP16 in RMSNorm has to be treated in special way, to be usable in real world implementations.

Pull Request resolved: #134106
Approved by: https://github.com/mikaylagawarecki, https://github.com/eqy
@mayank31398
Copy link
Copy Markdown
Contributor

@kkontny I think we can close this.
this has been fixed now in #134106

@kkontny
Copy link
Copy Markdown
Author

kkontny commented Sep 18, 2024

Fixed via #134106

@kkontny kkontny closed this Sep 18, 2024
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
This PR is a replacement for pytorch#133085 for pushing a quick fix for RMSNorm.
The original author is @kkontny

Previous PR summary:
Since FP16 has quite small dynamic range it is very easy to overflow while computing `at::pow(input, 2)` , and it happens in real world computation.

I've tried to use `nn.RMSNorm` fused implementation instead of `LlamaRMSNorm` inside `transformers` implementation of Llama (`src/transformers/models/llama/modeling_llama.py`). It started to give wrong answers in Fp16 while still giving good in FP32. I figured out happens due to overflow while computing square of the input tensor.

Original `LLamaRMSNorm` implementation upcasts input to fp32 to prevent this and give better numerical stability.

```
class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)
```

Proposed commit fixed the issue. FP16 in RMSNorm has to be treated in special way, to be usable in real world implementations.

Pull Request resolved: pytorch#134106
Approved by: https://github.com/mikaylagawarecki, https://github.com/eqy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants