@@ -204,6 +204,20 @@ Tensor sub_Tensor(const Tensor& input1, const Tensor& input2, Scalar alpha) {
204204 }
205205}
206206
207+ Tensor& sub__Tensor (Tensor& input1, const Tensor& input2, Scalar alpha) {
208+ TORCH_CHECK (input1.is_metal ());
209+ TORCH_CHECK (input1.dim () == input2.dim ());
210+ TORCH_CHECK (input1.sizes ()[0 ] == input2.sizes ()[0 ]);
211+ TORCH_CHECK (input1.sizes ()[1 ] == input2.sizes ()[1 ]);
212+ auto input2_ = input2.is_metal () ? input2 : input2.metal ();
213+ if (@available (iOS 11.3 , *)) {
214+ return binaryElementwiseMPSCNNKernel_<MPSCNNSubtract>(input1, input2_);
215+ } else {
216+ return binaryElementwiseShaderKernel_ (
217+ input1, input2_, @" elementwise_sub" , @" elementwise_sub_nonarray" );
218+ }
219+ }
220+
207221Tensor mul_Tensor (const Tensor& input1, const Tensor& input2) {
208222 TORCH_CHECK (input1.is_metal ());
209223 TORCH_CHECK (input1.dim () == input2.dim ());
@@ -218,11 +232,57 @@ Tensor mul_Tensor(const Tensor& input1, const Tensor& input2) {
218232 }
219233}
220234
235+ Tensor& mul__Tensor (Tensor& input1, const Tensor& input2) {
236+ TORCH_CHECK (input1.is_metal ());
237+ TORCH_CHECK (input1.dim () == input2.dim ());
238+ TORCH_CHECK (input1.sizes ()[0 ] == input2.sizes ()[0 ]);
239+ TORCH_CHECK (input1.sizes ()[1 ] == input2.sizes ()[1 ]);
240+ auto input2_ = input2.is_metal () ? input2 : input2.metal ();
241+ if (@available (iOS 11.3 , *)) {
242+ return binaryElementwiseMPSCNNKernel_<MPSCNNMultiply>(input1, input2_);
243+ } else {
244+ return binaryElementwiseShaderKernel_ (
245+ input1, input2_, @" elementwise_mul" , @" elementwise_mul_nonarray" );
246+ }
247+ }
248+
249+ Tensor div_Tensor (const Tensor& input1, const Tensor& input2) {
250+ TORCH_CHECK (input1.is_metal ());
251+ TORCH_CHECK (input1.dim () == input2.dim ());
252+ TORCH_CHECK (input1.sizes ()[0 ] == input2.sizes ()[0 ]);
253+ TORCH_CHECK (input1.sizes ()[1 ] == input2.sizes ()[1 ]);
254+ auto input2_ = input2.is_metal () ? input2 : input2.metal ();
255+ if (@available (iOS 11.3 , *)) {
256+ return binaryElementwiseMPSCNNKernel<MPSCNNDivide>(input1, input2_);
257+ } else {
258+ return binaryElementwiseShaderKernel (
259+ input1, input2_, @" elementwise_div" , @" elementwise_div_nonarray" );
260+ }
261+ }
262+
263+ Tensor& div__Tensor (Tensor& input1, const Tensor& input2) {
264+ TORCH_CHECK (input1.is_metal ());
265+ TORCH_CHECK (input1.dim () == input2.dim ());
266+ TORCH_CHECK (input1.sizes ()[0 ] == input2.sizes ()[0 ]);
267+ TORCH_CHECK (input1.sizes ()[1 ] == input2.sizes ()[1 ]);
268+ auto input2_ = input2.is_metal () ? input2 : input2.metal ();
269+ if (@available (iOS 11.3 , *)) {
270+ return binaryElementwiseMPSCNNKernel_<MPSCNNDivide>(input1, input2_);
271+ } else {
272+ return binaryElementwiseShaderKernel_ (
273+ input1, input2_, @" elementwise_div" , @" elementwise_div_nonarray" );
274+ }
275+ }
276+
221277TORCH_LIBRARY_IMPL (aten, Metal, m) {
222278 m.impl (" add.Tensor" , TORCH_FN (add_Tensor));
223279 m.impl (" add_.Tensor" , TORCH_FN (add__Tensor));
224280 m.impl (" mul.Tensor" , TORCH_FN (mul_Tensor));
281+ m.impl (" mul_.Tensor" , TORCH_FN (mul__Tensor));
225282 m.impl (" sub.Tensor" , TORCH_FN (sub_Tensor));
283+ m.impl (" sub_.Tensor" , TORCH_FN (sub__Tensor));
284+ m.impl (" div.Tensor" , TORCH_FN (div_Tensor));
285+ m.impl (" div_.Tensor" , TORCH_FN (div__Tensor));
226286};
227287
228288}
0 commit comments