Skip to content

Commit 530dc82

Browse files
xta0facebook-github-bot
authored andcommitted
[iOS GPU] Support element-wise broadcasting for binary ops in shaders (#53949)
Summary: Pull Request resolved: #53949 As title says ghstack-source-id: 123849745 Test Plan: `buck test pp-mac` ``` 2021-03-11 18:25:07.922375-0800 PyTorchPlayground[8324:5122672] [bool test_add()],[1 180 12 12 ],[SUCCEED] 2021-03-11 18:25:07.960812-0800 PyTorchPlayground[8324:5122672] [bool test_add_broadcast()],[2 17 58 67 ],[SUCCEED] 2021-03-11 18:25:07.978399-0800 PyTorchPlayground[8324:5122672] [bool test_add_broadcast2()],[2 17 1 67 ],[SUCCEED] 2021-03-11 18:25:08.021570-0800 PyTorchPlayground[8324:5122672] [bool test_sub()],[5 3 167 222 ],[SUCCEED] 2021-03-11 18:25:08.034218-0800 PyTorchPlayground[8324:5122672] [bool test_sub_broadcast()],[1 3 1 1 ],[SUCCEED] 2021-03-11 18:25:08.069419-0800 PyTorchPlayground[8324:5122672] [bool test_sub_broadcast2()],[3 3 192 192 ],[SUCCEED] 2021-03-11 18:25:08.112967-0800 PyTorchPlayground[8324:5122672] [bool test_mul()],[2 7 262 119 ],[SUCCEED] 2021-03-11 18:25:08.136691-0800 PyTorchPlayground[8324:5122672] [bool test_mul_broadcast()],[4 3 192 192 ],[SUCCEED] 2021-03-11 18:25:08.148920-0800 PyTorchPlayground[8324:5122672] [bool test_mul_broadcast2()],[1 3 192 192 ],[SUCCEED] ``` Reviewed By: SS-JIA Differential Revision: D27000487 fbshipit-source-id: f86fca5ac1960ca0a56636da17ae05020c1a4138
1 parent df7c0a0 commit 530dc82

4 files changed

Lines changed: 113 additions & 47 deletions

File tree

aten/src/ATen/native/metal/MetalShaders.h

Lines changed: 65 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,73 +18,106 @@ constant ushort ushort_arg_9[[function_constant(9)]];
1818
constant float float_arg_0 [[function_constant(10)]];
1919
constant float float_arg_1 [[function_constant(11)]];
2020
21-
2221
inline constexpr ushort divRoundUp(ushort x, ushort y) { return (x + (y - 1)) / y; }
2322
23+
enum broadcastOp {
24+
Add,
25+
Sub,
26+
Mul,
27+
Div,
28+
};
29+
30+
void elementwise_broadcast_nonarray(texture2d<half, access::read> in0,
31+
texture2d<half, access::read> in1,
32+
texture2d<half, access::write> out,
33+
ushort2 gid,
34+
broadcastOp op) {
35+
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
36+
return;
37+
}
38+
ushort2 in0_stride = ushort2(in0.get_width() > 1, in0.get_height() > 1);
39+
ushort2 in1_stride = ushort2(in1.get_width() > 1, in1.get_height() > 1);
40+
41+
ushort2 gid0 = gid.xy * in0_stride;
42+
ushort2 gid1 = gid.xy * in1_stride;
43+
44+
if(op == Add) {
45+
out.write(in0.read(gid0) + in1.read(gid1), gid);
46+
} else if(op == Sub) {
47+
out.write(in0.read(gid0) - in1.read(gid1), gid);
48+
} else if(op == Mul) {
49+
out.write(in0.read(gid0) * in1.read(gid1), gid);
50+
} else if(op == Div) {
51+
out.write(in0.read(gid0) / in1.read(gid1), gid);
52+
}
53+
}
54+
55+
void elementwise_broadcast(texture2d_array<half, access::read> in0,
56+
texture2d_array<half, access::read> in1,
57+
texture2d_array<half, access::write> out,
58+
ushort3 gid,
59+
broadcastOp op) {
60+
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
61+
return;
62+
}
63+
64+
ushort2 in0_stride = ushort2(in0.get_width() > 1, in0.get_height() > 1);
65+
ushort2 in1_stride = ushort2(in1.get_width() > 1, in1.get_height() > 1);
66+
67+
ushort2 gid0 = gid.xy * in0_stride;
68+
ushort2 gid1 = gid.xy * in1_stride;
69+
70+
if(op == Add) {
71+
out.write(in0.read(gid0, gid.z) + in1.read(gid1, gid.z), gid.xy, gid.z);
72+
} else if(op == Sub) {
73+
out.write(in0.read(gid0, gid.z) - in1.read(gid1, gid.z), gid.xy, gid.z);
74+
} else if(op == Mul) {
75+
out.write(in0.read(gid0, gid.z) * in1.read(gid1, gid.z), gid.xy, gid.z);
76+
} else if(op == Div) {
77+
out.write(in0.read(gid0, gid.z) / in1.read(gid1, gid.z), gid.xy, gid.z);
78+
}
79+
}
80+
2481
kernel void elementwise_add_nonarray(texture2d<half, access::read> in0[[texture(0)]],
2582
texture2d<half, access::read> in1[[texture(1)]],
2683
texture2d<half, access::write> out[[texture(2)]],
2784
ushort2 gid[[thread_position_in_grid]]) {
28-
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
29-
return;
30-
}
31-
out.write(in0.read(gid) + in1.read(gid), gid);
85+
elementwise_broadcast_nonarray(in0, in1, out, gid, Add);
3286
}
3387
3488
kernel void elementwise_add(texture2d_array<half, access::read> in0[[texture(0)]],
3589
texture2d_array<half, access::read> in1[[texture(1)]],
3690
texture2d_array<half, access::write> out[[texture(2)]],
3791
ushort3 gid[[thread_position_in_grid]]) {
38-
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
39-
return;
40-
}
41-
ushort2 gid_ = gid.xy;
42-
out.write(in0.read(gid_, gid.z) + in1.read(gid_, gid.z), gid_, gid.z);
92+
elementwise_broadcast(in0, in1, out, gid, Add);
4393
}
4494
4595
kernel void elementwise_sub_nonarray(texture2d<half, access::read> in0[[texture(0)]],
4696
texture2d<half, access::read> in1[[texture(1)]],
4797
texture2d<half, access::write> out[[texture(2)]],
4898
ushort2 gid[[thread_position_in_grid]]) {
49-
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
50-
return;
51-
}
52-
ushort2 gid2{0,0};
53-
out.write(in0.read(gid) - in1.read(gid2), gid);
99+
elementwise_broadcast_nonarray(in0, in1, out, gid, Sub);
54100
}
55101
56102
kernel void elementwise_sub(texture2d_array<half, access::read> in0[[texture(0)]],
57103
texture2d_array<half, access::read> in1[[texture(1)]],
58104
texture2d_array<half, access::write> out[[texture(2)]],
59105
ushort3 gid[[thread_position_in_grid]]) {
60-
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
61-
return;
62-
}
63-
ushort2 gid1 = gid.xy;
64-
ushort2 gid2{0,0};
65-
out.write(in0.read(gid1, gid.z) - in1.read(gid2, gid.z), gid1, gid.z);
106+
elementwise_broadcast(in0, in1, out, gid, Sub);
66107
}
108+
67109
kernel void elementwise_mul_nonarray(texture2d<half, access::read> in0[[texture(0)]],
68110
texture2d<half, access::read> in1[[texture(1)]],
69111
texture2d<half, access::write> out[[texture(2)]],
70112
ushort2 gid[[thread_position_in_grid]]) {
71-
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
72-
return;
73-
}
74-
ushort2 gid2{0,0};
75-
out.write(in0.read(gid) * in1.read(gid2), gid);
113+
elementwise_broadcast_nonarray(in0, in1, out, gid, Mul);
76114
}
77115
78116
kernel void elementwise_mul(texture2d_array<half, access::read> in0[[texture(0)]],
79117
texture2d_array<half, access::read> in1[[texture(1)]],
80118
texture2d_array<half, access::write> out[[texture(2)]],
81119
ushort3 gid[[thread_position_in_grid]]) {
82-
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
83-
return;
84-
}
85-
ushort2 gid1 = gid.xy;
86-
ushort2 gid2{0,0};
87-
out.write(in0.read(gid1, gid.z) * in1.read(gid2, gid.z), gid1, gid.z);
120+
elementwise_broadcast(in0, in1, out, gid, Mul);
88121
}
89122
90123
kernel void copy_nchw_to_metal(constant float* in[[buffer(0)]],

aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ bool test_relu();
1313
bool test_addmm();
1414
bool test_add();
1515
bool test_add_broadcast();
16+
bool test_add_broadcast2();
1617
bool test_sub();
1718
bool test_sub_broadcast();
1819
bool test_sub_broadcast2();

aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,28 @@ bool TEST(const std::vector<int64_t>& sizes, std::string name, Func block) {
7676
return b;
7777
}
7878

79+
void PRINT_TENSOR(std::string name, const at::Tensor& tensor){
80+
std::string str = name + ": ";
81+
auto print = [&](const at::Tensor& t){
82+
for(int i=0; i<t.numel(); ++i){
83+
NSString* sf = [NSString stringWithFormat:@"%.2f",t.data_ptr<float>()[i]];
84+
str += sf.UTF8String;
85+
str += ", ";
86+
}
87+
std::cout<<str<<std::endl;
88+
};
89+
if(tensor.is_metal()){
90+
MPSImage* image = at::native::metal::imageFromTensor(tensor);
91+
auto t = at::native::metal::staticImageToTensor(image);
92+
print(t);
93+
} else {
94+
print(tensor);
95+
}
96+
}
97+
7998
}
8099

