Skip to content

Modernize FakeTensorMode, throw on non-fake inputs#78516

Closed
eellison wants to merge 6 commits intogh/eellison/298/basefrom
gh/eellison/298/head
Closed

Modernize FakeTensorMode, throw on non-fake inputs#78516
eellison wants to merge 6 commits intogh/eellison/298/basefrom
gh/eellison/298/head

Conversation

@eellison
Copy link
Contributor

@eellison eellison commented May 31, 2022

Stack from ghstack (oldest at bottom):

Modernizes FakeTensorMode by inheriting from TorchDispatchMode. FakeTensor and FakeTensorMode now call into a common helper function. I didn't see any existing idiomatic pattern for this so please feel free to comment if you think something else would be more ergonomic.

This also throws if a non-Fake tensor is an input to FakeTensorMode.. So it effectively only extends handling to constructors (for now, more general handling later in stack). This is because we need more careful logic to detect something like:

x = torch.rand([1, 1])
with FakeTensorMode()
    y = x.add_(3)
    x.resize_([4])  # y should be resized here as well, no way to support this, error

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented May 31, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit c7b545e (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@eellison eellison mentioned this pull request May 31, 2022
@eellison eellison requested review from ezyang and samdow May 31, 2022 16:21
Modernizes FakeTensorMode by inheriting from `TorchDispatchMode`. `FakeTensor` and `FakeTensorMode` now call into a common helper function. I didn't see any existing idiomatic pattern for this so please feel free to comment if you think something else would be more ergonomic. 

This also throws if a non-Fake tensor is an input to `FakeTensorMode`.. So it effectively only extends handling to constructors (for now, more general handling later in stack).

[ghstack-poisoned]

def test_constructor(self):
with enable_torch_dispatch_mode(FakeTensorMode):
with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just write all of these tests using FakeTensorMode.push()? Or maybe @samdow's patch has landed so with FakeTensorMode() works now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Patch not landed yet (Richard has been hunting down a functorch/dispatch key bug and then I was going to ask him if he had anything to add)--and this actually doesn't work yet because push uses push_torch_dispatch_mode which causes the error Elias saw late last week (Creating a new Tensor subclass FakeTensor but the raw Tensor object is already associated to a python object of type Tensor)

On that note, I forgot to mention Richard debugged that and what's happening is that because push makes every mode have an inner mode (setting BaseTorchDispatchMode if there isn't one set), detach has a pyobj associated with it and we end up with this error). The patch fixes this because it removes BaseTorchDispatchMode but it will come up if we nest fake tensor mode with another mode. Long technical way of saying if we want FakeTensorMode to be composable, we should use push and add no_dispatch around all constructors in torch_dispatch. I'll try to flag everywhere this will come up

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have the constructors automatically apply no dispatch? Seems safer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do it in _make_subclass, sure. I don't think there's any case where we wouldn't want this? (cc @zou3519)

Copy link
Contributor Author

@eellison eellison May 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed here: #78565, also FakeTensorMode.push() still errors

