Skip to content

Migrate FakeTensors to always call into FakeTensorMode and have them hold a reference#78677

Closed
eellison wants to merge 14 commits intogh/eellison/305/basefrom
gh/eellison/305/head
Closed

Migrate FakeTensors to always call into FakeTensorMode and have them hold a reference#78677
eellison wants to merge 14 commits intogh/eellison/305/basefrom
gh/eellison/305/head

Conversation

@eellison
Copy link
Contributor

@eellison eellison commented Jun 1, 2022

Stack from ghstack (oldest at bottom):

Migrates FakeTensors to always invoke FakeTensorMode instead of calling into shared impl function. Because FakeTensorMode stores a FakeTensorConverter, which holds a map from Real Tensors -> Fake Tensors, this also changes the tensor_memo to use a WeakValueDictionary to avoid a circular reference from FakeTensorMode -> FakeTensor -> FakeTensorMode...

When instantiating FakeTensors you now need a corresponding FakeTensorMode for them to own. If creating new tensors, the idiom is similar to existing fake_modefromtorchdistx.

with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
            y = torch.rand([4], device="cpu")

When converting existing tensors to FakeTensors, the idiom is to instantiate a FakeTensorMode, and then that to use the conversion, with the new mode associated with the new tensor.

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 1, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit f0b2b83 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@eellison eellison requested review from ezyang and samdow June 1, 2022 21:38
kwargs = kwargs if kwargs else {}

