@@ -18,73 +18,106 @@ constant ushort ushort_arg_9[[function_constant(9)]];
1818constant float float_arg_0 [[function_constant(10)]];
1919constant float float_arg_1 [[function_constant(11)]];
2020
21-
2221inline constexpr ushort divRoundUp(ushort x, ushort y) { return (x + (y - 1)) / y; }
2322
23+ enum broadcastOp {
24+ Add,
25+ Sub,
26+ Mul,
27+ Div,
28+ };
29+
30+ void elementwise_broadcast_nonarray(texture2d<half, access::read> in0,
31+ texture2d<half, access::read> in1,
32+ texture2d<half, access::write> out,
33+ ushort2 gid,
34+ broadcastOp op) {
35+ if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
36+ return;
37+ }
38+ ushort2 in0_stride = ushort2(in0.get_width() > 1, in0.get_height() > 1);
39+ ushort2 in1_stride = ushort2(in1.get_width() > 1, in1.get_height() > 1);
40+
41+ ushort2 gid0 = gid.xy * in0_stride;
42+ ushort2 gid1 = gid.xy * in1_stride;
43+
44+ if(op == Add) {
45+ out.write(in0.read(gid0) + in1.read(gid1), gid);
46+ } else if(op == Sub) {
47+ out.write(in0.read(gid0) - in1.read(gid1), gid);
48+ } else if(op == Mul) {
49+ out.write(in0.read(gid0) * in1.read(gid1), gid);
50+ } else if(op == Div) {
51+ out.write(in0.read(gid0) / in1.read(gid1), gid);
52+ }
53+ }
54+
55+ void elementwise_broadcast(texture2d_array<half, access::read> in0,
56+ texture2d_array<half, access::read> in1,
57+ texture2d_array<half, access::write> out,
58+ ushort3 gid,
59+ broadcastOp op) {
60+ if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
61+ return;
62+ }
63+
64+ ushort2 in0_stride = ushort2(in0.get_width() > 1, in0.get_height() > 1);
65+ ushort2 in1_stride = ushort2(in1.get_width() > 1, in1.get_height() > 1);
66+
67+ ushort2 gid0 = gid.xy * in0_stride;
68+ ushort2 gid1 = gid.xy * in1_stride;
69+
70+ if(op == Add) {
71+ out.write(in0.read(gid0, gid.z) + in1.read(gid1, gid.z), gid.xy, gid.z);
72+ } else if(op == Sub) {
73+ out.write(in0.read(gid0, gid.z) - in1.read(gid1, gid.z), gid.xy, gid.z);
74+ } else if(op == Mul) {
75+ out.write(in0.read(gid0, gid.z) * in1.read(gid1, gid.z), gid.xy, gid.z);
76+ } else if(op == Div) {
77+ out.write(in0.read(gid0, gid.z) / in1.read(gid1, gid.z), gid.xy, gid.z);
78+ }
79+ }
80+
2481kernel void elementwise_add_nonarray(texture2d<half, access::read> in0[[texture(0)]],
2582 texture2d<half, access::read> in1[[texture(1)]],
2683 texture2d<half, access::write> out[[texture(2)]],
2784 ushort2 gid[[thread_position_in_grid]]) {
28- if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
29- return;
30- }
31- out.write(in0.read(gid) + in1.read(gid), gid);
85+ elementwise_broadcast_nonarray(in0, in1, out, gid, Add);
3286}
3387
3488kernel void elementwise_add(texture2d_array<half, access::read> in0[[texture(0)]],
3589 texture2d_array<half, access::read> in1[[texture(1)]],
3690 texture2d_array<half, access::write> out[[texture(2)]],
3791 ushort3 gid[[thread_position_in_grid]]) {
38- if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
39- return;
40- }
41- ushort2 gid_ = gid.xy;
42- out.write(in0.read(gid_, gid.z) + in1.read(gid_, gid.z), gid_, gid.z);
92+ elementwise_broadcast(in0, in1, out, gid, Add);
4393}
4494
4595kernel void elementwise_sub_nonarray(texture2d<half, access::read> in0[[texture(0)]],
4696 texture2d<half, access::read> in1[[texture(1)]],
4797 texture2d<half, access::write> out[[texture(2)]],
4898 ushort2 gid[[thread_position_in_grid]]) {
49- if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
50- return;
51- }
52- ushort2 gid2{0,0};
53- out.write(in0.read(gid) - in1.read(gid2), gid);
99+ elementwise_broadcast_nonarray(in0, in1, out, gid, Sub);
54100}
55101
56102kernel void elementwise_sub(texture2d_array<half, access::read> in0[[texture(0)]],
57103 texture2d_array<half, access::read> in1[[texture(1)]],
58104 texture2d_array<half, access::write> out[[texture(2)]],
59105 ushort3 gid[[thread_position_in_grid]]) {
60- if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
61- return;
62- }
63- ushort2 gid1 = gid.xy;
64- ushort2 gid2{0,0};
65- out.write(in0.read(gid1, gid.z) - in1.read(gid2, gid.z), gid1, gid.z);
106+ elementwise_broadcast(in0, in1, out, gid, Sub);
66107}
108+
67109kernel void elementwise_mul_nonarray(texture2d<half, access::read> in0[[texture(0)]],
68110 texture2d<half, access::read> in1[[texture(1)]],
69111 texture2d<half, access::write> out[[texture(2)]],
70112 ushort2 gid[[thread_position_in_grid]]) {
71- if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
72- return;
73- }
74- ushort2 gid2{0,0};
75- out.write(in0.read(gid) * in1.read(gid2), gid);
113+ elementwise_broadcast_nonarray(in0, in1, out, gid, Mul);
76114}
77115
78116kernel void elementwise_mul(texture2d_array<half, access::read> in0[[texture(0)]],
79117 texture2d_array<half, access::read> in1[[texture(1)]],
80118 texture2d_array<half, access::write> out[[texture(2)]],
81119 ushort3 gid[[thread_position_in_grid]]) {
82- if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
83- return;
84- }
85- ushort2 gid1 = gid.xy;
86- ushort2 gid2{0,0};
87- out.write(in0.read(gid1, gid.z) * in1.read(gid2, gid.z), gid1, gid.z);
120+ elementwise_broadcast(in0, in1, out, gid, Mul);
88121}
89122
90123kernel void copy_nchw_to_metal(constant float* in[[buffer(0)]],
0 commit comments