You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[PyTorch] Reduce errors of foreach functions (#56993)
Summary:
This is based on #48224.
To make `foreach` more flexible, this PR pushes unsupported cases to slow path.
Also, this adds some tests to verify that
- `foreach` functions work with tensors of different dtypes and/or memory layouts in 7bd4b2c
- `foreach` functions work with tensors on different devices in a list, but are on the same device if the indices are the same: def4b9b
Future plans:
1. Improve the coverage of unittests using `ops` decorator & updating `foreach_unary_op_db` and creating `foreach_(binary|pointwise|minmax)_db`.
2. Support broadcasting in slow path. Ref: #52448
3. Support type promotion in fast path. Ref #52449
CC: ngimel mcarilli ptrblck
Pull Request resolved: #56993
Reviewed By: zou3519
Differential Revision: D28630580
Pulled By: ngimel
fbshipit-source-id: e26ee74a39a591025e18c1ead48948cb7ec53c19
TORCH_CHECK(tensors2.size() > 0, "Tensor list must have at least one tensor.");
30
36
TORCH_CHECK(tensors1.size() == tensors2.size(), "Tensor lists must have the same number of tensors, got ", tensors1.size(), " and ", tensors2.size());
31
37
32
-
auto expected_dtype = tensors1[0].dtype();
33
-
34
38
for (constauto i : c10::irange(tensors1.size())) {
35
-
TORCH_CHECK(tensors1[i].dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype.");
36
-
TORCH_CHECK(tensors2[i].dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype.");
37
39
TORCH_CHECK(tensors1[i].sizes() == tensors2[i].sizes(), "Corresponding tensors in lists must have the same size, got ", tensors1[i].sizes(), " and ", tensors2[i].sizes());
38
40
}
39
41
}
@@ -45,11 +47,7 @@ void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2, Te
45
47
TORCH_CHECK(tensors1.size() == tensors2.size(), "Tensor lists must have the same number of tensors, got ", tensors1.size(), " and ", tensors2.size());
46
48
TORCH_CHECK(tensors1.size() == tensors3.size(), "Tensor lists must have the same number of tensors, got ", tensors1.size(), " and ", tensors3.size());
47
49
48
-
auto expected_dtype = tensors1[0].dtype();
49
-
50
50
for (constauto i : c10::irange(tensors1.size())) {
51
-
TORCH_CHECK(tensors1[i].dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype.");
52
-
TORCH_CHECK(tensors2[i].dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype.");
53
51
TORCH_CHECK(tensors1[i].sizes() == tensors2[i].sizes(), "Corresponding tensors in lists must have the same size, got ", tensors1[i].sizes(), " and ", tensors2[i].sizes());
54
52
TORCH_CHECK(tensors1[i].sizes() == tensors3[i].sizes(), "Corresponding tensors in lists must have the same size, got ", tensors1[i].sizes(), " and ", tensors3[i].sizes());
55
53
}
@@ -61,20 +59,24 @@ void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2, Te
61
59
}
62
60
63
61
// To go via 'fast' path, several conditions must be satisfied
62
+
// - All tensors in all lists must have the same dtype.
64
63
// - All tensors must be on the same device
65
64
// - All tensors must have strided layout
66
65
// - All tensors must be non-overlapping and dense
67
66
// - Resulting tensor must have the same dtype as the input one
68
67
68
+
// TODO(mkozuki): Consider whether we really need this function or not.
69
+
// Note that, there is a possibility that foreach fastpath supports type promotion in the future,
70
+
// which might complicate the functionality this function should provides.
71
+
// However, as of now, the check of division op with integer inputs is duplicated.
72
+
// `check_fast_path_restrictions` does the same thing in it before calling this function.
0 commit comments