# TODO: apply as no_dispatch decorator
with no_dispatch():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need this anymore, in the mode we automatically disable the current mode so if you call another function it will only hit the underlying dispatch. In fact doing it this way is wrong because it makes your mode non-compositional

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be great to have a test suite/ demo / something to test if your new tensor subclass is compositional because it's a bit hard to reason about without actually going through tests/debugging/seeing how things work together. In subclass zoo for example

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the no-dispatch I get infinite recursion. the r = func(*args, **kwargs) here just calls into __torch_dispatch__ again

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want this to work and it to be compositional with other modes, we'll need to unwrap args and kwargs. Usually this would just be arg.elem if isinstance(arg, FakeTensor) else arg but it seems like this will probably be more complicated before it can be a wrapper tensor? Not sure if @ezyang has any thoughts on how to do this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should only unwrap if the mode matches your mode

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah looks like this should work--so basically we should be saving elem and then unwrapping to that elem, like Ed said, if the mode matches your mode (currently this should always happen since we error if there's FakeTensors from more than one mode but we can be extra careful to set us up for nested FakeTensors if we want)

I got confused by this comment though: elem does not need to be recorded, because FakeTensor *is a* elem. FakeTensor should only be taking a vanilla tensor as elem?

Copy link
Contributor Author

@eellison eellison Jun 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ughhh, that comment is copy-pasta'd from here,,, maybe it's not applicable. Good catch.

Is there a way to unwrap a tensor that inherits from torch.Tensor ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see--that is Tracer/ProxyTensor specific. And kinda quirky

Yeah so the idea is that this is basically FakeTensor(torch.Tensor(...)) and we want to redispatch on torch.Tensor instead of on FakeTensor so that we don't get the infinite recursion you're seeing

So if we save elem, we should be able to unwrap by basically doing arg.elem if isinstance(arg, FakeTensor) and self == arg.fake_mode else arg for every arg and kwarg (this adds in Ed's check that we only unwrap tensors in the same mode) and then run func(*new_args, **new_kwargs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no elem, since FakeTensor inherits it's not compositional.. im going to defer this working for another pr

if fake_mode is None:
fake_mode = arg.fake_mode
else:
assert fake_mode is arg.fake_mode, "Mixing modes NYI"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@samdow we should probably do a utility for doing this

else:
assert fake_mode is arg.fake_mode, "Mixing modes NYI"

with enable_torch_dispatch_mode(fake_mode):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@samdow This probably can be a little more compositional if we have some sort of push many operation which will let you pushing arbitrary mode so long as the current mode is an ancestor of the mode you're trying to push.

# FakeTensors store the FakeTensorMode which in turn stores a
# FakeTensor, so we need to hold a weak reference to the FakeTensor
# otherwise we would induce a circular reference
self.tensor_memo = weakref.WeakValueDictionary()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually thought this was going to be a weak key dictionary; but this is clearly what you need to prevent the reference cycle. Too bad there's not a WeakKeyValueDictionary lol


def from_real_tensor(self, fake_mode, t):
if self._get_memo(t) is not None:
return self.tensor_memo[t]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks wrong shouldn't you return the results of get memo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wheres the walrus operator where u need it (doesnt work python 3.7)

eellison added 3 commits June 2, 2022 07:16
… have them hold a reference"


Migrates FakeTensors to always invoke FakeTensorMode instead of calling into shared impl function. Because `FakeTensorMode` stores a `FakeTensorConverter`, which holds a map from Real Tensors -> Fake Tensors, this also changes the `tensor_memo` to use a `WeakValueDictionary` to avoid a circular reference from `FakeTensorMode` -> `FakeTensor` -> `FakeTensorMode`... 

When instantiating `FakeTensors` you now need a corresponding `FakeTensorMode` for them to own. If creating new tensors, the idiom is similar to existing fake_mode` from `torchdistx. 
```
with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
            y = torch.rand([4], device="cpu")
 ```
When converting existing tensors to `FakeTensors`, the idiom is to instantiate a `FakeTensorMode`, and then that to use the conversion, with the new mode associated with the new tensor. 

 


[ghstack-poisoned]
… have them hold a reference"


Migrates FakeTensors to always invoke FakeTensorMode instead of calling into shared impl function. Because `FakeTensorMode` stores a `FakeTensorConverter`, which holds a map from Real Tensors -> Fake Tensors, this also changes the `tensor_memo` to use a `WeakValueDictionary` to avoid a circular reference from `FakeTensorMode` -> `FakeTensor` -> `FakeTensorMode`... 

When instantiating `FakeTensors` you now need a corresponding `FakeTensorMode` for them to own. If creating new tensors, the idiom is similar to existing fake_mode` from `torchdistx. 
```
with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
            y = torch.rand([4], device="cpu")
 ```
When converting existing tensors to `FakeTensors`, the idiom is to instantiate a `FakeTensorMode`, and then that to use the conversion, with the new mode associated with the new tensor. 

 


[ghstack-poisoned]
… have them hold a reference"


Migrates FakeTensors to always invoke FakeTensorMode instead of calling into shared impl function. Because `FakeTensorMode` stores a `FakeTensorConverter`, which holds a map from Real Tensors -> Fake Tensors, this also changes the `tensor_memo` to use a `WeakValueDictionary` to avoid a circular reference from `FakeTensorMode` -> `FakeTensor` -> `FakeTensorMode`... 

When instantiating `FakeTensors` you now need a corresponding `FakeTensorMode` for them to own. If creating new tensors, the idiom is similar to existing fake_mode` from `torchdistx. 
```
with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
            y = torch.rand([4], device="cpu")
 ```
When converting existing tensors to `FakeTensors`, the idiom is to instantiate a `FakeTensorMode`, and then that to use the conversion, with the new mode associated with the new tensor. 

 


[ghstack-poisoned]
samdow pushed a commit that referenced this pull request Jun 3, 2022
…mode is an ancestor"


Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping

[ghstack-poisoned]
samdow pushed a commit that referenced this pull request Jun 3, 2022
…tor"


Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping

[ghstack-poisoned]
samdow pushed a commit that referenced this pull request Jun 3, 2022
…mode is an ancestor"


Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping

[ghstack-poisoned]
samdow pushed a commit that referenced this pull request Jun 3, 2022
…tor"


Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping

[ghstack-poisoned]
… have them hold a reference"


Migrates FakeTensors to always invoke FakeTensorMode instead of calling into shared impl function. Because `FakeTensorMode` stores a `FakeTensorConverter`, which holds a map from Real Tensors -> Fake Tensors, this also changes the `tensor_memo` to use a `WeakValueDictionary` to avoid a circular reference from `FakeTensorMode` -> `FakeTensor` -> `FakeTensorMode`... 

When instantiating `FakeTensors` you now need a corresponding `FakeTensorMode` for them to own. If creating new tensors, the idiom is similar to existing fake_mode` from `torchdistx. 
```
with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
            y = torch.rand([4], device="cpu")
 ```
When converting existing tensors to `FakeTensors`, the idiom is to instantiate a `FakeTensorMode`, and then that to use the conversion, with the new mode associated with the new tensor. 

 


[ghstack-poisoned]
eellison added 4 commits June 7, 2022 09:16
… have them hold a reference"


Migrates FakeTensors to always invoke FakeTensorMode instead of calling into shared impl function. Because `FakeTensorMode` stores a `FakeTensorConverter`, which holds a map from Real Tensors -> Fake Tensors, this also changes the `tensor_memo` to use a `WeakValueDictionary` to avoid a circular reference from `FakeTensorMode` -> `FakeTensor` -> `FakeTensorMode`... 

When instantiating `FakeTensors` you now need a corresponding `FakeTensorMode` for them to own. If creating new tensors, the idiom is similar to existing fake_mode` from `torchdistx. 
```
with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
            y = torch.rand([4], device="cpu")
 ```
When converting existing tensors to `FakeTensors`, the idiom is to instantiate a `FakeTensorMode`, and then that to use the conversion, with the new mode associated with the new tensor. 

 


[ghstack-poisoned]
… have them hold a reference"


Migrates FakeTensors to always invoke FakeTensorMode instead of calling into shared impl function. Because `FakeTensorMode` stores a `FakeTensorConverter`, which holds a map from Real Tensors -> Fake Tensors, this also changes the `tensor_memo` to use a `WeakValueDictionary` to avoid a circular reference from `FakeTensorMode` -> `FakeTensor` -> `FakeTensorMode`... 

When instantiating `FakeTensors` you now need a corresponding `FakeTensorMode` for them to own. If creating new tensors, the idiom is similar to existing fake_mode` from `torchdistx. 
```
with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
            y = torch.rand([4], device="cpu")
 ```
When converting existing tensors to `FakeTensors`, the idiom is to instantiate a `FakeTensorMode`, and then that to use the conversion, with the new mode associated with the new tensor. 

 


[ghstack-poisoned]
… have them hold a reference"


Migrates FakeTensors to always invoke FakeTensorMode instead of calling into shared impl function. Because `FakeTensorMode` stores a `FakeTensorConverter`, which holds a map from Real Tensors -> Fake Tensors, this also changes the `tensor_memo` to use a `WeakValueDictionary` to avoid a circular reference from `FakeTensorMode` -> `FakeTensor` -> `FakeTensorMode`... 

When instantiating `FakeTensors` you now need a corresponding `FakeTensorMode` for them to own. If creating new tensors, the idiom is similar to existing fake_mode` from `torchdistx. 
```
with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
            y = torch.rand([4], device="cpu")
 ```
When converting existing tensors to `FakeTensors`, the idiom is to instantiate a `FakeTensorMode`, and then that to use the conversion, with the new mode associated with the new tensor. 

 


[ghstack-poisoned]
… have them hold a reference"


Migrates FakeTensors to always invoke FakeTensorMode instead of calling into shared impl function. Because `FakeTensorMode` stores a `FakeTensorConverter`, which holds a map from Real Tensors -> Fake Tensors, this also changes the `tensor_memo` to use a `WeakValueDictionary` to avoid a circular reference from `FakeTensorMode` -> `FakeTensor` -> `FakeTensorMode`... 

When instantiating `FakeTensors` you now need a corresponding `FakeTensorMode` for them to own. If creating new tensors, the idiom is similar to existing fake_mode` from `torchdistx. 
```
with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
            y = torch.rand([4], device="cpu")
 ```
When converting existing tensors to `FakeTensors`, the idiom is to instantiate a `FakeTensorMode`, and then that to use the conversion, with the new mode associated with the new tensor. 

 


[ghstack-poisoned]
@eellison eellison requested review from mruberry and ngimel as code owners June 8, 2022 01:13
eellison added 2 commits June 7, 2022 21:23
… have them hold a reference"


Migrates FakeTensors to always invoke FakeTensorMode instead of calling into shared impl function. Because `FakeTensorMode` stores a `FakeTensorConverter`, which holds a map from Real Tensors -> Fake Tensors, this also changes the `tensor_memo` to use a `WeakValueDictionary` to avoid a circular reference from `FakeTensorMode` -> `FakeTensor` -> `FakeTensorMode`... 

When instantiating `FakeTensors` you now need a corresponding `FakeTensorMode` for them to own. If creating new tensors, the idiom is similar to existing fake_mode` from `torchdistx. 
```
with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
            y = torch.rand([4], device="cpu")
 ```
When converting existing tensors to `FakeTensors`, the idiom is to instantiate a `FakeTensorMode`, and then that to use the conversion, with the new mode associated with the new tensor. 

 


[ghstack-poisoned]
… have them hold a reference"


Migrates FakeTensors to always invoke FakeTensorMode instead of calling into shared impl function. Because `FakeTensorMode` stores a `FakeTensorConverter`, which holds a map from Real Tensors -> Fake Tensors, this also changes the `tensor_memo` to use a `WeakValueDictionary` to avoid a circular reference from `FakeTensorMode` -> `FakeTensor` -> `FakeTensorMode`... 

When instantiating `FakeTensors` you now need a corresponding `FakeTensorMode` for them to own. If creating new tensors, the idiom is similar to existing fake_mode` from `torchdistx. 
```
with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
            y = torch.rand([4], device="cpu")
 ```
When converting existing tensors to `FakeTensors`, the idiom is to instantiate a `FakeTensorMode`, and then that to use the conversion, with the new mode associated with the new tensor. 

 


[ghstack-poisoned]
samdow pushed a commit that referenced this pull request Jun 8, 2022
…mode is an ancestor"


Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping

[ghstack-poisoned]
samdow pushed a commit that referenced this pull request Jun 8, 2022
…tor"


Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping

[ghstack-poisoned]
samdow pushed a commit that referenced this pull request Jun 8, 2022
…mode is an ancestor"


Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping

[ghstack-poisoned]
samdow pushed a commit that referenced this pull request Jun 8, 2022
…tor"


Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping

[ghstack-poisoned]
… have them hold a reference"


Migrates FakeTensors to always invoke FakeTensorMode instead of calling into shared impl function. Because `FakeTensorMode` stores a `FakeTensorConverter`, which holds a map from Real Tensors -> Fake Tensors, this also changes the `tensor_memo` to use a `WeakValueDictionary` to avoid a circular reference from `FakeTensorMode` -> `FakeTensor` -> `FakeTensorMode`... 

When instantiating `FakeTensors` you now need a corresponding `FakeTensorMode` for them to own. If creating new tensors, the idiom is similar to existing fake_mode` from `torchdistx. 
```
with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
            y = torch.rand([4], device="cpu")
 ```
When converting existing tensors to `FakeTensors`, the idiom is to instantiate a `FakeTensorMode`, and then that to use the conversion, with the new mode associated with the new tensor. 

 


[ghstack-poisoned]
@eellison
Copy link
Contributor Author

eellison commented Jun 8, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@github-actions
Copy link
Contributor

github-actions bot commented Jun 8, 2022

Hey @eellison.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

samdow pushed a commit that referenced this pull request Jun 9, 2022
…mode is an ancestor"


Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping

[ghstack-poisoned]
samdow pushed a commit that referenced this pull request Jun 9, 2022
…tor"


Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping

[ghstack-poisoned]
samdow pushed a commit that referenced this pull request Jun 9, 2022
…mode is an ancestor"


Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping

[ghstack-poisoned]
samdow pushed a commit that referenced this pull request Jun 9, 2022
…tor"


Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping

[ghstack-poisoned]
samdow pushed a commit that referenced this pull request Jun 9, 2022
…mode is an ancestor"


Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping

[ghstack-poisoned]
samdow pushed a commit that referenced this pull request Jun 9, 2022
…tor"


Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping

[ghstack-poisoned]
ezyang pushed a commit to ezyang/pytorch that referenced this pull request Jun 9, 2022
…hold a reference

ghstack-source-id: c8f448a
Pull Request resolved: pytorch#78677
facebook-github-bot pushed a commit that referenced this pull request Jun 10, 2022
…hold a reference (#78677)

Summary:
Pull Request resolved: #78677

Approved by: https://github.com/ezyang

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/290d0979f1abf9408b1f285660420183fa8809ed

Reviewed By: osalpekar

Differential Revision: D37025723

Pulled By: eellison

fbshipit-source-id: 5ad6683e29bf990b330fee05795d345b8c12f3c2
samdow pushed a commit that referenced this pull request Jun 10, 2022
…mode is an ancestor"


Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping

[ghstack-poisoned]
samdow pushed a commit that referenced this pull request Jun 10, 2022
…tor"


Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping

[ghstack-poisoned]
samdow pushed a commit that referenced this pull request Jun 10, 2022
…mode is an ancestor"


Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping

[ghstack-poisoned]
samdow pushed a commit that referenced this pull request Jun 10, 2022
…tor"


Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping

[ghstack-poisoned]
@facebook-github-bot facebook-github-bot deleted the gh/eellison/305/head branch June 12, 2022 14:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants