Conversation
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 jobs have failed, first few of them are: trunk / macos-py3-arm64-mps / test (mps, 1, 1, macos-m1-13), trunk / macos-py3-arm64-mps / test (mps, 1, 1, macos-m1-14) Details for Dev Infra teamRaised by workflow job |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 3 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
|
@mikaylagawarecki can you merge this? |
|
@mikaylagawarecki pinging again for quick resolution for merge. |
|
@pytorchbot merge -r |
|
The failing tests look related to me |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Successfully rebased |
0a3e470 to
93a13e6
Compare
|
@eqy @mikaylagawarecki pinging again |
|
guys can we get this merged? |
see e.g., |
38b7869 to
fd44b4b
Compare
|
@eqy figured it out |
| weight = m.weight | ||
| dims = [ndim - i - 1 for i in range(len(normalized_shape))] | ||
| result = i * torch.rsqrt(i.pow(2).mean(dim=dims, keepdim=True) + m.eps) | ||
| upcasted_i = i.float() |
There was a problem hiding this comment.
This would fail if I is complex and would reduce the precision if I is double
| upcasted_i = i.float() | |
| upcasted_i = i.to(dtype=torch.float) if i.dtype == torch.half else i |
There was a problem hiding this comment.
i think half doesn't include bf16 right?
There was a problem hiding this comment.
also, this is just for testcases which should pass
0cb9b84 to
fd44b4b
Compare
|
@eqy pinging again |
|
failing tests seem unrelated @eqy @mikaylagawarecki |
|
@eqy any updates on this? |
|
@eqy @mikaylagawarecki |
|
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR is a replacement for pytorch#133085 for pushing a quick fix for RMSNorm. The original author is @kkontny Previous PR summary: Since FP16 has quite small dynamic range it is very easy to overflow while computing `at::pow(input, 2)` , and it happens in real world computation. I've tried to use `nn.RMSNorm` fused implementation instead of `LlamaRMSNorm` inside `transformers` implementation of Llama (`src/transformers/models/llama/modeling_llama.py`). It started to give wrong answers in Fp16 while still giving good in FP32. I figured out happens due to overflow while computing square of the input tensor. Original `LLamaRMSNorm` implementation upcasts input to fp32 to prevent this and give better numerical stability. ``` class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ LlamaRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) ``` Proposed commit fixed the issue. FP16 in RMSNorm has to be treated in special way, to be usable in real world implementations. Pull Request resolved: pytorch#134106 Approved by: https://github.com/mikaylagawarecki, https://github.com/eqy
Fixes #134106. This PR moves the `upcasted_result` down-casting after all computation is done. Since the multiplication with the weight_opt input is not done in half precision, the current code path is doing the following: fp16 -> fp32 -> fp16 -> fp32 -> fp16. What we want tho is to avoid down-casting and this PR proposes: fp16 -> fp32 -> fp16. This results in better accuracy as it avoids truncating. Pull Request resolved: #147203 Approved by: https://github.com/eqy
This PR is a replacement for #133085 for pushing a quick fix for RMSNorm.
The original author is @kkontny
Previous PR summary:
Since FP16 has quite small dynamic range it is very easy to overflow while computing
at::pow(input, 2), and it happens in real world computation.I've tried to use
nn.RMSNormfused implementation instead ofLlamaRMSNorminsidetransformersimplementation of Llama (src/transformers/models/llama/modeling_llama.py). It started to give wrong answers in Fp16 while still giving good in FP32. I figured out happens due to overflow while computing square of the input tensor.Original
LLamaRMSNormimplementation upcasts input to fp32 to prevent this and give better numerical stability.Proposed commit fixed the issue. FP16 in RMSNorm has to be treated in special way, to be usable in real world implementations.
cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @gujinghui @PenghuiCheng @jianyuh @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen @snadampal @voznesenskym @penguinwu @EikanWang @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @rec