Update foreach binary ops with a scalar list#49249
Update foreach binary ops with a scalar list#49249izdeby wants to merge 40 commits intogh/izdeby/71/basefrom
Conversation
[ghstack-poisoned]
Differential Revision: [D25502939](https://our.internmc.facebook.com/intern/diff/D25502939) -------- - Update binary ops to support boolean type - Update sub ops to throw an error in case of bool subtraction - Update division op to check for type promotion. - Update tests - Update ForeachUtils.h with division check. [ghstack-poisoned]
Differential Revision: [D25502939](https://our.internmc.facebook.com/intern/diff/D25502939) -------- - Update binary ops to support boolean type - Update sub ops to throw an error in case of bool subtraction - Update division op to check for type promotion. - Update tests - Update ForeachUtils.h with division check. [ghstack-poisoned]
Differential Revision: [D25502939](https://our.internmc.facebook.com/intern/diff/D25502939) -------- - Update binary ops to support boolean type - Update sub ops to throw an error in case of bool subtraction - Update division op to check for type promotion. - Update tests - Update ForeachUtils.h with division check. [ghstack-poisoned]
Differential Revision: [D25502939](https://our.internmc.facebook.com/intern/diff/D25502939) -------- - Update binary ops to support boolean type - Update sub ops to throw an error in case of bool subtraction - Update division op to check for type promotion. - Update tests - Update ForeachUtils.h with division check. [ghstack-poisoned]
Differential Revision: [D25502939](https://our.internmc.facebook.com/intern/diff/D25502939) -------- - Update binary ops to support boolean type - Update sub ops to throw an error in case of bool subtraction - Update division op to check for type promotion. - Update tests - Update ForeachUtils.h with division check. [ghstack-poisoned]
Differential Revision: [D25502939](https://our.internmc.facebook.com/intern/diff/D25502939) -------- - Update binary ops to support boolean type - Update sub ops to throw an error in case of bool subtraction - Update division op to check for type promotion. - Update tests - Update ForeachUtils.h with division check. [ghstack-poisoned]
Differential Revision: [D25502939](https://our.internmc.facebook.com/intern/diff/D25502939) -------- - Update binary ops to support boolean type - Update sub ops to throw an error in case of bool subtraction - Update division op to check for type promotion. - Update tests - Update ForeachUtils.h with division check. [ghstack-poisoned]
Differential Revision: [D25502939](https://our.internmc.facebook.com/intern/diff/D25502939) -------- - Update binary ops to support boolean type - Update sub ops to throw an error in case of bool subtraction - Update division op to check for type promotion. - Update tests - Update ForeachUtils.h with division check. [ghstack-poisoned]
Differential Revision: [D25502939](https://our.internmc.facebook.com/intern/diff/D25502939) -------- - Update binary ops to support boolean type - Update sub ops to throw an error in case of bool subtraction - Update division op to check for type promotion. - Update tests - Update ForeachUtils.h with division check. [ghstack-poisoned]
Differential Revision: [D25502939](https://our.internmc.facebook.com/intern/diff/D25502939) -------- - Update binary ops to support boolean type - Update sub ops to throw an error in case of bool subtraction - Update division op to check for type promotion. - Update tests - Update ForeachUtils.h with division check. [ghstack-poisoned]
Differential Revision: [D25502939](https://our.internmc.facebook.com/intern/diff/D25502939) -------- - Update binary ops to support boolean type - Update sub ops to throw an error in case of bool subtraction - Update division op to check for type promotion. - Update tests - Update ForeachUtils.h with division check. [ghstack-poisoned]
zou3519
left a comment
There was a problem hiding this comment.
Questions about behavior
Differential Revision: [D25502939](https://our.internmc.facebook.com/intern/diff/D25502939) -------- - Update binary ops to support boolean type - Update sub ops to throw an error in case of bool subtraction - Update division op to check for type promotion. - Update tests - Update ForeachUtils.h with division check. [ghstack-poisoned]
…ils.h" Differential Revision: [D25502939](https://our.internmc.facebook.com/intern/diff/D25502939) -------- - Update binary ops to support boolean type - Update sub ops to throw an error in case of bool subtraction - Update division op to check for type promotion. - Update tests - Update ForeachUtils.h with division check and refactored `can_use_fast_route` methods. [ghstack-poisoned]
Differential Revision: [D25502939](https://our.internmc.facebook.com/intern/diff/D25502939) -------- - Update binary ops to support boolean type - Update sub ops to throw an error in case of bool subtraction - Update division op to check for type promotion. - Update tests [ghstack-poisoned]
Differential Revision: [D25502939](https://our.internmc.facebook.com/intern/diff/D25502939) -------- - Update binary ops to support boolean type - Update sub ops to throw an error in case of bool subtraction - Update division op to check for type promotion. - Update tests [ghstack-poisoned]
In this PR, we are updating the sub/div logic for APIs that work with scalar lists and in the next one APIs with a scalar and/or tensor list. Updating names and description for both PRs |
Differential Revision: [D25502939](https://our.internmc.facebook.com/intern/diff/D25502939) -------- Update logic for 2 set of foreach APIs with a scalar list - Update binary ops to support boolean type - Update sub ops to throw an error in case of bool subtraction - Update division op to check for type promotion. - Update tests [ghstack-poisoned]
Differential Revision: [D25502939](https://our.internmc.facebook.com/intern/diff/D25502939) -------- Update logic for 2 set of foreach APIs with a scalar list - Update binary ops to support boolean type - Update sub ops to throw an error in case of bool subtraction - Update division op to check for type promotion. - Update tests [ghstack-poisoned]
Differential Revision: [D25502939](https://our.internmc.facebook.com/intern/diff/D25502939) -------- Update logic for a set of foreach APIs with a scalar list - Update binary ops to support boolean type - Update sub ops to throw an error in case of bool subtraction - Update division op to check for type promotion. - Update tests [ghstack-poisoned]
zou3519
left a comment
There was a problem hiding this comment.
This looks pretty reasonable, it's great we are deleting all the special cases in the testing logic. I had some comments around the tests
| // In the case of division, integer inputs will result in float. | ||
| // Currently multi tensor apply can only return result of the same type as input. |
There was a problem hiding this comment.
nit: It would be good to mention "This does not use FOREACH_BINARY_OP_SCALARLIST because ". It's easy to make the inference but just to clarify
|
|
||
| // In the case of division, integer inputs will result in float. | ||
| // Currently multi tensor apply can only return result of the same type as input. | ||
| void foreach_tensor_div_scalarlist_kernel_cuda_(TensorList tensors, at::ArrayRef<Scalar> scalars) { |
There was a problem hiding this comment.
Also nit: is it possible to update FOREACH_BINARY_OP_SCALARLIST to take three arguments, where the last argument is div_op? Then we would be able to use FOREACH_BINARY_OP_SCALARLIST with division instead of rewriting the division logic
| return foreach_binary_op<std::divides>(tensors, scalars); | ||
| } | ||
|
|
||
| // In the case of subtraction, we dont allow scalar to be boolean following the torch.sub logic |
There was a problem hiding this comment.
It would be good to mention "This does not use FOREACH_BINARY_OP_SCALARLIST because ". It's easy to make the inference but just to clarify that this is a special case
test/test_foreach.py
Outdated
| foreach_bin_op_(tensors, scalars) | ||
| with self.assertRaisesRegex(RuntimeError, "not implemented for"): | ||
| foreach_bin_op(tensors, scalars) | ||
| return |
There was a problem hiding this comment.
Should this be a continue?
test/test_foreach.py
Outdated
| scalars = [True for _ in range(N)] | ||
| scalars[0] = 1 | ||
| scalars[1] = 1.1 | ||
| scalars[2] = 3 + 5j |
There was a problem hiding this comment.
nit: can be done succinctly in one line: scalars = [1, 1.1, 3 + 5j] + [True for _ in range(N - 3)]
| expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)] | ||
| res = foreach_bin_op(tensors, scalars) | ||
| self.assertEqual(expected, res) |
There was a problem hiding this comment.
This only tests correctness for the out-of-place variant, right? Should it also test correctness for the in-place variant?
Differential Revision: [D25502939](https://our.internmc.facebook.com/intern/diff/D25502939) -------- Update logic for a set of foreach APIs with a scalar list - Update binary ops to support boolean type - Update sub ops to throw an error in case of bool subtraction - Update division op to check for type promotion. - Update tests [ghstack-poisoned]
Differential Revision: [D25502939](https://our.internmc.facebook.com/intern/diff/D25502939) -------- Update logic for a set of foreach APIs with a scalar list - Update binary ops to support boolean type - Update sub ops to throw an error in case of bool subtraction - Update division op to check for type promotion. - Update tests [ghstack-poisoned]
Differential Revision: [D25502939](https://our.internmc.facebook.com/intern/diff/D25502939) -------- Update logic for a set of foreach APIs with a scalar list - Update binary ops to support boolean type - Update sub ops to throw an error in case of bool subtraction - Update division op to check for type promotion. - Update tests [ghstack-poisoned]
Differential Revision: [D25502939](https://our.internmc.facebook.com/intern/diff/D25502939) -------- Update logic for a set of foreach APIs with a scalar list - Update binary ops to support boolean type - Update sub ops to throw an error in case of bool subtraction - Update division op to check for type promotion. - Update tests [ghstack-poisoned]
Differential Revision: [D25502939](https://our.internmc.facebook.com/intern/diff/D25502939) -------- Update logic for a set of foreach APIs with a scalar list - Update binary ops to support boolean type - Update sub ops to throw an error in case of bool subtraction - Update division op to check for type promotion. - Update tests [ghstack-poisoned]
Differential Revision: [D25502939](https://our.internmc.facebook.com/intern/diff/D25502939) -------- Update logic for a set of foreach APIs with a scalar list - Update binary ops to support boolean type - Update sub ops to throw an error in case of bool subtraction - Update division op to check for type promotion. - Update tests [ghstack-poisoned]
test/test_foreach.py
Outdated
| else: | ||
| with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"): | ||
| foreach_bin_op_(tensors, scalars) | ||
| continue |
There was a problem hiding this comment.
Is the continue necessary here?
Differential Revision: [D25502939](https://our.internmc.facebook.com/intern/diff/D25502939) -------- Update logic for a set of foreach APIs with a scalar list - Update binary ops to support boolean type - Update sub ops to throw an error in case of bool subtraction - Update division op to check for type promotion. - Update tests [ghstack-poisoned]
Differential Revision: [D25502939](https://our.internmc.facebook.com/intern/diff/D25502939) -------- Update logic for a set of foreach APIs with a scalar list - Update binary ops to support boolean type - Update sub ops to throw an error in case of bool subtraction - Update division op to check for type promotion. - Update tests [ghstack-poisoned]
Differential Revision: [D25502939](https://our.internmc.facebook.com/intern/diff/D25502939) -------- Update logic for a set of foreach APIs with a scalar list - Update binary ops to support boolean type - Update sub ops to throw an error in case of bool subtraction - Update division op to check for type promotion. - Update tests [ghstack-poisoned]
Stack from ghstack:
Differential Revision: D25502939
Update logic for a set of foreach APIs with a scalar list