@@ -402,25 +402,25 @@ kernel void round_decimals_strided(
402402 rint (exp10 (float (ndigits)) * input[input_offs]) * exp10 (float (-ndigits)));
403403}
404404
405- #define INSTANTIATE_ROUND_DECIMALS (DTYPE ) \
406- template \
407- [[host_name( " round_decimals_dense_ " #DTYPE " _ " #DTYPE )]] kernel void \
408- round_decimals_dense ( \
409- device DTYPE* output [[buffer(0 )]], \
410- constant DTYPE* input [[buffer(1 )]], \
411- constant long& ndigits [[buffer(2 )]], \
412- uint index [[thread_position_in_grid]]); \
413- template \
414- [[host_name( " round_decimals_strided_ " #DTYPE " _ " #DTYPE )]] kernel void \
415- round_decimals_strided ( \
416- device DTYPE* output [[buffer(0 )]], \
417- constant DTYPE* input [[buffer(1 )]], \
418- constant long* sizes, \
419- constant long* input_strides, \
420- constant long* output_strides, \
421- constant uint& ndim, \
422- constant long& ndigits [[buffer(6 )]], \
423- uint index)
405+ #define INSTANTIATE_ROUND_DECIMALS (DTYPE ) \
406+ template [[host_name( " round_decimals_dense_ " #DTYPE " _ " #DTYPE \
407+ " _long " )]] kernel void \
408+ round_decimals_dense ( \
409+ device DTYPE* output [[buffer(0 )]], \
410+ constant DTYPE* input [[buffer(1 )]], \
411+ constant long& ndigits [[buffer(2 )]], \
412+ uint index [[thread_position_in_grid]]); \
413+ template [[host_name( " round_decimals_strided_ " #DTYPE " _ " #DTYPE \
414+ " _long " )]] kernel void \
415+ round_decimals_strided ( \
416+ device DTYPE* output [[buffer(0 )]], \
417+ constant DTYPE* input [[buffer(1 )]], \
418+ constant long* sizes, \
419+ constant long* input_strides, \
420+ constant long* output_strides, \
421+ constant uint& ndim, \
422+ constant long& ndigits [[buffer(6 )]], \
423+ uint index)
424424
425425INSTANTIATE_ROUND_DECIMALS(float );
426426INSTANTIATE_ROUND_DECIMALS (half);
0 commit comments