Skip to content

Commit 047c823

Browse files
committed
Fix failling MacOS 12 test
1 parent be120dd commit 047c823

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

aten/src/ATen/native/mps/operations/Convolution.mm

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -548,12 +548,12 @@ static Tensor mps_convolution_backward_weights(IntArrayRef weight_size,
548548
key = "mps_3d_convolution_backward_weights:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + to_string(stride[2]) + ":" +
549549
to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(dilation[2]) + ":" + to_string(padding[0]) + ":" +
550550
to_string(padding[1]) + ":" + to_string(padding[2]) + ":" + to_string(groups) + ":" + mem_format_key +
551-
getTensorsStringKey({grad_output_t, input_t}) + ":" + string([ns_shape_key UTF8String]);
551+
getTensorsStringKey({grad_output_t, input_t, grad_weight_t}) + ":" + string([ns_shape_key UTF8String]);
552552
} else {
553553
key = "mps_convolution_backward_weights:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" +
554554
to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" +
555555
to_string(padding[1]) + ":" + to_string(groups) + ":" + mem_format_key +
556-
getTensorsStringKey({grad_output_t, input_t}) + ":" + string([ns_shape_key UTF8String]);
556+
getTensorsStringKey({grad_output_t, input_t, grad_weight_t}) + ":" + string([ns_shape_key UTF8String]);
557557
}
558558
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
559559

0 commit comments

Comments
 (0)