Skip to content

Commit 5ac55f9

Browse files
committed
Update on "[MPS][BE][EZ] Aggregate macros"
Refactor `INSTANTIATE_UPSAMPLE_BILINEAR2D(DTYPE)`, `INSTANTIATE_UPSAMPLE_BICUBIC2D(DTYPE)` and `INSTANTIATE_UPSAMPLE_BILINEAR2DAA(DTYPE)` use common `INSTANTIATE_UPSAMPLE2D` Then combine multiple invocations into `INSTANTIATE_UPSAMPLE_ALL` I.e. functionally it's a no-op, but achieves the same with fewer lines of code [ghstack-poisoned]
2 parents f88f025 + 01ebaa8 commit 5ac55f9

1 file changed

Lines changed: 11 additions & 11 deletions

File tree

aten/src/ATen/native/mps/kernels/UpSample.metal

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -456,17 +456,17 @@ kernel void upsample_bicubic2d_backward(
456456
constant bool& align_corners [[buffer(7)]], \
457457
uint thread_index [[thread_position_in_grid]])
458458

459-
#define INSTANTIATE_UPSAMPLE_2D_BACKWARD(NAME, DTYPE) \
460-
template [[host_name("upsample" #NAME "_backward_" #DTYPE)]] kernel void \
461-
upsample_##NAME##_backward<DTYPE>( \
462-
device AtomicType_t<DTYPE> * gradInputData [[buffer(0)]], \
463-
constant DTYPE * gradOutputData [[buffer(1)]], \
464-
constant ulong4 & input_strides [[buffer(2)]], \
465-
constant ulong4 & output_strides [[buffer(3)]], \
466-
constant long4 & input_sizes [[buffer(4)]], \
467-
constant long4 & output_sizes [[buffer(5)]], \
468-
constant float2 & scales [[buffer(6)]], \
469-
constant bool& align_corners [[buffer(7)]], \
459+
#define INSTANTIATE_UPSAMPLE_2D_BACKWARD(NAME, DTYPE) \
460+
template [[host_name("upsample_" #NAME "_backward_" #DTYPE)]] kernel void \
461+
upsample_##NAME##_backward<DTYPE>( \
462+
device AtomicType_t<DTYPE> * gradInputData [[buffer(0)]], \
463+
constant DTYPE * gradOutputData [[buffer(1)]], \
464+
constant ulong4 & input_strides [[buffer(2)]], \
465+
constant ulong4 & output_strides [[buffer(3)]], \
466+
constant long4 & input_sizes [[buffer(4)]], \
467+
constant long4 & output_sizes [[buffer(5)]], \
468+
constant float2 & scales [[buffer(6)]], \
469+
constant bool& align_corners [[buffer(7)]], \
470470
uint thread_index [[thread_position_in_grid]])
471471

472472
#define INSTANTIATE_UPSAMPLE_LINEAR(DTYPE) \

0 commit comments

Comments
 (0)