Skip to content

Commit ec277ac

Browse files
committed
Update on "Make storage access error NotImplementedError"
Signed-off-by: Edward Z. Yang <ezyang@fb.com> Differential Revision: [D27036573](https://our.internmc.facebook.com/intern/diff/D27036573) [ghstack-poisoned]
2 parents 4b1083e + d408369 commit ec277ac

9 files changed

Lines changed: 258 additions & 180 deletions

File tree

aten/src/ATen/native/BinaryOps.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ inline void sub_check(const Tensor& self, const Tensor& other) {
2525
"If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
2626
}
2727

28+
inline void sub_check(const Tensor& self, const Scalar& scalar) {
29+
TORCH_CHECK(self.scalar_type() != kBool || !scalar.isBoolean(),
30+
"Subtraction, the `-` operator, with two bool tensors is not supported."
31+
"Use the `^` or `logical_xor()` operator instead.")
32+
TORCH_CHECK(self.scalar_type() != kBool && !scalar.isBoolean(),
33+
"Subtraction, the `-` operator, with a bool tensor is not supported. "
34+
"If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
35+
}
36+
2837
using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, Scalar alpha);
2938

3039
using binary_fn_alpha = void(*)(TensorIterator&, Scalar alpha);

aten/src/ATen/native/ForeachUtils.h

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,13 @@ void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2, Te
6464
// - All tensors must be non-overlapping and dense
6565
// - Resulting tensor must have the same dtype as the input one
6666

67-
bool will_promote_tensor(const Tensor& tensor, Scalar scalar) {
67+
bool will_promote_tensor(const Tensor& tensor, Scalar scalar, bool does_op_promote_integer_inputs_to_float = false) {
68+
// In case of division, integer inputs will result in float
69+
if (does_op_promote_integer_inputs_to_float) {
70+
if (at::isIntegralType(tensor.scalar_type(), /*includeBool*/ true)) {
71+
return true;
72+
}
73+
}
6874
auto result_dtype = at::result_type(tensor, scalar);
6975
return result_dtype != tensor.scalar_type();
7076
}
@@ -73,7 +79,8 @@ bool will_promote_tensor(const Tensor& tensor, Scalar scalar) {
7379
// There is a set of preconditions that have to be satisfied.
7480
bool check_fast_path_restrictions(
7581
ArrayRef<TensorList> tensorLists,
76-
ArrayRef<Scalar> scalarList = {}) {
82+
ArrayRef<Scalar> scalarList = {},
83+
bool does_op_promote_integer_inputs_to_float = false) {
7784
auto expected_device = tensorLists[0][0].device();
7885

7986
auto is_tensor_okay = [&](const Tensor& tensor) {
@@ -103,12 +110,18 @@ bool check_fast_path_restrictions(
103110
// checked by `check_foreach_api_restrictions`). This means we only need to check if
104111
// {tensorList[0][0], tensorList[0][1], tensorList[0][2], ...} do type promotion with scalarLIst.
105112
for (int i=0; i < tensorLists[0].size(); i++) {
113+
if (does_op_promote_integer_inputs_to_float) {
114+
if (at::isIntegralType(tensorLists[0][i].scalar_type(), /*includeBool*/ true)) {
115+
return false;
116+
}
117+
}
118+
106119
if (scalarList.size() == 1) {
107120
if (will_promote_tensor(tensorLists[0][i], scalarList[0])) {
108121
return false;
109122
}
110123
} else if (scalarList.size() > 1) {
111-
// Complex scalar list is not supported.
124+
// Complex scalar list is not supported due to the limit for kernel launch argument (4KB)
112125
if (scalarList[i].isComplex()) {
113126
return false;
114127
}
@@ -123,19 +136,20 @@ bool check_fast_path_restrictions(
123136
}
124137

125138
bool can_use_fast_route(ArrayRef<TensorList> tensorLists,
126-
ArrayRef<Scalar> scalarList = {}) {
139+
ArrayRef<Scalar> scalarList = {},
140+
bool does_op_promote_integer_inputs_to_float = false) {
127141
#ifdef __HIP_PLATFORM_HCC__
128142
return false;
129143
#else
130-
return check_fast_path_restrictions(tensorLists, scalarList);
144+
return check_fast_path_restrictions(tensorLists, scalarList, does_op_promote_integer_inputs_to_float);
131145
#endif
132146
}
133147

134-
bool can_use_fast_route(TensorList tensors1, TensorList tensors2) {
148+
bool can_use_fast_route(TensorList tensors1, TensorList tensors2, bool does_op_promote_integer_inputs_to_float = false) {
135149
#ifdef __HIP_PLATFORM_HCC__
136150
return false;
137151
#else
138-
return can_use_fast_route({tensors1, tensors2}, {});
152+
return can_use_fast_route({tensors1, tensors2}, {}, does_op_promote_integer_inputs_to_float);
139153
#endif
140154
}
141155

aten/src/ATen/native/cuda/ForeachBinaryOpList.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ void foreach_tensor_list_op_(TensorList tensors1, TensorList tensors2, Scalar al
4949
});
5050
}
5151

52-
#define FOREACH_BINARY_OP_LIST(NAME, OP) \
52+
#define FOREACH_BINARY_OP_LIST(NAME, OP, DIVISION_OP) \
5353
void foreach_tensor_##NAME##_list_kernel_cuda_(TensorList tensors1, TensorList tensors2) { \
5454
check_foreach_api_restrictions(tensors1, tensors2); \
55-
if (!can_use_fast_route({tensors1, tensors2})) { \
55+
if (!can_use_fast_route(tensors1, tensors2, DIVISION_OP)) { \
5656
return at::native::foreach_tensor_##NAME##_list_kernel_slow_(tensors1, tensors2); \
5757
} \
5858
\
@@ -61,7 +61,7 @@ void foreach_tensor_##NAME##_list_kernel_cuda_(TensorList tensors1, TensorList t
6161
\
6262
std::vector<Tensor> foreach_tensor_##NAME##_list_kernel_cuda(TensorList tensors1, TensorList tensors2) { \
6363
check_foreach_api_restrictions(tensors1, tensors2); \
64-
if (!can_use_fast_route({tensors1, tensors2})) { \
64+
if (!can_use_fast_route(tensors1, tensors2, DIVISION_OP)) { \
6565
return at::native::foreach_tensor_##NAME##_list_kernel_slow(tensors1, tensors2); \
6666
} \
6767
\
@@ -89,7 +89,7 @@ std::vector<Tensor> foreach_tensor_##NAME##_list_kernel_cuda(TensorList tensors1
8989

9090
FOREACH_BINARY_OP_LIST_ALPHA(add, std::plus);
9191
FOREACH_BINARY_OP_LIST_ALPHA(sub, std::minus);
92-
FOREACH_BINARY_OP_LIST(mul, std::multiplies);
93-
FOREACH_BINARY_OP_LIST(div, std::divides);
92+
FOREACH_BINARY_OP_LIST(mul, std::multiplies, /*division_op*/ false);
93+
FOREACH_BINARY_OP_LIST(div, std::divides, /*division_op*/ true);
9494

9595
}} // namespace at::native

aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include <ATen/Dispatch.h>
22
#include <ATen/native/ForeachUtils.h>
33
#include <ATen/native/cuda/ForeachFunctors.cuh>
4-
4+
#include <ATen/native/BinaryOps.h>
55
namespace at { namespace native {
66

77
template<template<class> class Op>
@@ -46,10 +46,10 @@ void foreach_binary_op_(TensorList tensors, Scalar scalar) {
4646
});
4747
}
4848

49-
#define FOREACH_BINARY_OP_SCALAR(NAME, OP) \
49+
#define FOREACH_BINARY_OP_SCALAR(NAME, OP, DIVISION_OP) \
5050
void foreach_tensor_##NAME##_scalar_kernel_cuda_(TensorList tensors, Scalar scalar) { \
5151
check_foreach_api_restrictions(tensors); \
52-
if (!can_use_fast_route(tensors, scalar)) { \
52+
if (!can_use_fast_route(tensors, scalar, DIVISION_OP)) { \
5353
return at::native::foreach_tensor_##NAME##_scalar_kernel_slow_(tensors, scalar); \
5454
} \
5555
\
@@ -58,16 +58,41 @@ void foreach_tensor_##NAME##_scalar_kernel_cuda_(TensorList tensors, Scalar scal
5858
\
5959
std::vector<Tensor> foreach_tensor_##NAME##_scalar_kernel_cuda(TensorList tensors, Scalar scalar) { \
6060
check_foreach_api_restrictions(tensors); \
61-
if (!can_use_fast_route(tensors, scalar)) { \
61+
if (!can_use_fast_route(tensors, scalar, DIVISION_OP)) { \
6262
return at::native::foreach_tensor_##NAME##_scalar_kernel_slow(tensors, scalar); \
6363
} \
6464
\
6565
return foreach_binary_op<OP>(tensors, scalar); \
6666
}
6767

68-
FOREACH_BINARY_OP_SCALAR(add, std::plus);
69-
FOREACH_BINARY_OP_SCALAR(sub, std::minus);
70-
FOREACH_BINARY_OP_SCALAR(mul, std::multiplies);
71-
FOREACH_BINARY_OP_SCALAR(div, std::divides);
68+
FOREACH_BINARY_OP_SCALAR(add, std::plus, false);
69+
FOREACH_BINARY_OP_SCALAR(mul, std::multiplies, false);
70+
71+
// In the case of division, integer inputs will result in float.
72+
// Currently multi tensor apply can only return result of the same type as input.
73+
FOREACH_BINARY_OP_SCALAR(div, std::divides, true);
74+
75+
// In the case of subtraction, we dont allow scalar to be boolean following the torch.sub logic
76+
void foreach_tensor_sub_scalar_kernel_cuda_(TensorList tensors, Scalar scalar) {
77+
check_foreach_api_restrictions(tensors);
78+
at::native::sub_check(tensors[0], scalar);
79+
80+
if (!can_use_fast_route(tensors, scalar)) {
81+
return at::native::foreach_tensor_sub_scalar_kernel_slow_(tensors, scalar);
82+
}
83+
84+
foreach_binary_op_<std::minus>(tensors, scalar);
85+
}
86+
87+
std::vector<Tensor> foreach_tensor_sub_scalar_kernel_cuda(TensorList tensors, Scalar scalar) {
88+
check_foreach_api_restrictions(tensors);
89+
at::native::sub_check(tensors[0], scalar);
90+
91+
if (!can_use_fast_route(tensors, scalar)) {
92+
return at::native::foreach_tensor_sub_scalar_kernel_slow(tensors, scalar);
93+
}
94+
95+
return foreach_binary_op<std::minus>(tensors, scalar);
96+
}
7297

7398
}} // namespace at::native

aten/src/ATen/native/cuda/ForeachBinaryOpScalarList.cu

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <ATen/Dispatch.h>
22
#include <ATen/native/ForeachUtils.h>
33
#include <ATen/native/cuda/ForeachFunctors.cuh>
4+
#include <ATen/native/BinaryOps.h>
45

56
namespace at { namespace native {
67

@@ -16,7 +17,7 @@ std::vector<Tensor> foreach_binary_op(TensorList tensors, at::ArrayRef<Scalar> s
1617
tensor_lists.emplace_back(tensors.vec());
1718
tensor_lists.emplace_back(vec_res);
1819

19-
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_binary_op_scalarlist_cuda", [&]() {
20+
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, tensors[0].scalar_type(), "foreach_binary_op_scalarlist_cuda", [&]() {
2021
using opmath_t = get_opmath_t<scalar_t>::opmath_t;
2122
multi_tensor_apply<2, opmath_t>(tensor_lists,
2223
scalars,
@@ -35,7 +36,7 @@ void foreach_binary_op_(TensorList tensors, at::ArrayRef<Scalar> scalars) {
3536
std::vector<std::vector<at::Tensor>> tensor_lists;
3637
tensor_lists.emplace_back(tensors.vec());
3738

38-
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_binary_op_scalarlist_cuda_", [&]() {
39+
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, tensors[0].scalar_type(), "foreach_binary_op_scalarlist_cuda_", [&]() {
3940
using opmath_t = get_opmath_t<scalar_t>::opmath_t;
4041
multi_tensor_apply<1, opmath_t>(tensor_lists,
4142
scalars,
@@ -47,10 +48,10 @@ void foreach_binary_op_(TensorList tensors, at::ArrayRef<Scalar> scalars) {
4748
});
4849
}
4950

50-
#define FOREACH_BINARY_OP_SCALARLIST(NAME, OP) \
51+
#define FOREACH_BINARY_OP_SCALARLIST(NAME, OP, DIV_OP) \
5152
void foreach_tensor_##NAME##_scalarlist_kernel_cuda_(TensorList tensors, at::ArrayRef<Scalar> scalars) { \
5253
check_foreach_api_restrictions(tensors, scalars); \
53-
if (!can_use_fast_route(tensors, scalars)) { \
54+
if (!can_use_fast_route(tensors, scalars, DIV_OP)) { \
5455
return at::native::foreach_tensor_##NAME##_scalarlist_kernel_slow_(tensors, scalars); \
5556
} \
5657
\
@@ -59,16 +60,43 @@ void foreach_tensor_##NAME##_scalarlist_kernel_cuda_(TensorList tensors, at::Arr
5960
\
6061
std::vector<Tensor> foreach_tensor_##NAME##_scalarlist_kernel_cuda(TensorList tensors, at::ArrayRef<Scalar> scalars) { \
6162
check_foreach_api_restrictions(tensors, scalars); \
62-
if (!can_use_fast_route(tensors, scalars)) { \
63+
if (!can_use_fast_route(tensors, scalars, DIV_OP)) { \
6364
return at::native::foreach_tensor_##NAME##_scalarlist_kernel_slow(tensors, scalars); \
6465
} \
6566
\
6667
return foreach_binary_op<OP>(tensors, scalars); \
6768
}
6869

69-
FOREACH_BINARY_OP_SCALARLIST(add, std::plus);
70-
FOREACH_BINARY_OP_SCALARLIST(sub, std::minus);
71-
FOREACH_BINARY_OP_SCALARLIST(mul, std::multiplies);
72-
FOREACH_BINARY_OP_SCALARLIST(div, std::divides);
70+
FOREACH_BINARY_OP_SCALARLIST(add, std::plus, /*div_op*/ false);
71+
FOREACH_BINARY_OP_SCALARLIST(mul, std::multiplies, /*div_op*/ false);
72+
FOREACH_BINARY_OP_SCALARLIST(div, std::divides, /*div_op*/ true);
73+
74+
// This does not use FOREACH_BINARY_OP_SCALARLIST because
75+
// In the case of subtraction, we dont allow scalar to be boolean following the torch.sub logic
76+
void foreach_tensor_sub_scalarlist_kernel_cuda_(TensorList tensors, at::ArrayRef<Scalar> scalars) {
77+
check_foreach_api_restrictions(tensors, scalars);
78+
for (int i = 0; i < tensors.size(); i++) {
79+
sub_check(tensors[i], scalars[i]);
80+
}
81+
82+
if (!can_use_fast_route({tensors}, scalars)) {
83+
return at::native::foreach_tensor_sub_scalarlist_kernel_slow_(tensors, scalars);
84+
}
85+
86+
foreach_binary_op_<std::minus>(tensors, scalars);
87+
}
88+
89+
std::vector<Tensor> foreach_tensor_sub_scalarlist_kernel_cuda(TensorList tensors, at::ArrayRef<Scalar> scalars) {
90+
check_foreach_api_restrictions(tensors, scalars);
91+
for (int i = 0; i < tensors.size(); i++) {
92+
sub_check(tensors[i], scalars[i]);
93+
}
94+
95+
if (!can_use_fast_route({tensors}, scalars)) {
96+
return at::native::foreach_tensor_sub_scalarlist_kernel_slow(tensors, scalars);
97+
}
98+
99+
return foreach_binary_op<std::minus>(tensors, scalars);
100+
}
73101

74102
}} // namespace at::native

0 commit comments

Comments
 (0)