Skip to content

Commit bb21aea

Browse files
xta0facebook-github-bot
authored andcommitted
[iOS GPU] Add the reset of binary ops (#53950)
Summary: Pull Request resolved: #53950 Add four binary ops to Metal - `aten::mul_` - `aten::sub_` - `aten::div` - `aten::div_` ghstack-source-id: 123850577 Test Plan: - `buck test pp-mac` ``` 2021-03-11 20:36:47.151139-0800 PyTorchPlayground[8469:5169786] [bool test_sub()],[5 3 167 222 ],[SUCCEED] 2021-03-11 20:36:47.157638-0800 PyTorchPlayground[8469:5169786] [bool test_sub_broadcast()],[1 3 1 1 ],[SUCCEED] 2021-03-11 20:36:47.170640-0800 PyTorchPlayground[8469:5169786] [bool test_sub_broadcast2()],[3 3 192 192 ],[SUCCEED] 2021-03-11 20:36:47.194009-0800 PyTorchPlayground[8469:5169786] [bool test_mul()],[2 7 262 119 ],[SUCCEED] 2021-03-11 20:36:47.210344-0800 PyTorchPlayground[8469:5169786] [bool test_mul_broadcast()],[4 3 192 192 ],[SUCCEED] 2021-03-11 20:36:47.216610-0800 PyTorchPlayground[8469:5169786] [bool test_mul_broadcast2()],[1 3 192 192 ],[SUCCEED] 2021-03-11 20:36:47.224471-0800 PyTorchPlayground[8469:5169786] [bool test_div()],[1 3 192 192 ],[SUCCEED] 2021-03-11 20:36:47.240817-0800 PyTorchPlayground[8469:5169786] [bool test_div_broadcast()],[4 3 192 192 ],[SUCCEED] 2021-03-11 20:36:47.246816-0800 PyTorchPlayground[8469:5169786] [bool test_div_broadcast2()],[1 3 192 192 ],[SUCCEED] ``` Reviewed By: SS-JIA Differential Revision: D27003417 fbshipit-source-id: 290f7e524eef4c444f8884fc1315151752e5ac31
1 parent 530dc82 commit bb21aea

4 files changed

Lines changed: 118 additions & 0 deletions

File tree

aten/src/ATen/native/metal/MetalShaders.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,20 @@ kernel void elementwise_mul(texture2d_array<half, access::read> in0[[texture(0)]
120120
elementwise_broadcast(in0, in1, out, gid, Mul);
121121
}
122122
123+
kernel void elementwise_div_nonarray(texture2d<half, access::read> in0[[texture(0)]],
124+
texture2d<half, access::read> in1[[texture(1)]],
125+
texture2d<half, access::write> out[[texture(2)]],
126+
ushort2 gid[[thread_position_in_grid]]) {
127+
elementwise_broadcast_nonarray(in0, in1, out, gid, Div);
128+
}
129+
130+
kernel void elementwise_div(texture2d_array<half, access::read> in0[[texture(0)]],
131+
texture2d_array<half, access::read> in1[[texture(1)]],
132+
texture2d_array<half, access::write> out[[texture(2)]],
133+
ushort3 gid[[thread_position_in_grid]]) {
134+
elementwise_broadcast(in0, in1, out, gid, Div);
135+
}
136+
123137
kernel void copy_nchw_to_metal(constant float* in[[buffer(0)]],
124138
texture2d_array<half, access::write> out[[texture(0)]],
125139
ushort3 gid[[thread_position_in_grid]]) {

aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ bool test_sub_broadcast2();
2020
bool test_mul();
2121
bool test_mul_broadcast();
2222
bool test_mul_broadcast2();
23+
bool test_div();
24+
bool test_div_broadcast();
25+
bool test_div_broadcast2();
2326
bool test_t();
2427
bool test_view();
2528
bool test_cat_dim0();

aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,47 @@ bool test_mul_broadcast2() {
440440
});
441441
}
442442

443+
bool test_div() {
444+
__block std::vector<int64_t> x{1, 3, 24, 24};
445+
return TEST(x, __PRETTY_FUNCTION__, ^bool {
446+
auto X1 = at::rand(x, at::TensorOptions(at::kCPU).dtype(at::kFloat));
447+
auto X2 = at::rand(x, at::TensorOptions(at::kCPU).dtype(at::kFloat));
448+
auto Y1 = at::div(X1, X2);
449+
auto MX1 = X1.metal();
450+
auto MX2 = X2.metal();
451+
auto Y2 = at::div(MX1, MX2).cpu();
452+
return almostEqual(Y1, Y2);
453+
});
454+
}
455+
456+
bool test_div_broadcast() {
457+
__block std::vector<int64_t> x1{4, 3, 24, 24};
458+
__block std::vector<int64_t> x2{4, 3, 1, 1};
459+
return TEST(x1, __PRETTY_FUNCTION__, ^bool {
460+
auto X1 = at::rand(x1, at::TensorOptions(at::kCPU).dtype(at::kFloat));
461+
auto X2 = at::rand(x2, at::TensorOptions(at::kCPU).dtype(at::kFloat));
462+
auto Y1 = at::div(X1, X2);
463+
auto MX1 = X1.metal();
464+
auto MX2 = X2.metal();
465+
auto Y2 = at::div(MX1, MX2).cpu();
466+
return almostEqual(Y1, Y2);
467+
});
468+
}
469+
470+
bool test_div_broadcast2() {
471+
__block std::vector<int64_t> x2{1, 3, 24, 1};
472+
__block std::vector<int64_t> x1{1, 3, 24, 24};
473+
return TEST(x1, __PRETTY_FUNCTION__, ^bool {
474+
auto X1 = at::rand(x1, at::TensorOptions(at::kCPU).dtype(at::kFloat));
475+
auto X2 = at::rand(x2, at::TensorOptions(at::kCPU).dtype(at::kFloat));
476+
auto Y1 = at::div(X1, X2);
477+
auto MX1 = X1.metal();
478+
auto MX2 = X2.metal();
479+
auto Y2 = at::div(MX1, MX2).cpu();
480+
return almostEqual(Y1, Y2);
481+
});
482+
}
483+
443484
bool test_t() {
444485
bool result = true;
445486
for (int i = 0; i < ITER_COUNT; ++i) {

aten/src/ATen/native/metal/ops/MetalBinaryElementwise.mm

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,20 @@ Tensor sub_Tensor(const Tensor& input1, const Tensor& input2, Scalar alpha) {
204204
}
205205
}
206206

207+
Tensor& sub__Tensor(Tensor& input1, const Tensor& input2, Scalar alpha) {
208+
TORCH_CHECK(input1.is_metal());
209+
TORCH_CHECK(input1.dim() == input2.dim());
210+
TORCH_CHECK(input1.sizes()[0] == input2.sizes()[0]);
211+
TORCH_CHECK(input1.sizes()[1] == input2.sizes()[1]);
212+
auto input2_ = input2.is_metal() ? input2 : input2.metal();
213+
if (@available(iOS 11.3, *)) {
214+
return binaryElementwiseMPSCNNKernel_<MPSCNNSubtract>(input1, input2_);
215+
} else {
216+
return binaryElementwiseShaderKernel_(
217+
input1, input2_, @"elementwise_sub", @"elementwise_sub_nonarray");
218+
}
219+
}
220+
207221
Tensor mul_Tensor(const Tensor& input1, const Tensor& input2) {
208222
TORCH_CHECK(input1.is_metal());
209223
TORCH_CHECK(input1.dim() == input2.dim());
@@ -218,11 +232,57 @@ Tensor mul_Tensor(const Tensor& input1, const Tensor& input2) {
218232
}
219233
}
220234

235+
Tensor& mul__Tensor(Tensor& input1, const Tensor& input2) {
236+
TORCH_CHECK(input1.is_metal());
237+
TORCH_CHECK(input1.dim() == input2.dim());
238+
TORCH_CHECK(input1.sizes()[0] == input2.sizes()[0]);
239+
TORCH_CHECK(input1.sizes()[1] == input2.sizes()[1]);
240+
auto input2_ = input2.is_metal() ? input2 : input2.metal();
241+
if (@available(iOS 11.3, *)) {
242+
return binaryElementwiseMPSCNNKernel_<MPSCNNMultiply>(input1, input2_);
243+
} else {
244+
return binaryElementwiseShaderKernel_(
245+
input1, input2_, @"elementwise_mul", @"elementwise_mul_nonarray");
246+
}
247+
}
248+
249+
Tensor div_Tensor(const Tensor& input1, const Tensor& input2) {
250+
TORCH_CHECK(input1.is_metal());
251+
TORCH_CHECK(input1.dim() == input2.dim());
252+
TORCH_CHECK(input1.sizes()[0] == input2.sizes()[0]);
253+
TORCH_CHECK(input1.sizes()[1] == input2.sizes()[1]);
254+
auto input2_ = input2.is_metal() ? input2 : input2.metal();
255+
if (@available(iOS 11.3, *)) {
256+
return binaryElementwiseMPSCNNKernel<MPSCNNDivide>(input1, input2_);
257+
} else {
258+
return binaryElementwiseShaderKernel(
259+
input1, input2_, @"elementwise_div", @"elementwise_div_nonarray");
260+
}
261+
}
262+
263+
Tensor& div__Tensor(Tensor& input1, const Tensor& input2) {
264+
TORCH_CHECK(input1.is_metal());
265+
TORCH_CHECK(input1.dim() == input2.dim());
266+
TORCH_CHECK(input1.sizes()[0] == input2.sizes()[0]);
267+
TORCH_CHECK(input1.sizes()[1] == input2.sizes()[1]);
268+
auto input2_ = input2.is_metal() ? input2 : input2.metal();
269+
if (@available(iOS 11.3, *)) {
270+
return binaryElementwiseMPSCNNKernel_<MPSCNNDivide>(input1, input2_);
271+
} else {
272+
return binaryElementwiseShaderKernel_(
273+
input1, input2_, @"elementwise_div", @"elementwise_div_nonarray");
274+
}
275+
}
276+
221277
TORCH_LIBRARY_IMPL(aten, Metal, m) {
222278
m.impl("add.Tensor", TORCH_FN(add_Tensor));
223279
m.impl("add_.Tensor", TORCH_FN(add__Tensor));
224280
m.impl("mul.Tensor", TORCH_FN(mul_Tensor));
281+
m.impl("mul_.Tensor", TORCH_FN(mul__Tensor));
225282
m.impl("sub.Tensor", TORCH_FN(sub_Tensor));
283+
m.impl("sub_.Tensor", TORCH_FN(sub__Tensor));
284+
m.impl("div.Tensor", TORCH_FN(div_Tensor));
285+
m.impl("div_.Tensor", TORCH_FN(div__Tensor));
226286
};
227287

228288
}

0 commit comments

Comments
 (0)