Skip to content

fix for fp16#134106

Closed
mayank31398 wants to merge 13 commits intopytorch:mainfrom
mayank31398:fix-rmsnorm
Closed

fix for fp16#134106
mayank31398 wants to merge 13 commits intopytorch:mainfrom
mayank31398:fix-rmsnorm

Conversation

@mayank31398
Copy link
Copy Markdown
Contributor

@mayank31398 mayank31398 commented Aug 21, 2024

This PR is a replacement for #133085 for pushing a quick fix for RMSNorm.
The original author is @kkontny

Previous PR summary:
Since FP16 has quite small dynamic range it is very easy to overflow while computing at::pow(input, 2) , and it happens in real world computation.

I've tried to use nn.RMSNorm fused implementation instead of LlamaRMSNorm inside transformers implementation of Llama (src/transformers/models/llama/modeling_llama.py). It started to give wrong answers in Fp16 while still giving good in FP32. I figured out happens due to overflow while computing square of the input tensor.

Original LLamaRMSNorm implementation upcasts input to fp32 to prevent this and give better numerical stability.

class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

Proposed commit fixed the issue. FP16 in RMSNorm has to be treated in special way, to be usable in real world implementations.

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @gujinghui @PenghuiCheng @jianyuh @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen @snadampal @voznesenskym @penguinwu @EikanWang @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @rec

@mayank31398 mayank31398 requested a review from mruberry as a code owner August 21, 2024 13:41
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Aug 21, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134106

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (9 Unrelated Failures)

As of commit 04c9c16 with merge base d7b57c4 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Copy Markdown
Contributor

@mikaylagawarecki mikaylagawarecki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

@mikaylagawarecki
Copy link
Copy Markdown
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 21, 2024
@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 21, 2024
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: 2 jobs have failed, first few of them are: trunk / macos-py3-arm64-mps / test (mps, 1, 1, macos-m1-13), trunk / macos-py3-arm64-mps / test (mps, 1, 1, macos-m1-14)

Details for Dev Infra team Raised by workflow job

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: 3 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@mayank31398
Copy link
Copy Markdown
Contributor Author

@mikaylagawarecki can you merge this?
I dont think the failing test is coming from my PR

@mayank31398
Copy link
Copy Markdown
Contributor Author

@mikaylagawarecki pinging again for quick resolution for merge.
Unsure about failing tests. pretty sure they are unrelated to this PR

@mikaylagawarecki
Copy link
Copy Markdown
Contributor

@pytorchbot merge -r

@mikaylagawarecki
Copy link
Copy Markdown
Contributor

mikaylagawarecki commented Aug 23, 2024

The failing tests look related to me test_modules.py::TestModuleMPS::test_forward_nn_RMSNorm_mps_float16, rebasing to check

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Successfully rebased fix-rmsnorm onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout fix-rmsnorm && git pull --rebase)

@mayank31398
Copy link
Copy Markdown
Contributor Author

@eqy @mikaylagawarecki pinging again

@mayank31398
Copy link
Copy Markdown
Contributor Author

guys can we get this merged?

@eqy
Copy link
Copy Markdown
Collaborator

eqy commented Sep 2, 2024

nit: could we use OpMath rather than hardcoding the half, float case here? e.g.,

struct OpMathType {

@eqy I dont understand how to use this. Can you give an example?

see e.g.,

using opmath_t = at::opmath_type<scalar_t>;

@mayank31398
Copy link
Copy Markdown
Contributor Author

mayank31398 commented Sep 2, 2024

@eqy @kkontny my knowledge of C++ isnt great

using opmath_t = opmath_type<scalar_t>;
upcasted_input = input.to(opmath_t())

this doesn't compile, can you push a fix?

@mayank31398
Copy link
Copy Markdown
Contributor Author

@eqy figured it out
I think its fixed now.

weight = m.weight
dims = [ndim - i - 1 for i in range(len(normalized_shape))]
result = i * torch.rsqrt(i.pow(2).mean(dim=dims, keepdim=True) + m.eps)
upcasted_i = i.float()
Copy link
Copy Markdown
Contributor

@malfet malfet Sep 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would fail if I is complex and would reduce the precision if I is double

Suggested change
upcasted_i = i.float()
upcasted_i = i.to(dtype=torch.float) if i.dtype == torch.half else i

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think half doesn't include bf16 right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, this is just for testcases which should pass

@mayank31398
Copy link
Copy Markdown
Contributor Author

@eqy pinging again

@mayank31398
Copy link
Copy Markdown
Contributor Author

failing tests seem unrelated @eqy @mikaylagawarecki

@mayank31398
Copy link
Copy Markdown
Contributor Author

@eqy any updates on this?

@mayank31398
Copy link
Copy Markdown
Contributor Author

@eqy @mikaylagawarecki
pinging again

@eqy
Copy link
Copy Markdown
Collaborator

eqy commented Sep 11, 2024

@pytorchmergebot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@mayank31398 mayank31398 deleted the fix-rmsnorm branch September 12, 2024 03:35
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
This PR is a replacement for pytorch#133085 for pushing a quick fix for RMSNorm.
The original author is @kkontny

Previous PR summary:
Since FP16 has quite small dynamic range it is very easy to overflow while computing `at::pow(input, 2)` , and it happens in real world computation.

I've tried to use `nn.RMSNorm` fused implementation instead of `LlamaRMSNorm` inside `transformers` implementation of Llama (`src/transformers/models/llama/modeling_llama.py`). It started to give wrong answers in Fp16 while still giving good in FP32. I figured out happens due to overflow while computing square of the input tensor.

Original `LLamaRMSNorm` implementation upcasts input to fp32 to prevent this and give better numerical stability.

```
class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)
```

Proposed commit fixed the issue. FP16 in RMSNorm has to be treated in special way, to be usable in real world implementations.

Pull Request resolved: pytorch#134106
Approved by: https://github.com/mikaylagawarecki, https://github.com/eqy
pytorchmergebot pushed a commit that referenced this pull request Mar 8, 2025
Fixes #134106. This PR moves the `upcasted_result` down-casting after all computation is done.

Since the multiplication with the weight_opt input is not done in half precision, the current code path is doing the following: fp16 -> fp32 -> fp16 -> fp32 -> fp16. What we want tho is to avoid down-casting and this PR proposes: fp16 -> fp32 -> fp16. This results in better accuracy as it avoids truncating.
Pull Request resolved: #147203
Approved by: https://github.com/eqy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: dynamo module: inductor module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: quantization release notes category topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants