Skip to content

Commit 525ae01

Browse files
Update
[ghstack-poisoned]
1 parent 5314497 commit 525ae01

1 file changed

Lines changed: 19 additions & 19 deletions

File tree

aten/src/ATen/native/mps/kernels/UnaryKernel.metal

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -402,25 +402,25 @@ kernel void round_decimals_strided(
402402
rint(exp10(float(ndigits)) * input[input_offs]) * exp10(float(-ndigits)));
403403
}
404404

405-
#define INSTANTIATE_ROUND_DECIMALS(DTYPE) \
406-
template \
407-
[[host_name("round_decimals_dense_" #DTYPE "_" #DTYPE)]] kernel void \
408-
round_decimals_dense( \
409-
device DTYPE* output [[buffer(0)]], \
410-
constant DTYPE* input [[buffer(1)]], \
411-
constant long& ndigits [[buffer(2)]], \
412-
uint index [[thread_position_in_grid]]); \
413-
template \
414-
[[host_name("round_decimals_strided_" #DTYPE "_" #DTYPE)]] kernel void \
415-
round_decimals_strided( \
416-
device DTYPE* output [[buffer(0)]], \
417-
constant DTYPE* input [[buffer(1)]], \
418-
constant long* sizes, \
419-
constant long* input_strides, \
420-
constant long* output_strides, \
421-
constant uint& ndim, \
422-
constant long& ndigits [[buffer(6)]], \
423-
uint index)
405+
#define INSTANTIATE_ROUND_DECIMALS(DTYPE) \
406+
template [[host_name("round_decimals_dense_" #DTYPE "_" #DTYPE \
407+
"_long")]] kernel void \
408+
round_decimals_dense( \
409+
device DTYPE* output [[buffer(0)]], \
410+
constant DTYPE* input [[buffer(1)]], \
411+
constant long& ndigits [[buffer(2)]], \
412+
uint index [[thread_position_in_grid]]); \
413+
template [[host_name("round_decimals_strided_" #DTYPE "_" #DTYPE \
414+
"_long")]] kernel void \
415+
round_decimals_strided( \
416+
device DTYPE* output [[buffer(0)]], \
417+
constant DTYPE* input [[buffer(1)]], \
418+
constant long* sizes, \
419+
constant long* input_strides, \
420+
constant long* output_strides, \
421+
constant uint& ndim, \
422+
constant long& ndigits [[buffer(6)]], \
423+
uint index)
424424

425425
INSTANTIATE_ROUND_DECIMALS(float);
426426
INSTANTIATE_ROUND_DECIMALS(half);

0 commit comments

Comments
 (0)