Skip to content

Commit c8c3925

Browse files
Update
[ghstack-poisoned]
1 parent 0b437fc commit c8c3925

1 file changed

Lines changed: 5 additions & 10 deletions

File tree

aten/src/ATen/native/mps/kernels/ActivationKernel.metal

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,26 +33,21 @@ REGISTER_BINARY_ALPHA_OP(hardshrink_backward, bfloat, bfloat, bfloat);
3333
struct hardsigmoid_functor {
3434
template <typename T>
3535
inline T operator()(const T x) {
36-
T zero(0);
37-
T three(3);
38-
T six(6);
39-
T result = min(max(x + three, zero), six) / six;
40-
return result;
36+
return static_cast<T>(min(max(x + 3.0f, .0f), 6.f) / 6.f);
4137
}
4238
};
4339

4440
struct hardsigmoid_backward_functor {
4541
template <typename T>
4642
inline T operator()(const T grad_output, const T self) {
47-
T zero(0);
48-
T one_sixth(T(1.0 / 6.0));
49-
T neg_three(-3);
50-
T three(3);
43+
constexpr T zero(0);
44+
constexpr T neg_three(-3);
45+
constexpr T three(3);
5146

5247
if (self < neg_three || self > three) {
5348
return zero;
5449
} else {
55-
return grad_output * one_sixth;
50+
return static_cast<T>(grad_output * (1.0f / 6.0f));
5651
}
5752
}
5853
};

0 commit comments

Comments
 (0)