81-
using namespace at::native::metal;
100+
using namespace at::native::metal;
82101

83102
bool test_synchronization() {
84103
__block std::vector<int64_t> size{1, 3, 2, 2};
@@ -324,6 +343,21 @@ bool test_add_broadcast() {
324343
});
325344
}
326345

346+
bool test_add_broadcast2() {
347+
__block std::vector<int64_t> x1{2, 17, 1, 67};
348+
__block std::vector<int64_t> x2{2, 17, 58, 67};
349+
return TEST(x1, __PRETTY_FUNCTION__, ^bool {
350+
auto X1 = at::rand(x1, at::TensorOptions(at::kCPU).dtype(at::kFloat));
351+
auto X2 = at::rand(x2, at::TensorOptions(at::kCPU).dtype(at::kFloat));
352+
auto Y1 = at::add(X1, X2);
353+
auto MX1 = X1.metal();
354+
auto MX2 = X2.metal();
355+
auto Y2 = at::add(MX1, MX2).cpu();
356+
return almostEqual(Y1, Y2);
357+
});
358+
}
359+
360+
327361
bool test_sub() {
328362
__block std::vector<int64_t> x{5, 3, 167, 222};
329363
return TEST(x, __PRETTY_FUNCTION__, ^bool {
@@ -338,8 +372,8 @@ bool test_sub() {
338372
}
339373

340374
bool test_sub_broadcast() {
341-
__block std::vector<int64_t> x1{3, 3, 1, 1};
342-
__block std::vector<int64_t> x2{3, 3, 192, 192};
375+
__block std::vector<int64_t> x1{1, 3, 1, 1};
376+
__block std::vector<int64_t> x2{1, 3, 192, 192};
343377
return TEST(x1, __PRETTY_FUNCTION__, ^bool {
344378
auto X1 = at::rand(x1, at::TensorOptions(at::kCPU).dtype(at::kFloat));
345379
auto X2 = at::rand(x2, at::TensorOptions(at::kCPU).dtype(at::kFloat));
@@ -393,8 +427,8 @@ bool test_mul_broadcast() {
393427
}
394428

395429
bool test_mul_broadcast2() {
396-
__block std::vector<int64_t> x1{4, 3, 192, 1};
397-
__block std::vector<int64_t> x2{4, 3, 192, 192};
430+
__block std::vector<int64_t> x2{1, 3, 192, 1};
431+
__block std::vector<int64_t> x1{1, 3, 192, 192};
398432
return TEST(x1, __PRETTY_FUNCTION__, ^bool {
399433
auto X1 = at::rand(x1, at::TensorOptions(at::kCPU).dtype(at::kFloat));
400434
auto X2 = at::rand(x2, at::TensorOptions(at::kCPU).dtype(at::kFloat));

aten/src/ATen/native/metal/ops/MetalBinaryElementwise.mm

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,12 @@ Tensor binaryElementwiseMPSCNNKernel(
165165
Tensor add_Tensor(const Tensor& input1, const Tensor& input2, Scalar alpha) {
166166
TORCH_CHECK(input1.is_metal());
167167
TORCH_CHECK(input1.dim() == input2.dim());
168+
TORCH_CHECK(input1.sizes()[0] == input2.sizes()[0]);
169+
TORCH_CHECK(input1.sizes()[1] == input2.sizes()[1]);
168170
auto input2_ = input2.is_metal() ? input2 : input2.metal();
169171
if (@available(iOS 11.3, *)) {
170172
return binaryElementwiseMPSCNNKernel<MPSCNNAdd>(input1, input2_);
171173
} else {
172-
// TODO: support broadcast in shader functions for iOS 10 users
173-
TORCH_CHECK(input1.sizes()[2] == input2.sizes()[2]);
174-
TORCH_CHECK(input1.sizes()[3] == input2.sizes()[3]);
175174
return binaryElementwiseShaderKernel(
176175
input1, input2_, @"elementwise_add", @"elementwise_add_nonarray");
177176
}
@@ -180,13 +179,12 @@ Tensor add_Tensor(const Tensor& input1, const Tensor& input2, Scalar alpha) {
180179
Tensor& add__Tensor(Tensor& input1, const Tensor& input2, Scalar alpha) {
181180
TORCH_CHECK(input1.is_metal());
182181
TORCH_CHECK(input1.dim() == input2.dim());
182+
TORCH_CHECK(input1.sizes()[0] == input2.sizes()[0]);
183+
TORCH_CHECK(input1.sizes()[1] == input2.sizes()[1]);
183184
auto input2_ = input2.is_metal() ? input2 : input2.metal();
184185
if (@available(iOS 11.3, *)) {
185186
return binaryElementwiseMPSCNNKernel_<MPSCNNAdd>(input1, input2_);
186187
} else {
187-
// TODO: support broadcast in for iOS 10 users
188-
TORCH_CHECK(input1.sizes()[2] == input2.sizes()[2]);
189-
TORCH_CHECK(input1.sizes()[3] == input2.sizes()[3]);
190188
return binaryElementwiseShaderKernel_(
191189
input1, input2_, @"elementwise_add", @"elementwise_add_nonarray");
192190
}
@@ -195,12 +193,12 @@ Tensor add_Tensor(const Tensor& input1, const Tensor& input2, Scalar alpha) {
195193
Tensor sub_Tensor(const Tensor& input1, const Tensor& input2, Scalar alpha) {
196194
TORCH_CHECK(input1.is_metal());
197195
TORCH_CHECK(input1.dim() == input2.dim());
196+
TORCH_CHECK(input1.sizes()[0] == input2.sizes()[0]);
197+
TORCH_CHECK(input1.sizes()[1] == input2.sizes()[1]);
198198
auto input2_ = input2.is_metal() ? input2 : input2.metal();
199199
if (@available(iOS 11.3, *)) {
200200
return binaryElementwiseMPSCNNKernel<MPSCNNSubtract>(input1, input2_);
201201
} else {
202-
// TODO: support non-broadcast for iOS 10 users
203-
TORCH_CHECK(input2.sizes()[2] == input2.sizes()[3] == 1);
204202
return binaryElementwiseShaderKernel(
205203
input1, input2_, @"elementwise_sub", @"elementwise_sub_nonarray");
206204
}
@@ -209,12 +207,12 @@ Tensor sub_Tensor(const Tensor& input1, const Tensor& input2, Scalar alpha) {
209207
Tensor mul_Tensor(const Tensor& input1, const Tensor& input2) {
210208
TORCH_CHECK(input1.is_metal());
211209
TORCH_CHECK(input1.dim() == input2.dim());
210+
TORCH_CHECK(input1.sizes()[0] == input2.sizes()[0]);
211+
TORCH_CHECK(input1.sizes()[1] == input2.sizes()[1]);
212212
auto input2_ = input2.is_metal() ? input2 : input2.metal();
213213
if (@available(iOS 11.3, *)) {
214214
return binaryElementwiseMPSCNNKernel<MPSCNNMultiply>(input1, input2_);
215215
} else {
216-
// TODO: support non-broadcast for iOS 10 users
217-
TORCH_CHECK(input2.sizes()[2] == input2.sizes()[3] == 1);
218216
return binaryElementwiseShaderKernel(
219217
input1, input2_, @"elementwise_mul", @"elementwise_mul_nonarray");
220218
}

0 commit comments

Comments
 (0)