Modernize FakeTensorMode, throw on non-fake inputs#78516
Modernize FakeTensorMode, throw on non-fake inputs#78516eellison wants to merge 6 commits intogh/eellison/298/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful links
✅ No Failures (0 Pending)As of commit c7b545e (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Modernizes FakeTensorMode by inheriting from `TorchDispatchMode`. `FakeTensor` and `FakeTensorMode` now call into a common helper function. I didn't see any existing idiomatic pattern for this so please feel free to comment if you think something else would be more ergonomic. This also throws if a non-Fake tensor is an input to `FakeTensorMode`.. So it effectively only extends handling to constructors (for now, more general handling later in stack). [ghstack-poisoned]
|
|
||
| def test_constructor(self): | ||
| with enable_torch_dispatch_mode(FakeTensorMode): | ||
| with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): |
There was a problem hiding this comment.
can we just write all of these tests using FakeTensorMode.push()? Or maybe @samdow's patch has landed so with FakeTensorMode() works now
There was a problem hiding this comment.
Patch not landed yet (Richard has been hunting down a functorch/dispatch key bug and then I was going to ask him if he had anything to add)--and this actually doesn't work yet because push uses push_torch_dispatch_mode which causes the error Elias saw late last week (Creating a new Tensor subclass FakeTensor but the raw Tensor object is already associated to a python object of type Tensor)
On that note, I forgot to mention Richard debugged that and what's happening is that because push makes every mode have an inner mode (setting BaseTorchDispatchMode if there isn't one set), detach has a pyobj associated with it and we end up with this error). The patch fixes this because it removes BaseTorchDispatchMode but it will come up if we nest fake tensor mode with another mode. Long technical way of saying if we want FakeTensorMode to be composable, we should use push and add no_dispatch around all constructors in torch_dispatch. I'll try to flag everywhere this will come up
There was a problem hiding this comment.
Can we have the constructors automatically apply no dispatch? Seems safer.
There was a problem hiding this comment.
We can do it in _make_subclass, sure. I don't think there's any case where we wouldn't want this? (cc @zou3519)
There was a problem hiding this comment.
Filed here: #78565, also FakeTensorMode.push() still errors
| return tree_map(partial(wrap, device=common_device), r) | ||
| def run_fn(func, types, args, kwargs): | ||
| return torch.Tensor.__torch_dispatch__(func, types, args, kwargs) | ||
| return torch_dispatch_impl(cls, func, types, args, kwargs, run_fn) |
There was a problem hiding this comment.
@samdow and I discussed this on Thursday and we think the right way to do code reuse is you put the real implementation in the mode, and then in the subclass implementation you (1) store the mode that allocated the subclass and (2) enable that mode before redispatching on the tensors as is.
There was a problem hiding this comment.
I'm going to leave to follow up because changing this causes a few different errors (which i'll file issues / smaller repros for and/or debug) and i would like to unblock dynamo
samdow
left a comment
There was a problem hiding this comment.
LGTM! All of my points are small/cleanup stuff
| conversion_made = False | ||
|
|
||
| def check_non_fake_tensor(x): | ||
| nonlocal conversion_made | ||
| conversion_made = conversion_made or (isinstance(x, torch.Tensor) and not isinstance(x, FakeTensor)) | ||
|
|
||
| tree_map(check_non_fake_tensor, args) | ||
| tree_map(check_non_fake_tensor, kwargs) |
There was a problem hiding this comment.
nit: not sure which is more idiomatic but we could do any(tree_flatten(tree_map...)[0]) or even something like
conversion_made = False
for x in tree_flatten(args)[0] + tree_flatten(kwargs)[0]:
if isinstance(x, torch.Tensor) and not isinstance(x, FakeTensor):
raise ...
this just saves us from using a nonlocal and lets us short circuit in the case that we run into the error
There was a problem hiding this comment.
IMO it's not worth optimizing for failure modes for execution speed or when sacrificing readability..... but maybe this is more readable, will consider it. thanks!
There was a problem hiding this comment.
I'm mentally filing an issue to do tree_iterate separately or something along those lines because there are a lot of places in the codebase where tree_map is used where the return value isnt used
| def run_fn(func, types, args, kwargs): | ||
| return torch.Tensor.__torch_dispatch__(func, types, args, kwargs) |
There was a problem hiding this comment.
We should be okay running func(*args, **kwargs) in all cases. This way we can also remove the run_function argument
There was a problem hiding this comment.
This actually causes an error:
x = FakeTensor.from_tensor(torch.tensor(0.0))
y = FakeTensor.from_tensor(torch.rand([4, 4], device="cuda"))
out = x + y
...
ERROR: test_zero_dim (__main__.FakeTensorTest)
----------------------------------------------------------------------
Traceback (most recent call last):
File "test/test_fake_tensor.py", line 32, in test_zero_dim
out = x + y
TypeError: unsupported operand type(s) for +: 'FakeTensor' and 'FakeTensor'
Although it's a mute point because I'm going to do the refactoring of tensors storing mode and calling into it
|
|
||
| def test_constructor(self): | ||
| with enable_torch_dispatch_mode(FakeTensorMode): | ||
| with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): |
There was a problem hiding this comment.
Patch not landed yet (Richard has been hunting down a functorch/dispatch key bug and then I was going to ask him if he had anything to add)--and this actually doesn't work yet because push uses push_torch_dispatch_mode which causes the error Elias saw late last week (Creating a new Tensor subclass FakeTensor but the raw Tensor object is already associated to a python object of type Tensor)
On that note, I forgot to mention Richard debugged that and what's happening is that because push makes every mode have an inner mode (setting BaseTorchDispatchMode if there isn't one set), detach has a pyobj associated with it and we end up with this error). The patch fixes this because it removes BaseTorchDispatchMode but it will come up if we nest fake tensor mode with another mode. Long technical way of saying if we want FakeTensorMode to be composable, we should use push and add no_dispatch around all constructors in torch_dispatch. I'll try to flag everywhere this will come up
Modernizes FakeTensorMode by inheriting from `TorchDispatchMode`. `FakeTensor` and `FakeTensorMode` now call into a common helper function. I didn't see any existing idiomatic pattern for this so please feel free to comment if you think something else would be more ergonomic.
This also throws if a non-Fake tensor is an input to `FakeTensorMode`.. So it effectively only extends handling to constructors (for now, more general handling later in stack). This is because we need more careful logic to detect something like:
```
x = torch.rand([1, 1])
with FakeTensorMode()
y = x.add_(3)
x.resize_([4]) # y should be resized here as well, no way to support this, error
```
[ghstack-poisoned]
|
@pytorchbot merge this please |
|
Hey @eellison. |
Summary: Pull Request resolved: #78516 Approved by: https://github.com/samdow Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/6671b504f7e3934bb26df93fd9a02d4081ba1713 Reviewed By: b0noI Differential Revision: D36854261 Pulled By: eellison fbshipit-source-id: 41c3f0d74b592561a2f9a3262ae2bc421a6c43af
Stack from ghstack (oldest at bottom):
Modernizes FakeTensorMode by inheriting from
TorchDispatchMode.FakeTensorandFakeTensorModenow call into a common helper function. I didn't see any existing idiomatic pattern for this so please feel free to comment if you think something else would be more ergonomic.This also throws if a non-Fake tensor is an input to
FakeTensorMode.. So it effectively only extends handling to constructors (for now, more general handling later in stack). This is because we need more careful logic to detect something like: