44namespace at {
55namespace native {
66namespace {
7-
7+ // Check foreach API restrictions
8+ // - Tensor lists must be non-empty.
9+ // - All tensors in all lists must have the same dtype.
10+ // - All TensorLists and ScalarLists must have the same number of elements.
11+ // - Corresponding tensors must have the same size.
812void check_foreach_api_restrictions (TensorList tensors) {
913 TORCH_CHECK (tensors.size () > 0 , " Tensor list must have at least one tensor." );
1014 auto expected_dtype = tensors[0 ].dtype ();
@@ -13,7 +17,7 @@ void check_foreach_api_restrictions(TensorList tensors) {
1317 }
1418}
1519
16- void check_foreach_api_restrictions (TensorList tensors, ArrayRef<double > scalars) {
20+ void check_foreach_api_restrictions (TensorList tensors, ArrayRef<Scalar > scalars) {
1721 check_foreach_api_restrictions (tensors);
1822 TORCH_CHECK (tensors.size () == scalars.size (), " Tensor list must have same number of elements as scalar list." );
1923}
@@ -49,7 +53,7 @@ void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2, Te
4953 }
5054}
5155
52- void check_foreach_api_restrictions (TensorList tensors1, TensorList tensors2, TensorList tensors3, ArrayRef<double > scalars) {
56+ void check_foreach_api_restrictions (TensorList tensors1, TensorList tensors2, TensorList tensors3, ArrayRef<Scalar > scalars) {
5357 check_foreach_api_restrictions (tensors1, tensors2, tensors3);
5458 TORCH_CHECK (tensors1.size () == scalars.size (), " Tensor list must have same number of elements as scalar list, got " , tensors1.size (), " and " , scalars.size ());
5559}
@@ -85,21 +89,8 @@ bool has_same_attributes(Device expected_device, TensorList tensors) {
8589}
8690
8791bool will_promote_tensor (const Tensor& tensor, Scalar scalar) {
88- // complex scalar + integral or boolean tensor will result in complex tensor
89- if (scalar.isComplex () && at::isIntegralType (tensor.scalar_type (), /* includeBool*/ true )) {
90- return false ;
91- }
92-
93- // float scalar + integral or boolean tensor will result in float tensor
94- if (scalar.isFloatingPoint () && at::isIntegralType (tensor.scalar_type (), /* includeBool*/ true )) {
95- return false ;
96- }
97-
98- // integral scalar + boolean tensor will result in integral tensor
99- if (scalar.isIntegral (/* includeBool*/ false ) && tensor.dtype () == at::kBool ) {
100- return false ;
101- }
102- return true ;
92+ auto result_dtype = at::result_type (tensor, scalar);
93+ return result_dtype != tensor.scalar_type ();
10394}
10495
10596bool can_use_fast_route (TensorList tensors) {
@@ -128,7 +119,7 @@ bool can_use_fast_route(TensorList tensors, Scalar scalar) {
128119 return false ;
129120 }
130121
131- if (! will_promote_tensor (t, scalar)) {
122+ if (will_promote_tensor (t, scalar)) {
132123 return false ;
133124 }
134125 }
@@ -137,8 +128,18 @@ bool can_use_fast_route(TensorList tensors, Scalar scalar) {
137128#endif
138129}
139130
140- bool can_use_fast_route (TensorList tensors, ArrayRef<double > scalars) {
141- return can_use_fast_route (tensors);
131+ bool can_use_fast_route (TensorList tensors, ArrayRef<Scalar> scalars) {
132+ #ifdef __HIP_PLATFORM_HCC__
133+ return false ;
134+ #else
135+ for (int i = 0 ; i < tensors.size (); i++) {
136+ if (will_promote_tensor (tensors[i], scalars[i])) {
137+ return false ;
138+ }
139+ }
140+
141+ return true ;
142+ #endif
142143}
143144
144145bool can_use_fast_route (TensorList tensors1, TensorList tensors2) {
@@ -166,7 +167,7 @@ bool can_use_fast_route(TensorList tensors1, TensorList tensors2, Scalar scalar)
166167 return false ;
167168 }
168169
169- if (! will_promote_tensor (tensors1[i], scalar)) {
170+ if (will_promote_tensor (tensors1[i], scalar)) {
170171 return false ;
171172 }
172173 }
@@ -200,7 +201,7 @@ bool can_use_fast_route(TensorList tensors1, TensorList tensors2, TensorList ten
200201 return false ;
201202 }
202203
203- if (! will_promote_tensor (tensors1[i], scalar)) {
204+ if (will_promote_tensor (tensors1[i], scalar)) {
204205 return false ;
205206 }
206207 }
@@ -209,7 +210,7 @@ bool can_use_fast_route(TensorList tensors1, TensorList tensors2, TensorList ten
209210#endif
210211}
211212
212- bool can_use_fast_route (TensorList tensors1, TensorList tensors2, TensorList tensors3, ArrayRef<double > scalars) {
213+ bool can_use_fast_route (TensorList tensors1, TensorList tensors2, TensorList tensors3, ArrayRef<Scalar > scalars) {
213214 return can_use_fast_route (tensors1, tensors2, tensors3);
214215}
215216
0 commit comments