File tree Expand file tree Collapse file tree
aten/src/ATen/native/mps/operations Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -548,12 +548,12 @@ static Tensor mps_convolution_backward_weights(IntArrayRef weight_size,
548548 key = " mps_3d_convolution_backward_weights:" + to_string (stride[0 ]) + " :" + to_string (stride[1 ]) + " :" + to_string (stride[2 ]) + " :" +
549549 to_string (dilation[0 ]) + " :" + to_string (dilation[1 ]) + " :" + to_string (dilation[2 ]) + " :" + to_string (padding[0 ]) + " :" +
550550 to_string (padding[1 ]) + " :" + to_string (padding[2 ]) + " :" + to_string (groups) + " :" + mem_format_key +
551- getTensorsStringKey ({grad_output_t , input_t }) + " :" + string ([ns_shape_key UTF8String ]);
551+ getTensorsStringKey ({grad_output_t , input_t , grad_weight_t }) + " :" + string ([ns_shape_key UTF8String ]);
552552 } else {
553553 key = " mps_convolution_backward_weights:" + to_string (stride[0 ]) + " :" + to_string (stride[1 ]) + " :" +
554554 to_string (dilation[0 ]) + " :" + to_string (dilation[1 ]) + " :" + to_string (padding[0 ]) + " :" +
555555 to_string (padding[1 ]) + " :" + to_string (groups) + " :" + mem_format_key +
556- getTensorsStringKey ({grad_output_t , input_t }) + " :" + string ([ns_shape_key UTF8String ]);
556+ getTensorsStringKey ({grad_output_t , input_t , grad_weight_t }) + " :" + string ([ns_shape_key UTF8String ]);
557557 }
558558 auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
559559
You can’t perform that action at this time.
0 commit comments