Refactor foreach binary ops tests with scalars to use OpInfo#51058
Refactor foreach binary ops tests with scalars to use OpInfo#51058izdeby wants to merge 29 commits intogh/izdeby/75/basefrom
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 6338190 (more details on the Dr. CI page):
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
[ghstack-poisoned]
[ghstack-poisoned]
----- Updated foreach binary ops tests with scalars to use OpInfo [ghstack-poisoned]
----- Updated foreach binary ops tests with scalars to use OpInfo [ghstack-poisoned]
----- Updated foreach binary ops tests with scalars to use OpInfo [ghstack-poisoned]
----- Updated foreach binary ops tests with scalars to use OpInfo Differential Revision: [D26103905](https://our.internmc.facebook.com/intern/diff/D26103905) [ghstack-poisoned]
----- Updated foreach binary ops tests with scalars to use OpInfo Differential Revision: [D26103905](https://our.internmc.facebook.com/intern/diff/D26103905) [ghstack-poisoned]
----- Updated foreach binary ops tests with scalars to use OpInfo Differential Revision: [D26103905](https://our.internmc.facebook.com/intern/diff/D26103905) [ghstack-poisoned]
----- Updated foreach binary ops tests with scalars to use OpInfo Differential Revision: [D26103905](https://our.internmc.facebook.com/intern/diff/D26103905) [ghstack-poisoned]
----- Updated foreach binary ops tests with scalars to use OpInfo Differential Revision: [D26103905](https://our.internmc.facebook.com/intern/diff/D26103905) [ghstack-poisoned]
----- Updated foreach binary ops tests with scalars to use OpInfo Differential Revision: [D26103905](https://our.internmc.facebook.com/intern/diff/D26103905) [ghstack-poisoned]
----- Updated foreach binary ops tests with scalars to use OpInfo Differential Revision: [D26103905](https://our.internmc.facebook.com/intern/diff/D26103905) [ghstack-poisoned]
----- Updated foreach binary ops tests with scalars to use OpInfo Differential Revision: [D26103905](https://our.internmc.facebook.com/intern/diff/D26103905) [ghstack-poisoned]
----- Updated foreach binary ops tests with scalars to use OpInfo Differential Revision: [D26103905](https://our.internmc.facebook.com/intern/diff/D26103905) [ghstack-poisoned]
----- Updated foreach binary ops tests with scalars to use OpInfo Differential Revision: [D26103905](https://our.internmc.facebook.com/intern/diff/D26103905) [ghstack-poisoned]
----- Updated foreach binary ops tests with scalars to use OpInfo Differential Revision: [D26103905](https://our.internmc.facebook.com/intern/diff/D26103905) [ghstack-poisoned]
----- Updated foreach binary ops tests with scalars to use OpInfo Differential Revision: [D26103905](https://our.internmc.facebook.com/intern/diff/D26103905) [ghstack-poisoned]
----- Updated foreach binary ops tests with scalars to use OpInfo Differential Revision: [D26103905](https://our.internmc.facebook.com/intern/diff/D26103905) [ghstack-poisoned]
----- Updated foreach binary ops tests with scalars to use OpInfo Differential Revision: [D26103905](https://our.internmc.facebook.com/intern/diff/D26103905) [ghstack-poisoned]
----- Updated foreach binary ops tests with scalars to use OpInfo Differential Revision: [D26103905](https://our.internmc.facebook.com/intern/diff/D26103905) [ghstack-poisoned]
----- Updated foreach binary ops tests with scalars to use OpInfo Differential Revision: [D26103905](https://our.internmc.facebook.com/intern/diff/D26103905) [ghstack-poisoned]
----- Updated foreach binary ops tests with scalars to use OpInfo Differential Revision: [D26103905](https://our.internmc.facebook.com/intern/diff/D26103905) [ghstack-poisoned]
----- Updated foreach binary ops tests with scalars to use OpInfo Differential Revision: [D26103905](https://our.internmc.facebook.com/intern/diff/D26103905) [ghstack-poisoned]
----- Updated foreach binary ops tests with scalars to use OpInfo Differential Revision: [D26103905](https://our.internmc.facebook.com/intern/diff/D26103905) [ghstack-poisoned]
----- Updated foreach binary ops tests with scalars to use OpInfo Differential Revision: [D26103905](https://our.internmc.facebook.com/intern/diff/D26103905) [ghstack-poisoned]
----- Updated foreach binary ops tests with scalars to use OpInfo Differential Revision: [D26103905](https://our.internmc.facebook.com/intern/diff/D26103905) [ghstack-poisoned]
There was a problem hiding this comment.
Logic looks correct. I had two comments:
- I'm not sure what is happening with the float16 testing (added an in-line comment)
- In the old code, we had some checks for various error messages that would be raised. Those checks don't exist anymore in the OpInfo tests. Do you think it would be worth it to add some new, non-OpInfo test that tests for the various error message? I am thinking of something like:
def test_error_messages(self, device):
# test subtraction with booleans raises a nice error message
# test addition with LongTensors but with floating-point alpha raises a nice error message
# ...
There are probably better ways to organize this. I'm OK if you want to handle adding tests for the error messages in a follow-up later to unblock progress here
| """Early version of a specialized OpInfo for foreach binary functions""" | ||
| def __init__(self, | ||
| name, | ||
| dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), |
There was a problem hiding this comment.
(no action required) Err, it's a little weird that "all_types" doesn't mean "all possible dtypes". But I suppose that's a pre-existing problem
There was a problem hiding this comment.
Yeah, i agree, Its super confusing
There was a problem hiding this comment.
It is incredibly confusing, sorry. It corresponds to the internal dispatch macro 1:1, though.
| # Mimics cuda kernel dtype flow. With fp16/bf16 input, runs in fp32 and casts output back to fp16/bf16. | ||
| dtype = torch.float32 if (self.device_type == 'cuda' and | ||
| (dtype is torch.float16 or dtype is torch.bfloat16)) else dtype |
There was a problem hiding this comment.
Wouldn't we want to test that torch.float16 tensors are accepted by these APIs? What does "Mimics cuda kernel dtype flow." mean? (Maybe this just needs a better explanation)
There was a problem hiding this comment.
Hmm I see, thanks for the references. Wouldn't it be better to
- Run the foreach op in with inputs of the original dtype (float16)
- Convert inputs to the updated dtype (float32), run the reference (PyTorch op in a for-loop), then convert outputs back to float 16
- Compare the result of 1 and 2
This way we can still test that the foreach APIs accept float16 and bfloat16 tensors without erroring out.
----- Updated foreach binary ops tests with scalars to use OpInfo Differential Revision: [D26103905](https://our.internmc.facebook.com/intern/diff/D26103905) [ghstack-poisoned]
|
Hi @izdeby! Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention. You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks! |
Stack from ghstack:
Updated foreach binary ops tests with scalars to use OpInfo
Differential Revision: D26103905