return tree_map(partial(wrap, device=common_device), r)
def run_fn(func, types, args, kwargs):
return torch.Tensor.__torch_dispatch__(func, types, args, kwargs)
return torch_dispatch_impl(cls, func, types, args, kwargs, run_fn)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@samdow and I discussed this on Thursday and we think the right way to do code reuse is you put the real implementation in the mode, and then in the subclass implementation you (1) store the mode that allocated the subclass and (2) enable that mode before redispatching on the tensors as is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to leave to follow up because changing this causes a few different errors (which i'll file issues / smaller repros for and/or debug) and i would like to unblock dynamo

Copy link
Contributor

@samdow samdow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! All of my points are small/cleanup stuff

Comment on lines +77 to +84
conversion_made = False

def check_non_fake_tensor(x):
nonlocal conversion_made
conversion_made = conversion_made or (isinstance(x, torch.Tensor) and not isinstance(x, FakeTensor))

tree_map(check_non_fake_tensor, args)
tree_map(check_non_fake_tensor, kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: not sure which is more idiomatic but we could do any(tree_flatten(tree_map...)[0]) or even something like

conversion_made = False
for x in tree_flatten(args)[0] + tree_flatten(kwargs)[0]:
  if isinstance(x, torch.Tensor) and not isinstance(x, FakeTensor):
    raise ...

this just saves us from using a nonlocal and lets us short circuit in the case that we run into the error

Copy link
Contributor Author

@eellison eellison May 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it's not worth optimizing for failure modes for execution speed or when sacrificing readability..... but maybe this is more readable, will consider it. thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm mentally filing an issue to do tree_iterate separately or something along those lines because there are a lot of places in the codebase where tree_map is used where the return value isnt used

Comment on lines +165 to +166
def run_fn(func, types, args, kwargs):
return torch.Tensor.__torch_dispatch__(func, types, args, kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be okay running func(*args, **kwargs) in all cases. This way we can also remove the run_function argument

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually causes an error:

x = FakeTensor.from_tensor(torch.tensor(0.0))
y = FakeTensor.from_tensor(torch.rand([4, 4], device="cuda"))
out = x + y
...
ERROR: test_zero_dim (__main__.FakeTensorTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test/test_fake_tensor.py", line 32, in test_zero_dim
    out = x + y
TypeError: unsupported operand type(s) for +: 'FakeTensor' and 'FakeTensor'

Although it's a mute point because I'm going to do the refactoring of tensors storing mode and calling into it


def test_constructor(self):
with enable_torch_dispatch_mode(FakeTensorMode):
with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Patch not landed yet (Richard has been hunting down a functorch/dispatch key bug and then I was going to ask him if he had anything to add)--and this actually doesn't work yet because push uses push_torch_dispatch_mode which causes the error Elias saw late last week (Creating a new Tensor subclass FakeTensor but the raw Tensor object is already associated to a python object of type Tensor)

On that note, I forgot to mention Richard debugged that and what's happening is that because push makes every mode have an inner mode (setting BaseTorchDispatchMode if there isn't one set), detach has a pyobj associated with it and we end up with this error). The patch fixes this because it removes BaseTorchDispatchMode but it will come up if we nest fake tensor mode with another mode. Long technical way of saying if we want FakeTensorMode to be composable, we should use push and add no_dispatch around all constructors in torch_dispatch. I'll try to flag everywhere this will come up

Modernizes FakeTensorMode by inheriting from `TorchDispatchMode`. `FakeTensor` and `FakeTensorMode` now call into a common helper function. I didn't see any existing idiomatic pattern for this so please feel free to comment if you think something else would be more ergonomic. 

This also throws if a non-Fake tensor is an input to `FakeTensorMode`.. So it effectively only extends handling to constructors (for now, more general handling later in stack). This is because we need more careful logic to detect something like:
```
x = torch.rand([1, 1])
with FakeTensorMode()
    y = x.add_(3)
    x.resize_([4])  # y should be resized here as well, no way to support this, error
```

[ghstack-poisoned]
@eellison
Copy link
Contributor Author

eellison commented Jun 1, 2022

@pytorchbot merge this please

@github-actions
Copy link
Contributor

github-actions bot commented Jun 1, 2022

Hey @eellison.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

facebook-github-bot pushed a commit that referenced this pull request Jun 2, 2022
Summary:
Pull Request resolved: #78516

Approved by: https://github.com/samdow

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/6671b504f7e3934bb26df93fd9a02d4081ba1713

Reviewed By: b0noI

Differential Revision: D36854261

Pulled By: eellison

fbshipit-source-id: 41c3f0d74b592561a2f9a3262ae2bc421a6c43af
@facebook-github-bot facebook-github-bot deleted the gh/eellison/298/head branch June 5, 2022 14:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants