Skip to content

Commit c79577d

Browse files
committed
Fixing Pytorch RMS norm implementation
1 parent b7bcfda commit c79577d

1 file changed

Lines changed: 6 additions & 1 deletion

File tree

aten/src/ATen/native/layer_norm.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,13 @@ Tensor rms_norm(
294294
} else {
295295
eps_val = eps.value();
296296
}
297+
at::Tensor result;
297298

298-
auto result = input.mul(at::rsqrt(at::pow(input, 2).mean(dims_to_reduce_ref, /*keep_dim=*/true).add_(eps_val)));
299+
if constexpr (std::is_same_v<scalar_t, c10::Half>) {
300+
result = input.mul(at::rsqrt(at::pow(input.to(at::ScalarType::Float), 2).mean(dims_to_reduce_ref, /*keep_dim=*/true).add_(eps_val)).to(at::ScalarType::Half));
301+
} else {
302+
result = input.mul(at::rsqrt(at::pow(input, 2).mean(dims_to_reduce_ref, /*keep_dim=*/true).add_(eps_val)));
303+
}
299304

300305
if (weight_opt.has_value()) {
301306
result = result.mul(weight_opt.value());

0 commit comments

Comments
 (0)