11#include < ATen/Dispatch.h>
22#include < ATen/native/ForeachUtils.h>
33#include < ATen/native/cuda/ForeachFunctors.cuh>
4+ #include < ATen/native/BinaryOps.h>
45
56namespace at { namespace native {
67
@@ -16,7 +17,7 @@ std::vector<Tensor> foreach_binary_op(TensorList tensors, at::ArrayRef<Scalar> s
1617 tensor_lists.emplace_back (tensors.vec ());
1718 tensor_lists.emplace_back (vec_res);
1819
19- AT_DISPATCH_ALL_TYPES_AND2 (kBFloat16 , kHalf , tensors[0 ].scalar_type (), " foreach_binary_op_scalarlist_cuda" , [&]() {
20+ AT_DISPATCH_ALL_TYPES_AND3 (kBFloat16 , kHalf , kBool , tensors[0 ].scalar_type (), " foreach_binary_op_scalarlist_cuda" , [&]() {
2021 using opmath_t = get_opmath_t <scalar_t >::opmath_t ;
2122 multi_tensor_apply<2 , opmath_t >(tensor_lists,
2223 scalars,
@@ -35,7 +36,7 @@ void foreach_binary_op_(TensorList tensors, at::ArrayRef<Scalar> scalars) {
3536 std::vector<std::vector<at::Tensor>> tensor_lists;
3637 tensor_lists.emplace_back (tensors.vec ());
3738
38- AT_DISPATCH_ALL_TYPES_AND2 (kBFloat16 , kHalf , tensors[0 ].scalar_type (), " foreach_binary_op_scalarlist_cuda_" , [&]() {
39+ AT_DISPATCH_ALL_TYPES_AND3 (kBFloat16 , kHalf , kBool , tensors[0 ].scalar_type (), " foreach_binary_op_scalarlist_cuda_" , [&]() {
3940 using opmath_t = get_opmath_t <scalar_t >::opmath_t ;
4041 multi_tensor_apply<1 , opmath_t >(tensor_lists,
4142 scalars,
@@ -47,10 +48,10 @@ void foreach_binary_op_(TensorList tensors, at::ArrayRef<Scalar> scalars) {
4748 });
4849}
4950
50- #define FOREACH_BINARY_OP_SCALARLIST (NAME, OP ) \
51+ #define FOREACH_BINARY_OP_SCALARLIST (NAME, OP, DIV_OP ) \
5152void foreach_tensor_##NAME##_scalarlist_kernel_cuda_(TensorList tensors, at::ArrayRef<Scalar> scalars) { \
5253 check_foreach_api_restrictions (tensors, scalars); \
53- if (!can_use_fast_route (tensors, scalars)) { \
54+ if (!can_use_fast_route (tensors, scalars, DIV_OP )) { \
5455 return at::native::foreach_tensor_##NAME##_scalarlist_kernel_slow_ (tensors, scalars); \
5556 } \
5657 \
@@ -59,16 +60,43 @@ void foreach_tensor_##NAME##_scalarlist_kernel_cuda_(TensorList tensors, at::Arr
5960 \
6061std::vector<Tensor> foreach_tensor_##NAME##_scalarlist_kernel_cuda(TensorList tensors, at::ArrayRef<Scalar> scalars) { \
6162 check_foreach_api_restrictions (tensors, scalars); \
62- if (!can_use_fast_route (tensors, scalars)) { \
63+ if (!can_use_fast_route (tensors, scalars, DIV_OP )) { \
6364 return at::native::foreach_tensor_##NAME##_scalarlist_kernel_slow (tensors, scalars); \
6465 } \
6566 \
6667 return foreach_binary_op<OP>(tensors, scalars); \
6768}
6869
69- FOREACH_BINARY_OP_SCALARLIST (add, std::plus);
70- FOREACH_BINARY_OP_SCALARLIST (sub, std::minus);
71- FOREACH_BINARY_OP_SCALARLIST (mul, std::multiplies);
72- FOREACH_BINARY_OP_SCALARLIST (div, std::divides);
70+ FOREACH_BINARY_OP_SCALARLIST (add, std::plus, /* div_op*/ false );
71+ FOREACH_BINARY_OP_SCALARLIST (mul, std::multiplies, /* div_op*/ false );
72+ FOREACH_BINARY_OP_SCALARLIST (div, std::divides, /* div_op*/ true );
73+
74+ // This does not use FOREACH_BINARY_OP_SCALARLIST because
75+ // In the case of subtraction, we dont allow scalar to be boolean following the torch.sub logic
76+ void foreach_tensor_sub_scalarlist_kernel_cuda_ (TensorList tensors, at::ArrayRef<Scalar> scalars) {
77+ check_foreach_api_restrictions (tensors, scalars);
78+ for (int i = 0 ; i < tensors.size (); i++) {
79+ sub_check (tensors[i], scalars[i]);
80+ }
81+
82+ if (!can_use_fast_route ({tensors}, scalars)) {
83+ return at::native::foreach_tensor_sub_scalarlist_kernel_slow_ (tensors, scalars);
84+ }
85+
86+ foreach_binary_op_<std::minus>(tensors, scalars);
87+ }
88+
89+ std::vector<Tensor> foreach_tensor_sub_scalarlist_kernel_cuda (TensorList tensors, at::ArrayRef<Scalar> scalars) {
90+ check_foreach_api_restrictions (tensors, scalars);
91+ for (int i = 0 ; i < tensors.size (); i++) {
92+ sub_check (tensors[i], scalars[i]);
93+ }
94+
95+ if (!can_use_fast_route ({tensors}, scalars)) {
96+ return at::native::foreach_tensor_sub_scalarlist_kernel_slow (tensors, scalars);
97+ }
98+
99+ return foreach_binary_op<std::minus>(tensors, scalars);
100+ }
73101
74102}} // namespace at::native
0 commit comments