Skip to content

Commit cce84b5

Browse files
Iurii Zdebskyifacebook-github-bot
authored andcommitted
[WIP] Update foreach APIs to use scalar lists (#48223)
Summary: Pull Request resolved: #48223 Test Plan: Imported from OSS Reviewed By: zou3519 Differential Revision: D25074763 Pulled By: izdeby fbshipit-source-id: 155e3d2073a20d16bdbe358820170bf53f93c7a5
1 parent 506fdf9 commit cce84b5

8 files changed

Lines changed: 234 additions & 230 deletions

File tree

aten/src/ATen/native/ForeachOpsKernels.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@ std::vector<Tensor> foreach_tensor_##OP##_scalar_kernel_slow(TensorList tensors,
2525
}
2626

2727
#define FOREACH_BINARY_OP_SCALARLIST(OP) \
28-
void foreach_tensor_##OP##_scalarlist_kernel_slow_(TensorList tensors, at::ArrayRef<double> scalars) { \
28+
void foreach_tensor_##OP##_scalarlist_kernel_slow_(TensorList tensors, at::ArrayRef<Scalar> scalars) { \
2929
check_foreach_api_restrictions(tensors, scalars); \
3030
\
3131
for (size_t i = 0; i < tensors.size(); i++) { \
3232
tensors[i].OP##_(scalars[i]); \
3333
} \
3434
} \
3535
\
36-
std::vector<Tensor> foreach_tensor_##OP##_scalarlist_kernel_slow(TensorList tensors, at::ArrayRef<double> scalars) { \
36+
std::vector<Tensor> foreach_tensor_##OP##_scalarlist_kernel_slow(TensorList tensors, at::ArrayRef<Scalar> scalars) { \
3737
check_foreach_api_restrictions(tensors, scalars); \
3838
std::vector<Tensor> result; \
3939
result.reserve(tensors.size()); \
@@ -128,7 +128,7 @@ void foreach_tensor_##OP##_scalar_slow_(TensorList input, TensorList tensors1, T
128128
} \
129129

