Skip to content

Commit 7eade66

Browse files
crcrparfacebook-github-bot
authored andcommitted
[PyTorch] Reduce errors of foreach functions (#56993)
Summary: This is based on #48224. To make `foreach` more flexible, this PR pushes unsupported cases to slow path. Also, this adds some tests to verify that - `foreach` functions work with tensors of different dtypes and/or memory layouts in 7bd4b2c - `foreach` functions work with tensors on different devices in a list, but are on the same device if the indices are the same: def4b9b Future plans: 1. Improve the coverage of unittests using `ops` decorator & updating `foreach_unary_op_db` and creating `foreach_(binary|pointwise|minmax)_db`. 2. Support broadcasting in slow path. Ref: #52448 3. Support type promotion in fast path. Ref #52449 CC: ngimel mcarilli ptrblck Pull Request resolved: #56993 Reviewed By: zou3519 Differential Revision: D28630580 Pulled By: ngimel fbshipit-source-id: e26ee74a39a591025e18c1ead48948cb7ec53c19
1 parent 8a28bbe commit 7eade66

6 files changed

Lines changed: 176 additions & 64 deletions

File tree

aten/src/ATen/native/ForeachUtils.h

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,23 @@
66
namespace at {
77
namespace native {
88
namespace {
9+
// Check if tensor list has either a boolean tensor or a integer tensor
10+
bool has_integral_tensor(TensorList tensors, const bool includeBool) {
11+
return std::any_of(tensors.begin(), tensors.end(),
12+
[&includeBool](const auto & t) { return at::isIntegralType(t.scalar_type(), includeBool); });
13+
}
14+
// check if tensor list has bool tensors
15+
bool has_bool_tensor(TensorList tensors) {
16+
return std::any_of(tensors.begin(), tensors.end(),
17+
[](const auto & t) -> bool { return t.scalar_type() == ScalarType::Bool; });
18+
}
19+
920
// Check foreach API restrictions
1021
// - Tensor lists must be non-empty.
11-
// - All tensors in all lists must have the same dtype.
1222
// - All TensorLists and ScalarLists must have the same number of elements.
1323
// - Corresponding tensors must have the same size.
1424
void check_foreach_api_restrictions(TensorList tensors) {
1525
TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor.");
16-
auto expected_dtype = tensors[0].dtype();
17-
for (const auto& t : tensors) {
18-
TORCH_CHECK(t.dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype.");
19-
}
2026
}
2127

2228
void check_foreach_api_restrictions(TensorList tensors, ArrayRef<Scalar> scalars) {
@@ -29,11 +35,7 @@ void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2) {
2935
TORCH_CHECK(tensors2.size() > 0, "Tensor list must have at least one tensor.");
3036
TORCH_CHECK(tensors1.size() == tensors2.size(), "Tensor lists must have the same number of tensors, got ", tensors1.size(), " and ", tensors2.size());
3137

32-
auto expected_dtype = tensors1[0].dtype();
33-
3438
for (const auto i : c10::irange(tensors1.size())) {
35-
TORCH_CHECK(tensors1[i].dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype.");
36-
TORCH_CHECK(tensors2[i].dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype.");
3739
TORCH_CHECK(tensors1[i].sizes() == tensors2[i].sizes(), "Corresponding tensors in lists must have the same size, got ", tensors1[i].sizes(), " and ", tensors2[i].sizes());
3840
}
3941
}
@@ -45,11 +47,7 @@ void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2, Te
4547
TORCH_CHECK(tensors1.size() == tensors2.size(), "Tensor lists must have the same number of tensors, got ", tensors1.size(), " and ", tensors2.size());
4648
TORCH_CHECK(tensors1.size() == tensors3.size(), "Tensor lists must have the same number of tensors, got ", tensors1.size(), " and ", tensors3.size());
4749

48-
auto expected_dtype = tensors1[0].dtype();
49-
5050
for (const auto i : c10::irange(tensors1.size())) {
51-
TORCH_CHECK(tensors1[i].dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype.");
52-
TORCH_CHECK(tensors2[i].dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype.");
5351
TORCH_CHECK(tensors1[i].sizes() == tensors2[i].sizes(), "Corresponding tensors in lists must have the same size, got ", tensors1[i].sizes(), " and ", tensors2[i].sizes());
5452
TORCH_CHECK(tensors1[i].sizes() == tensors3[i].sizes(), "Corresponding tensors in lists must have the same size, got ", tensors1[i].sizes(), " and ", tensors3[i].sizes());
5553
}
@@ -61,20 +59,24 @@ void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2, Te
6159
}
6260

6361
// To go via 'fast' path, several conditions must be satisfied
62+
// - All tensors in all lists must have the same dtype.
6463
// - All tensors must be on the same device
6564
// - All tensors must have strided layout
6665
// - All tensors must be non-overlapping and dense
6766
// - Resulting tensor must have the same dtype as the input one
6867

68+
// TODO(mkozuki): Consider whether we really need this function or not.
69+
// Note that, there is a possibility that foreach fastpath supports type promotion in the future,
70+
// which might complicate the functionality this function should provides.
71+
// However, as of now, the check of division op with integer inputs is duplicated.
72+
// `check_fast_path_restrictions` does the same thing in it before calling this function.
6973
bool will_promote_tensor(const Tensor& tensor, const Scalar& scalar, bool does_op_promote_integer_inputs_to_float = false) {
7074
// In case of division, integer inputs will result in float
71-
if (does_op_promote_integer_inputs_to_float) {
72-
if (at::isIntegralType(tensor.scalar_type(), /*includeBool*/ true)) {
73-
return true;
74-
}
75+
if (does_op_promote_integer_inputs_to_float &&
76+
at::isIntegralType(tensor.scalar_type(), /* includeBool */ true)) {
77+
return true;
7578
}
76-
auto result_dtype = at::result_type(tensor, scalar);
77-
return result_dtype != tensor.scalar_type();
79+
return tensor.scalar_type() != at::native::result_type(scalar, tensor);
7880
}
7981

8082
// Please, make sure to call check_foreach_api_restrictions before calling this method.
@@ -83,10 +85,12 @@ bool check_fast_path_restrictions(
8385
ArrayRef<TensorList> tensorLists,
8486
ArrayRef<Scalar> scalarList = {},
8587
bool does_op_promote_integer_inputs_to_float = false) {
86-
auto expected_device = tensorLists[0][0].device();
88+
const auto expected_dtype = tensorLists[0][0].dtype();
89+
const auto expected_device = tensorLists[0][0].device();
8790

8891
auto is_tensor_okay = [&](const Tensor& tensor) {
89-
return tensor.device() == expected_device &&
92+
return tensor.dtype() == expected_dtype &&
93+
tensor.device() == expected_device &&
9094
tensor.layout() == at::kStrided &&
9195
tensor.is_non_overlapping_and_dense();
9296
};
@@ -108,9 +112,11 @@ bool check_fast_path_restrictions(
108112
}
109113
}
110114

111-
// For all j, tensorList[j][0] have the same shape and dtype. (this was a precondition
112-
// checked by `check_foreach_api_restrictions`). This means we only need to check if
113-
// {tensorList[0][0], tensorList[0][1], tensorList[0][2], ...} do type promotion with scalarLIst.
115+
// This function has already checked that `tensorList[j][i]` for all j, i has the same dtype
116+
// using `is_tensor_okay` function above.
117+
// checked by `check_foreach_api_restrictions`).
118+
// This means we only need to check if {tensorList[0][0], tensorList[0][1], tensorList[0][2], ...}
119+
// do type promotion with scalarLIst.
114120
for (int i=0; i < tensorLists[0].size(); i++) {
115121
if (does_op_promote_integer_inputs_to_float) {
116122
if (at::isIntegralType(tensorLists[0][i].scalar_type(), /*includeBool*/ true)) {
@@ -123,6 +129,8 @@ bool check_fast_path_restrictions(
123129
return false;
124130
}
125131
} else if (scalarList.size() > 1) {
132+
// FIXME(mkozuki): Consider specializing `TensorListScalarListMetadata` for complex dtypes
133+
// to access the following comment.
126134
// Complex scalar list is not supported due to the limit for kernel launch argument (4KB)
127135
if (scalarList[i].isComplex()) {
128136
return false;

aten/src/ATen/native/cuda/AmpKernels.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ void _amp_foreach_non_finite_check_and_unscale_cuda_(TensorList scaled_grads,
113113
// - all scaled_grads are strided
114114
// - all scaled_grads are non overlapping and dense
115115
// - all scaled_grads are on the same device
116+
// - all scaled_grads are of the same dtype
116117
TORCH_CHECK(scaled_grads[0].is_cuda(), "scaled_grads must be CUDA tensors.");
117118
// Sets up MTA launch to use scaled_grads as-is.
118119
tensor_lists.emplace_back(scaled_grads.vec());
@@ -126,12 +127,13 @@ void _amp_foreach_non_finite_check_and_unscale_cuda_(TensorList scaled_grads,
126127
tensor_lists.resize(1);
127128
tensor_lists[0].reserve(scaled_grads.size());
128129
auto expected_device = scaled_grads[0].device();
130+
const auto expected_dtype = scaled_grads[0].scalar_type();
129131
for (const Tensor& t : scaled_grads) {
130132
// Ensures GradScaler filtered scaled_grads by device.
131133
TORCH_CHECK(t.is_cuda(), "one of scaled_grads was not a CUDA tensor.");
132134
TORCH_CHECK(t.device() == expected_device, "scaled_grads must be on the same device.");
133135
TORCH_CHECK(t.layout() == at::kStrided, "one of scaled_grads was not a strided tensor.");
134-
if (!t.is_non_overlapping_and_dense()) {
136+
if (!t.is_non_overlapping_and_dense() || t.scalar_type() != expected_dtype) {
135137
// t is acceptable but not MTA-safe. Falls back to single-tensor TensorIterator kernel.
136138
_amp_non_finite_check_and_unscale_cuda_(const_cast<Tensor&>(t),
137139
found_inf,

aten/src/ATen/native/cuda/ForeachPointwiseOp.cu

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ std::vector<Tensor> foreach_pointwise_op(TensorList input, TensorList tensors1,
105105
std::vector<Tensor> foreach_tensor_##NAME##_scalar_cuda(TensorList input, TensorList tensors1, TensorList tensors2, const Scalar& scalar) { \
106106
check_foreach_api_restrictions(input, tensors1, tensors2); \
107107
\
108-
if (!can_use_fast_route({input, tensors1, tensors2}, scalar)) { \
108+
if (!can_use_fast_route({input, tensors1, tensors2}, scalar) || has_integral_tensor(input, /* includeBool */ true)) { \
109109
return at::native::foreach_tensor_##NAME##_scalar_slow(input, tensors1, tensors2, scalar); \
110110
} \
111111
\
@@ -115,7 +115,7 @@ std::vector<Tensor> foreach_tensor_##NAME##_scalar_cuda(TensorList input, Tensor
115115
void foreach_tensor_##NAME##_scalar_cuda_(TensorList input, TensorList tensors1, TensorList tensors2, const Scalar& scalar) { \
116116
check_foreach_api_restrictions(input, tensors1, tensors2); \
117117
\
118-
if (!can_use_fast_route({input, tensors1, tensors2}, scalar)) { \
118+
if (!can_use_fast_route({input, tensors1, tensors2}, scalar) || has_integral_tensor(input, /* includeBool */ true)) { \
119119
return at::native::foreach_tensor_##NAME##_scalar_slow_(input, tensors1, tensors2, scalar); \
120120
} \
121121
\
@@ -127,7 +127,7 @@ void foreach_tensor_##NAME##_scalar_cuda_(TensorList input, TensorList tensors1,
127127
std::vector<Tensor> foreach_tensor_##NAME##_scalarlist_cuda(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<Scalar> scalars) { \
128128
check_foreach_api_restrictions(input, tensors1, tensors2, scalars); \
129129
\
130-
if (!can_use_fast_route({input, tensors1, tensors2}, scalars)) { \
130+
if (!can_use_fast_route({input, tensors1, tensors2}, scalars) || has_integral_tensor(input, /* includeBool */ true)) { \
131131
return at::native::foreach_tensor_##NAME##_scalarlist_slow(input, tensors1, tensors2, scalars); \
132132
} \
133133
\
@@ -137,7 +137,7 @@ std::vector<Tensor> foreach_tensor_##NAME##_scalarlist_cuda(TensorList input, Te
137137
void foreach_tensor_##NAME##_scalarlist_cuda_(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<Scalar> scalars) { \
138138
check_foreach_api_restrictions(input, tensors1, tensors2, scalars); \
139139
\
140-
if (!can_use_fast_route({input, tensors1, tensors2}, scalars)) { \
140+
if (!can_use_fast_route({input, tensors1, tensors2}, scalars) || has_integral_tensor(input, /* includeBool */ true)) { \
141141
return at::native::foreach_tensor_##NAME##_scalarlist_slow_(input, tensors1, tensors2, scalars); \
142142
} \
143143
\
@@ -149,10 +149,14 @@ FOREACH_POINTWISE_OP_SCALAR(addcdiv, std::divides);
149149
FOREACH_POINTWISE_OP_SCALARLIST(addcmul, std::multiplies);
150150
FOREACH_POINTWISE_OP_SCALARLIST(addcdiv, std::divides);
151151

152+
153+
// Why bool tensors are pushed to slowpath?
154+
// Because `AT_DISPATCH_ALL_TYPES_AND` is used below.
155+
// TODO(mkozuki): Check whether it's possible to handle bool tensors in fastpath.
152156
#define FOREACH_MAXIMUM_MINIMUM_OP(NAME, OP) \
153157
std::vector<Tensor> foreach_tensor_##NAME##_cuda(TensorList tensors1, TensorList tensors2) { \
154158
check_foreach_api_restrictions(tensors1, tensors2); \
155-
if (!can_use_fast_route({tensors1, tensors2})) { \
159+
if (!can_use_fast_route({tensors1, tensors2}) || has_bool_tensor(tensors1)) { \
156160
return at::native::foreach_tensor_##NAME##_slow(tensors1, tensors2); \
157161
} \
158162
\

aten/src/ATen/native/cuda/ForeachUnaryOp.cu

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -133,14 +133,14 @@ struct functor_name { \
133133
#define OP_CUSTOM_FUNCTOR(function, op_name, functor_name) \
134134
std::vector<Tensor> foreach_tensor_##op_name##_cuda(TensorList tensors) { \
135135
check_foreach_api_restrictions(tensors); \
136-
if (!can_use_fast_route(tensors)) { \
136+
if (!can_use_fast_route(tensors) || has_integral_tensor(tensors, /* includeBool */ true)) { \
137137
return at::native::foreach_tensor_##op_name##_slow(tensors); \
138138
} \
139139
return function<functor_name>(tensors); \
140140
} \
141141
void foreach_tensor_##op_name##_cuda_(TensorList tensors) { \
142142
check_foreach_api_restrictions(tensors); \
143-
if (!can_use_fast_route(tensors)) { \
143+
if (!can_use_fast_route(tensors) || has_integral_tensor(tensors, /* includeBool */ true)) { \
144144
return at::native::foreach_tensor_##op_name##_slow_(tensors); \
145145
} \
146146
\
@@ -247,13 +247,9 @@ struct Abs {
247247

248248
std::vector<Tensor> foreach_tensor_abs_cuda(TensorList tensors) {
249249
check_foreach_api_restrictions(tensors);
250-
bool has_complex = false;
251-
for (auto t : tensors) {
252-
if (at::isComplexType(t.scalar_type())) {
253-
has_complex = true;
254-
}
255-
}
256-
250+
const bool has_complex = std::any_of(
251+
tensors.begin(), tensors.end(),
252+
[](const auto & t) { return at::isComplexType(t.scalar_type()); });
257253
if (!can_use_fast_route(tensors) || has_complex) {
258254
return at::native::foreach_tensor_abs_slow(tensors);
259255
}
@@ -263,13 +259,9 @@ std::vector<Tensor> foreach_tensor_abs_cuda(TensorList tensors) {
263259

264260
void foreach_tensor_abs_cuda_(TensorList tensors) {
265261
check_foreach_api_restrictions(tensors);
266-
bool has_complex = false;
267-
for (auto t : tensors) {
268-
if (at::isComplexType(t.scalar_type())) {
269-
has_complex = true;
270-
}
271-
}
272-
262+
const bool has_complex = std::any_of(
263+
tensors.begin(), tensors.end(),
264+
[](const auto & t) { return at::isComplexType(t.scalar_type()); });
273265
if (!can_use_fast_route(tensors) || has_complex) {
274266
return at::native::foreach_tensor_abs_slow_(tensors);
275267
}

test/test_cuda.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1908,13 +1908,16 @@ def test_grad_scaling_unscale(self, dtype=torch.float):
19081908
for grad in grads:
19091909
self.assertTrue(torch.allclose(grad, torch.ones_like(grad), atol=1e-7))
19101910

1911-
# Passing lists with mismatched devices or dtypes to a raw
1911+
# When passing lists with mismatched dtypes to a raw
1912+
# _amp_foreach_non_finite_check_and_unscale_ call,
1913+
# it's expected to fall back to single-tensor TensorIterator kernel.
1914+
grads = [g.clone(), g.to(dtype=torch.float16)]
1915+
torch._amp_foreach_non_finite_check_and_unscale_(grads, found_inf, inv_scale)
1916+
for grad in grads:
1917+
self.assertTrue(torch.allclose(grad, torch.ones_like(grad), atol=1e-7))
1918+
1919+
# Passing lists with mismatched devices to a raw
19121920
# _amp_foreach_non_finite_check_and_unscale_ call should raise errors.
1913-
with self.assertRaisesRegex(RuntimeError, r"must have the same dtype"):
1914-
torch._amp_foreach_non_finite_check_and_unscale_([g.clone(), g.to(dtype=torch.float16)],
1915-
found_inf,
1916-
inv_scale)
1917-
19181921
if TEST_MULTIGPU:
19191922
with self.assertRaisesRegex(RuntimeError, r"Expected all tensors to be on the same device"):
19201923
torch._amp_foreach_non_finite_check_and_unscale_([g.clone(), g.to(device="cuda:1")],

0 commit comments

Comments
 (0)