Skip to content

Commit 7fd4753

Browse files
riccardofellugapytorchmergebot
authored andcommitted
keep rms computation in full precision
1 parent 81847d0 commit 7fd4753

2 files changed

Lines changed: 5 additions & 6 deletions

File tree

aten/src/ATen/native/layer_norm.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -315,16 +315,16 @@ Tensor rms_norm_symint(
315315
rqrst_input = rsqrt(at::pow(upcasted_input, 2).mean(dims_to_reduce_ref, /*keepdim=*/true).add_(eps_val));
316316
}
317317

318-
Tensor result = upcasted_input.mul(rqrst_input).type_as(input);
318+
Tensor upcasted_result = upcasted_input.mul(rqrst_input);
319319

320320
if (weight_opt.has_value()) {
321-
result = result.mul(weight_opt.value());
321+
upcasted_result = upcasted_result.mul(weight_opt.value());
322322
}
323323

324-
return result;
324+
return upcasted_result;
325325
});
326326

327-
return result;
327+
return result.type_as(input);
328328

329329
}
330330
} // namespace at::native

torch/testing/_internal/common_modules.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1945,10 +1945,9 @@ def rms_norm_reference_fn(m, p, i):
19451945
dims = [ndim - i - 1 for i in range(len(normalized_shape))]
19461946
upcasted_i = i.float()
19471947
result = upcasted_i * torch.rsqrt(upcasted_i.pow(2).mean(dim=dims, keepdim=True) + m.eps)
1948-
result = result.type_as(i)
19491948
if weight is not None:
19501949
result *= weight
1951-
return result
1950+
return result.type_as(i)
19521951

19531952
return [
19541953
ModuleInput(

0 commit comments

Comments
 (0)