130130
#define FOREACH_POINTWISE_OP_SCALARLIST(OP) \
131-
std::vector<Tensor> foreach_tensor_##OP##_scalarlist_slow(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<double> scalars) { \
131+
std::vector<Tensor> foreach_tensor_##OP##_scalarlist_slow(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<Scalar> scalars) { \
132132
check_foreach_api_restrictions(input, tensors1, tensors2, scalars); \
133133
\
134134
std::vector<Tensor> result; \
@@ -139,7 +139,7 @@ std::vector<Tensor> foreach_tensor_##OP##_scalarlist_slow(TensorList input, Tens
139139
return result; \
140140
} \
141141
\
142-
void foreach_tensor_##OP##_scalarlist_slow_(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<double> scalars) { \
142+
void foreach_tensor_##OP##_scalarlist_slow_(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<Scalar> scalars) { \
143143
check_foreach_api_restrictions(input, tensors1, tensors2, scalars); \
144144
\
145145
for (size_t i = 0; i < input.size(); i++) { \

aten/src/ATen/native/ForeachUtils.h

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
namespace at {
55
namespace native {
66
namespace {
7-
7+
// Check foreach API restrictions
8+
// - Tensor lists must be non-empty.
9+
// - All tensors in all lists must have the same dtype.
10+
// - All TensorLists and ScalarLists must have the same number of elements.
11+
// - Corresponding tensors must have the same size.
812
void check_foreach_api_restrictions(TensorList tensors) {
913
TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor.");
1014
auto expected_dtype = tensors[0].dtype();
@@ -13,7 +17,7 @@ void check_foreach_api_restrictions(TensorList tensors) {
1317
}
1418
}
1519

16-
void check_foreach_api_restrictions(TensorList tensors, ArrayRef<double> scalars) {
20+
void check_foreach_api_restrictions(TensorList tensors, ArrayRef<Scalar> scalars) {
1721
check_foreach_api_restrictions(tensors);
1822
TORCH_CHECK(tensors.size() == scalars.size(), "Tensor list must have same number of elements as scalar list.");
1923
}
@@ -49,7 +53,7 @@ void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2, Te
4953
}
5054
}
5155

52-
void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2, TensorList tensors3, ArrayRef<double> scalars) {
56+
void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2, TensorList tensors3, ArrayRef<Scalar> scalars) {
5357
check_foreach_api_restrictions(tensors1, tensors2, tensors3);
5458
TORCH_CHECK(tensors1.size() == scalars.size(), "Tensor list must have same number of elements as scalar list, got ", tensors1.size(), " and ", scalars.size());
5559
}
@@ -85,21 +89,8 @@ bool has_same_attributes(Device expected_device, TensorList tensors) {
8589
}
8690

8791
bool will_promote_tensor(const Tensor& tensor, Scalar scalar) {
88-
// complex scalar + integral or boolean tensor will result in complex tensor
89-
if (scalar.isComplex() && at::isIntegralType(tensor.scalar_type(), /*includeBool*/ true)) {
90-
return false;
91-
}
92-
93-
// float scalar + integral or boolean tensor will result in float tensor
94-
if (scalar.isFloatingPoint() && at::isIntegralType(tensor.scalar_type(), /*includeBool*/ true)) {
95-
return false;
96-
}
97-
98-
// integral scalar + boolean tensor will result in integral tensor
99-
if (scalar.isIntegral(/*includeBool*/ false) && tensor.dtype() == at::kBool) {
100-
return false;
101-
}
102-
return true;
92+
auto result_dtype = at::result_type(tensor, scalar);
93+
return result_dtype != tensor.scalar_type();
10394
}
10495

10596
bool can_use_fast_route(TensorList tensors) {
@@ -128,7 +119,7 @@ bool can_use_fast_route(TensorList tensors, Scalar scalar) {
128119
return false;
129120
}
130121

131-
if (!will_promote_tensor(t, scalar)) {
122+
if (will_promote_tensor(t, scalar)) {
132123
return false;
133124
}
134125
}
@@ -137,8 +128,18 @@ bool can_use_fast_route(TensorList tensors, Scalar scalar) {
137128
#endif
138129
}
139130

140-
bool can_use_fast_route(TensorList tensors, ArrayRef<double> scalars) {
141-
return can_use_fast_route(tensors);
131+
bool can_use_fast_route(TensorList tensors, ArrayRef<Scalar> scalars) {
132+
#ifdef __HIP_PLATFORM_HCC__
133+
return false;
134+
#else
135+
for (int i = 0; i < tensors.size(); i++) {
136+
if (will_promote_tensor(tensors[i], scalars[i])) {
137+
return false;
138+
}
139+
}
140+
141+
return true;
142+
#endif
142143
}
143144

144145
bool can_use_fast_route(TensorList tensors1, TensorList tensors2) {
@@ -166,7 +167,7 @@ bool can_use_fast_route(TensorList tensors1, TensorList tensors2, Scalar scalar)
166167
return false;
167168
}
168169

169-
if (!will_promote_tensor(tensors1[i], scalar)) {
170+
if (will_promote_tensor(tensors1[i], scalar)) {
170171
return false;
171172
}
172173
}
@@ -200,7 +201,7 @@ bool can_use_fast_route(TensorList tensors1, TensorList tensors2, TensorList ten
200201
return false;
201202
}
202203

203-
if (!will_promote_tensor(tensors1[i], scalar)) {
204+
if (will_promote_tensor(tensors1[i], scalar)) {
204205
return false;
205206
}
206207
}
@@ -209,7 +210,7 @@ bool can_use_fast_route(TensorList tensors1, TensorList tensors2, TensorList ten
209210
#endif
210211
}
211212

212-
bool can_use_fast_route(TensorList tensors1, TensorList tensors2, TensorList tensors3, ArrayRef<double> scalars) {
213+
bool can_use_fast_route(TensorList tensors1, TensorList tensors2, TensorList tensors3, ArrayRef<Scalar> scalars) {
213214
return can_use_fast_route(tensors1, tensors2, tensors3);
214215
}
215216

aten/src/ATen/native/cuda/ForeachBinaryOpScalarList.cu

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
namespace at { namespace native {
66

77
template<template<class> class Op>
8-
std::vector<Tensor> foreach_binary_op(TensorList tensors, at::ArrayRef<double> scalars) {
8+
std::vector<Tensor> foreach_binary_op(TensorList tensors, at::ArrayRef<Scalar> scalars) {
99
std::vector<std::vector<at::Tensor>> tensor_lists;
1010
std::vector<at::Tensor> vec_res;
1111
vec_res.reserve(tensors.size());
@@ -18,52 +18,51 @@ std::vector<Tensor> foreach_binary_op(TensorList tensors, at::ArrayRef<double> s
1818

1919
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_binary_op_scalarlist_cuda", [&]() {
2020
using opmath_t = get_opmath_t<scalar_t>::opmath_t;
21-
multi_tensor_apply<2>(tensor_lists,
22-
scalars,
23-
BinaryOpScalarListFunctor<scalar_t,
24-
/* depth */ 2,
25-
/* r_args_depth */ 1,
26-
/* res_arg_index */ 1>(),
27-
28-
Op<opmath_t>());
21+
multi_tensor_apply<2, opmath_t>(tensor_lists,
22+
scalars,
23+
BinaryOpScalarListFunctor<scalar_t,
24+
/* depth */ 2,
25+
/* r_args_depth */ 1,
26+
/* res_arg_index */ 1>(),
27+
Op<opmath_t>());
2928
});
3029
return tensor_lists[1];
3130
}
3231

3332
template<template<class> class Op>
34-
void foreach_binary_op_(TensorList tensors, at::ArrayRef<double> scalars) {
33+
void foreach_binary_op_(TensorList tensors, at::ArrayRef<Scalar> scalars) {
3534
std::vector<std::vector<at::Tensor>> tensor_lists;
3635
tensor_lists.emplace_back(tensors.vec());
3736

3837
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_binary_op_scalarlist_cuda_", [&]() {
3938
using opmath_t = get_opmath_t<scalar_t>::opmath_t;
40-
multi_tensor_apply<1>(tensor_lists,
41-
scalars,
42-
BinaryOpScalarListFunctor<scalar_t,
43-
/* depth */ 1,
44-
/* r_args_depth */ 1,
45-
/* res_arg_index */ 0>(),
46-
Op<opmath_t>());
39+
multi_tensor_apply<1, opmath_t>(tensor_lists,
40+
scalars,
41+
BinaryOpScalarListFunctor<scalar_t,
42+
/* depth */ 1,
43+
/* r_args_depth */ 1,
44+
/* res_arg_index */ 0>(),
45+
Op<opmath_t>());
4746
});
4847
}
4948

50-
#define FOREACH_BINARY_OP_SCALARLIST(NAME, OP) \
51-
void foreach_tensor_##NAME##_scalarlist_kernel_cuda_(TensorList tensors, at::ArrayRef<double> scalars) { \
52-
check_foreach_api_restrictions(tensors, scalars); \
53-
if (!can_use_fast_route(tensors, scalars)) { \
54-
return at::native::foreach_tensor_##NAME##_scalarlist_kernel_slow_(tensors, scalars); \
55-
} \
56-
\
57-
foreach_binary_op_<OP>(tensors, scalars); \
58-
} \
59-
\
60-
std::vector<Tensor> foreach_tensor_##NAME##_scalarlist_kernel_cuda(TensorList tensors, at::ArrayRef<double> scalars) { \
61-
check_foreach_api_restrictions(tensors, scalars); \
62-
if (!can_use_fast_route(tensors, scalars)) { \
63-
return at::native::foreach_tensor_##NAME##_scalarlist_kernel_slow(tensors, scalars); \
64-
} \
65-
\
66-
return foreach_binary_op<OP>(tensors, scalars); \
49+
#define FOREACH_BINARY_OP_SCALARLIST(NAME, OP) \
50+
void foreach_tensor_##NAME##_scalarlist_kernel_cuda_(TensorList tensors, at::ArrayRef<Scalar> scalars) { \
51+
check_foreach_api_restrictions(tensors, scalars); \
52+
if (!can_use_fast_route(tensors, scalars)) { \
53+
return at::native::foreach_tensor_##NAME##_scalarlist_kernel_slow_(tensors, scalars); \
54+
} \
55+
\
56+
foreach_binary_op_<OP>(tensors, scalars); \
57+
} \
58+
\
59+
std::vector<Tensor> foreach_tensor_##NAME##_scalarlist_kernel_cuda(TensorList tensors, at::ArrayRef<Scalar> scalars) { \
60+
check_foreach_api_restrictions(tensors, scalars); \
61+
if (!can_use_fast_route(tensors, scalars)) { \
62+
return at::native::foreach_tensor_##NAME##_scalarlist_kernel_slow(tensors, scalars); \
63+
} \
64+
\
65+
return foreach_binary_op<OP>(tensors, scalars); \
6766
}
6867

6968
FOREACH_BINARY_OP_SCALARLIST(add, std::plus);

0 commit comments

Comments
 (0)