[PyTorch] Reduce errors of foreach functions#56993
[PyTorch] Reduce errors of foreach functions#56993crcrpar wants to merge 21 commits intopytorch:masterfrom
foreach functions#56993Conversation
Apply skipMeta to reduce warnings
💊 CI failures summary and remediationsAs of commit 4cc29a5 (more details on the Dr. CI page):
🕵️ 2 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
| Job | Step | Action |
|---|---|---|
| Run clang-format | 🔁 rerun |
This comment was automatically generated by Dr. CI (expand for details).
Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group.
aten/src/ATen/native/ForeachUtils.h
Outdated
| TORCH_CHECK(t.dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype."); | ||
| if (check_dtype) { | ||
| const auto expected_dtype = tensors[0].dtype(); | ||
| for (const auto & t : tensors) { |
There was a problem hiding this comment.
nit: you can use std::all_of here
| check_foreach_api_restrictions(input, tensors1, tensors2); \ | ||
| \ | ||
| if (!can_use_fast_route({input, tensors1, tensors2}, scalar)) { \ | ||
| if (!can_use_fast_route({input, tensors1, tensors2}, scalar) || has_int_or_bool_tensor(input)) { \ |
There was a problem hiding this comment.
out of curiosity, why are we getting rid of int tensors? Because of type promotion?
There was a problem hiding this comment.
Yes, type promotion.
| for (const auto& t : tensors) { | ||
| if (at::isComplexType(t.scalar_type())) { | ||
| has_complex = true; | ||
| has_complex_or_integer = true; |
There was a problem hiding this comment.
this is misleading - how did integer get in the name? Also, std::any_of.
There was a problem hiding this comment.
ah, sorry, forgot to fix the name
also, use `std::any_of`
Stop providing a special path for AMP kernels.
Upon the removal of dtype checks from `check_foreach_api_restrictions`, non_finite_check_and_unscale_cuda_ internally checks each tensor's dtype and push it to single-tensor TensorIterator kernel. Also, update the comment about `can_use_fast_route`.
_amp_foreach_non_finite_checks_and_unscale_ is expected to - handle tensors of various dtypes - raise a RuntimeError if devices are different
|
check_foreach_api_restriction functions dropped the option of dtype checks and |
ngimel
left a comment
There was a problem hiding this comment.
This looks good, I've left minor comments.
aten/src/ATen/native/ForeachUtils.h
Outdated
|
|
||
| bool will_promote_tensor(const Tensor& tensor, const Scalar& scalar, bool does_op_promote_integer_inputs_to_float = false) { | ||
| // complex scalar + float/int/bool tensor will result in complex tensor | ||
| if (scalar.isComplex() && (at::isIntegralType(tensor.scalar_type(), /* includeBool= */true) || at::isFloatingType(tensor.scalar_type())) ) { |
There was a problem hiding this comment.
If would be better to call result_type here
| FOREACH_POINTWISE_OP_SCALARLIST(addcdiv, std::divides); | ||
|
|
||
|
|
||
| bool has_bool_tensor(TensorList tensors) { |
There was a problem hiding this comment.
this has to live next to has_int_or_bool_tensor, or, better, be a single function with an argument indicating whether to include bool, has_integral_tensor(TensorList tensors, bool include_bool = False)
There was a problem hiding this comment.
That makes sense, I'll rename has_int_or_bool_tensor to has_integral_tensor and add includeBool argument.
I also will juxtapose has_bool_tensor with has_integral_tensor, but I want to have them as separate functions as the caller of has_bool_tensor can handle int tensors.
Or, is it better to write std::any_of(tensor.begin(), tensors.end(), [](const auto & t) -> bool { t.scalar_type() == Bool;}); inside the caller?
| std::vector<Tensor> foreach_tensor_##NAME##_cuda(TensorList tensors1, TensorList tensors2) { \ | ||
| check_foreach_api_restrictions(tensors1, tensors2); \ | ||
| if (!can_use_fast_route({tensors1, tensors2})) { \ | ||
| if (!can_use_fast_route({tensors1, tensors2}) || has_bool_tensor(tensors1)) { \ |
There was a problem hiding this comment.
can you please leave a comment here why bool tensors can't be handled.
test/test_foreach.py
Outdated
| else: | ||
| self.assertEqual(tensors1, expected) | ||
|
|
||
| @skipMeta |
There was a problem hiding this comment.
out of curiosity, how is it working now, without skipMeta?
There was a problem hiding this comment.
I don't remember the details, but without this I saw a bunch of messages which seem like a ton of failures, but they were somehow treated as warnings or something like that.
The reason I added this decorator is, those messages were like "this op does not support meta tensor" and I thought the messages from every single foreach function test is meaningless as this decorator says all.
There was a problem hiding this comment.
It that's the case (the tests are actually passing), don't skip them please, so that if coverage is extended to meta they don't have to be explicitly reenabled. For you local runs you could still filter to only CPU or only CUDA device (not both together I'm afraid)
There was a problem hiding this comment.
Yes, please don't do this. If you want to reduce the spamminess, edit the code in torch/testing/_internal/common_utils.py
@contextmanager
def skip_exception_type(exc_type):
try:
yield
except exc_type as e:
raise unittest.SkipTest(f"not implemented: {e}") from e
you could make it stop printing the entire exception text (though I've personally found it useful for collecting information from XML logs)
There was a problem hiding this comment.
Sure, thank you @ezyang for the explanation
| @skipMeta | ||
| @dtypes(*torch.testing.get_all_dtypes()) | ||
| def test_add_list_slow_path(self, device, dtype): | ||
| # 0-strides |
There was a problem hiding this comment.
Nit: it's "implicit broadcast", not 0-strides
There was a problem hiding this comment.
Sorry, might be just I'm wrong, but is it still the case, even if I deliberately call expand_as(tensor1)?
If sizes don't match, the current foreach functions should raise.
There was a problem hiding this comment.
Side note: I want to support broadcasting via slowpath after unittest refactoring
Rel: #52448
There was a problem hiding this comment.
If you deliberately expand, then it's explicit broadcast, but either way, it's just a comment, so doesn't matter
| tensor2 = torch.randn(5, 2, 1, 3 * 7, device=device).to(dtype)[:, :, :, ::7] | ||
| res = torch._foreach_add([tensor1], [tensor2]) | ||
| torch._foreach_add_([tensor1], [tensor2]) | ||
| self.assertEqual(res, [tensor1]) |
There was a problem hiding this comment.
Note for the future: this test is not actually testing correctness, it's testing in-place vs out-of-place. But that can wait till future refactor
test/test_foreach.py
Outdated
| # `tensors2`: ['cuda', 'cpu'] | ||
| _cuda_tensors = self._get_test_data(device, dtype, 2) | ||
| _cpu_tensors = self._get_test_data('cpu', dtype, 2) | ||
| tensors1, tensors2 = [[_cuda_tensors[i], _cpu_tensors[i]] for i in range(2)] |
There was a problem hiding this comment.
nit: tensors1, tensors2 = [ t for t in zip(_cuda_tensors[i], _cpu_tensors[i])] is slightly more pythonic
There was a problem hiding this comment.
Oh, I like this, thank you.
test/test_foreach.py
Outdated
| _cpu_tensors = self._get_test_data('cpu', dtype, 2) | ||
| tensors1, tensors2 = [[_cuda_tensors[i], _cpu_tensors[i]] for i in range(2)] | ||
|
|
||
| if dtype == torch.bool and native_op == torch.sub: |
There was a problem hiding this comment.
why do you need this special case? It should be handled by the following try - except block (both for_each and native should raise the same error)?
There was a problem hiding this comment.
You are absolutely right and that is what I expected but it unfortunately fails.
If I merge this special handling into the following try-except block, in my understanding, as the error message includes the ^, one of the regex special characters, I get the following:
Traceback (most recent call last):
File "/workspace/test/test_foreach.py", line 890, in test_binary_op_tensors_on_different_devices
actual = foreach_op(tensors1, tensors2)
RuntimeError: Subtraction, the `-` operator, with two bool tensors is not supported. Use the `^` or `logical_xor()` operator instead.
During handling of the above exception, another exception occurred:
RuntimeError: Subtraction, the `-` operator, with two bool tensors is not supported. Use the `^` or `logical_xor()` operator instead.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/workspace/test/test_foreach.py", line 893, in test_binary_op_tensors_on_different_devices
[native_op(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
File "/opt/conda/lib/python3.8/unittest/case.py", line 240, in __exit__
self._raiseFailure('"{}" does not match "{}"'.format(
File "/opt/conda/lib/python3.8/unittest/case.py", line 164, in _raiseFailure
raise self.test_case.failureException(msg)
AssertionError: "Subtraction, the `-` operator, with two bool tensors is not supported. Use the `^` or `logical_xor()` operator instead." does not match "Subtraction, the `-` operator, with two bool tensors is not supported. Use the `^` or `logical_xor()` operator instead."There was a problem hiding this comment.
Can you try if wrapping in re.escape solves this? e.g.
with self.assertRaisesRegex(type(e), re.escape(str(e))):
re.search seems to be happy:
In [10]: aa="pytorch `^`"
In [11]: re.search(re.escape(aa), bb)
Out[11]: <re.Match object; span=(0, 11), match='pytorch `^`'>
In [12]: re.search(aa, bb)
test/test_foreach.py
Outdated
| tensors1, tensors2, tensors3 = [[_cuda_tensors[i], _cpu_tensors[i]] for i in range(3)] | ||
|
|
||
| if native_op == torch.addcdiv: | ||
| if dtype in torch.testing.get_all_int_dtypes() + [torch.bool]: |
There was a problem hiding this comment.
can you please leave a comment here why this special case is necessary?
There was a problem hiding this comment.
Okay, I think the reason is similar to that of sub and bool case: the message has one or more special characters.
Traceback (most recent call last):
File "/workspace/test/test_foreach.py", line 931, in test_pointwise_op_tensors_on_different_devices
actual = foreach_op(tensors1, tensors2, tensors3)
RuntimeError: Integer division with addcdiv is no longer supported, and in a future release addcdiv will perform a true division of tensor1 and tensor2. The historic addcdiv behavior can be implemented as (input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) for integer inputs and as (input + value * tensor1 / tensor2) for float inputs. The future addcdiv behavior is just the latter implementation: (input + value * tensor1 / tensor2), for all dtypes.
During handling of the above exception, another exception occurred:
RuntimeError: Integer division with addcdiv is no longer supported, and in a future release addcdiv will perform a true division of tensor1 and tensor2. The historic addcdiv behavior can be implemented as (input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) for integer inputs and as (input + value * tensor1 / tensor2) for float inputs. The future addcdiv behavior is just the latter implementation: (input + value * tensor1 / tensor2), for all dtypes.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/workspace/test/test_foreach.py", line 934, in test_pointwise_op_tensors_on_different_devices
expected = [native_op(t1, t2, t3) for t1, t2, t3 in zip(tensors1, tensors2, tensors3)]
File "/opt/conda/lib/python3.8/unittest/case.py", line 240, in __exit__
self._raiseFailure('"{}" does not match "{}"'.format(
File "/opt/conda/lib/python3.8/unittest/case.py", line 164, in _raiseFailure
raise self.test_case.failureException(msg)
AssertionError: "Integer division with addcdiv is no longer supported, and in a future release addcdiv will perform a true division of tensor1 and tensor2. The historic addcdiv behavior can be implemented as (input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) for integer inputs and as (input + value * tensor1 / tensor2) for float inputs. The future addcdiv behavior is just the latter implementation: (input + value * tensor1 / tensor2), for all dtypes." does not match "Integer division with addcdiv is no longer supported, and in a future release addcdiv will perform a true division of tensor1 and tensor2. The historic addcdiv behavior can be implemented as (input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) for integer inputs and as (input + value * tensor1 / tensor2) for float inputs. The future addcdiv behavior is just the latter implementation: (input + value * tensor1 / tensor2), for all dtypes."note: `has_int_or_bool_tensor` is renamed to `has_integral_tensor` and a new argument `includeBool` is added.
Binary foreach functions cannot be compiled for the inputs of one tensorlist and one scalarlist if scalarlist consists of complex128.
leave a comment to the test which requires the decorator
This obviates the special case for `_foreach_sub` and `_foreach_addcdiv`
|
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Summary: This is based on pytorch#48224. To make `foreach` more flexible, this PR pushes unsupported cases to slow path. Also, this adds some tests to verify that - `foreach` functions work with tensors of different dtypes and/or memory layouts in pytorch@7bd4b2c - `foreach` functions work with tensors on different devices in a list, but are on the same device if the indices are the same: pytorch@def4b9b Future plans: 1. Improve the coverage of unittests using `ops` decorator & updating `foreach_unary_op_db` and creating `foreach_(binary|pointwise|minmax)_db`. 2. Support broadcasting in slow path. Ref: pytorch#52448 3. Support type promotion in fast path. Ref pytorch#52449 CC: ngimel mcarilli ptrblck Pull Request resolved: pytorch#56993 Reviewed By: zou3519 Differential Revision: D28630580 Pulled By: ngimel fbshipit-source-id: e26ee74a39a591025e18c1ead48948cb7ec53c19
This is based on #48224.
To make
foreachmore flexible, this PR pushes unsupported cases to slow path.Also, this adds some tests to verify that
foreachfunctions work with tensors of different dtypes and/or memory layouts in 7bd4b2cforeachfunctions work with tensors on different devices in a list, but are on the same device if the indices are the same: def4b9bFuture plans:
opsdecorator & updatingforeach_unary_op_dband creatingforeach_(binary|pointwise|minmax)_db.CC: @ngimel @mcarilli @ptrblck