Migrate FakeTensors to always call into FakeTensorMode and have them hold a reference#78677
Migrate FakeTensors to always call into FakeTensorMode and have them hold a reference#78677eellison wants to merge 14 commits intogh/eellison/305/basefrom
Conversation
…hold a reference [ghstack-poisoned]
🔗 Helpful links
✅ No Failures (0 Pending)As of commit f0b2b83 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
| kwargs = kwargs if kwargs else {} | ||
|
|
||
| # TODO: apply as no_dispatch decorator | ||
| with no_dispatch(): |
There was a problem hiding this comment.
I don't think you need this anymore, in the mode we automatically disable the current mode so if you call another function it will only hit the underlying dispatch. In fact doing it this way is wrong because it makes your mode non-compositional
There was a problem hiding this comment.
Would be great to have a test suite/ demo / something to test if your new tensor subclass is compositional because it's a bit hard to reason about without actually going through tests/debugging/seeing how things work together. In subclass zoo for example
There was a problem hiding this comment.
Without the no-dispatch I get infinite recursion. the r = func(*args, **kwargs) here just calls into __torch_dispatch__ again
There was a problem hiding this comment.
If we want this to work and it to be compositional with other modes, we'll need to unwrap args and kwargs. Usually this would just be arg.elem if isinstance(arg, FakeTensor) else arg but it seems like this will probably be more complicated before it can be a wrapper tensor? Not sure if @ezyang has any thoughts on how to do this
There was a problem hiding this comment.
you should only unwrap if the mode matches your mode
There was a problem hiding this comment.
Ah looks like this should work--so basically we should be saving elem and then unwrapping to that elem, like Ed said, if the mode matches your mode (currently this should always happen since we error if there's FakeTensors from more than one mode but we can be extra careful to set us up for nested FakeTensors if we want)
I got confused by this comment though: elem does not need to be recorded, because FakeTensor *is a* elem. FakeTensor should only be taking a vanilla tensor as elem?
There was a problem hiding this comment.
Ughhh, that comment is copy-pasta'd from here,,, maybe it's not applicable. Good catch.
Is there a way to unwrap a tensor that inherits from torch.Tensor ?
There was a problem hiding this comment.
Ah I see--that is Tracer/ProxyTensor specific. And kinda quirky
Yeah so the idea is that this is basically FakeTensor(torch.Tensor(...)) and we want to redispatch on torch.Tensor instead of on FakeTensor so that we don't get the infinite recursion you're seeing
So if we save elem, we should be able to unwrap by basically doing arg.elem if isinstance(arg, FakeTensor) and self == arg.fake_mode else arg for every arg and kwarg (this adds in Ed's check that we only unwrap tensors in the same mode) and then run func(*new_args, **new_kwargs)
There was a problem hiding this comment.
There is no elem, since FakeTensor inherits it's not compositional.. im going to defer this working for another pr
| if fake_mode is None: | ||
| fake_mode = arg.fake_mode | ||
| else: | ||
| assert fake_mode is arg.fake_mode, "Mixing modes NYI" |
There was a problem hiding this comment.
@samdow we should probably do a utility for doing this
| else: | ||
| assert fake_mode is arg.fake_mode, "Mixing modes NYI" | ||
|
|
||
| with enable_torch_dispatch_mode(fake_mode): |
There was a problem hiding this comment.
@samdow This probably can be a little more compositional if we have some sort of push many operation which will let you pushing arbitrary mode so long as the current mode is an ancestor of the mode you're trying to push.
| # FakeTensors store the FakeTensorMode which in turn stores a | ||
| # FakeTensor, so we need to hold a weak reference to the FakeTensor | ||
| # otherwise we would induce a circular reference | ||
| self.tensor_memo = weakref.WeakValueDictionary() |
There was a problem hiding this comment.
I actually thought this was going to be a weak key dictionary; but this is clearly what you need to prevent the reference cycle. Too bad there's not a WeakKeyValueDictionary lol
torch/_subclasses/fake_tensor.py
Outdated
|
|
||
| def from_real_tensor(self, fake_mode, t): | ||
| if self._get_memo(t) is not None: | ||
| return self.tensor_memo[t] |
There was a problem hiding this comment.
This looks wrong shouldn't you return the results of get memo
There was a problem hiding this comment.
wheres the walrus operator where u need it (doesnt work python 3.7)
… have them hold a reference"
Migrates FakeTensors to always invoke FakeTensorMode instead of calling into shared impl function. Because `FakeTensorMode` stores a `FakeTensorConverter`, which holds a map from Real Tensors -> Fake Tensors, this also changes the `tensor_memo` to use a `WeakValueDictionary` to avoid a circular reference from `FakeTensorMode` -> `FakeTensor` -> `FakeTensorMode`...
When instantiating `FakeTensors` you now need a corresponding `FakeTensorMode` for them to own. If creating new tensors, the idiom is similar to existing fake_mode` from `torchdistx.
```
with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
y = torch.rand([4], device="cpu")
```
When converting existing tensors to `FakeTensors`, the idiom is to instantiate a `FakeTensorMode`, and then that to use the conversion, with the new mode associated with the new tensor.
[ghstack-poisoned]
… have them hold a reference"
Migrates FakeTensors to always invoke FakeTensorMode instead of calling into shared impl function. Because `FakeTensorMode` stores a `FakeTensorConverter`, which holds a map from Real Tensors -> Fake Tensors, this also changes the `tensor_memo` to use a `WeakValueDictionary` to avoid a circular reference from `FakeTensorMode` -> `FakeTensor` -> `FakeTensorMode`...
When instantiating `FakeTensors` you now need a corresponding `FakeTensorMode` for them to own. If creating new tensors, the idiom is similar to existing fake_mode` from `torchdistx.
```
with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
y = torch.rand([4], device="cpu")
```
When converting existing tensors to `FakeTensors`, the idiom is to instantiate a `FakeTensorMode`, and then that to use the conversion, with the new mode associated with the new tensor.
[ghstack-poisoned]
… have them hold a reference"
Migrates FakeTensors to always invoke FakeTensorMode instead of calling into shared impl function. Because `FakeTensorMode` stores a `FakeTensorConverter`, which holds a map from Real Tensors -> Fake Tensors, this also changes the `tensor_memo` to use a `WeakValueDictionary` to avoid a circular reference from `FakeTensorMode` -> `FakeTensor` -> `FakeTensorMode`...
When instantiating `FakeTensors` you now need a corresponding `FakeTensorMode` for them to own. If creating new tensors, the idiom is similar to existing fake_mode` from `torchdistx.
```
with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
y = torch.rand([4], device="cpu")
```
When converting existing tensors to `FakeTensors`, the idiom is to instantiate a `FakeTensorMode`, and then that to use the conversion, with the new mode associated with the new tensor.
[ghstack-poisoned]
…mode is an ancestor" Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping [ghstack-poisoned]
…tor" Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping [ghstack-poisoned]
…mode is an ancestor" Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping [ghstack-poisoned]
…tor" Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping [ghstack-poisoned]
… have them hold a reference"
Migrates FakeTensors to always invoke FakeTensorMode instead of calling into shared impl function. Because `FakeTensorMode` stores a `FakeTensorConverter`, which holds a map from Real Tensors -> Fake Tensors, this also changes the `tensor_memo` to use a `WeakValueDictionary` to avoid a circular reference from `FakeTensorMode` -> `FakeTensor` -> `FakeTensorMode`...
When instantiating `FakeTensors` you now need a corresponding `FakeTensorMode` for them to own. If creating new tensors, the idiom is similar to existing fake_mode` from `torchdistx.
```
with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
y = torch.rand([4], device="cpu")
```
When converting existing tensors to `FakeTensors`, the idiom is to instantiate a `FakeTensorMode`, and then that to use the conversion, with the new mode associated with the new tensor.
[ghstack-poisoned]
… have them hold a reference"
Migrates FakeTensors to always invoke FakeTensorMode instead of calling into shared impl function. Because `FakeTensorMode` stores a `FakeTensorConverter`, which holds a map from Real Tensors -> Fake Tensors, this also changes the `tensor_memo` to use a `WeakValueDictionary` to avoid a circular reference from `FakeTensorMode` -> `FakeTensor` -> `FakeTensorMode`...
When instantiating `FakeTensors` you now need a corresponding `FakeTensorMode` for them to own. If creating new tensors, the idiom is similar to existing fake_mode` from `torchdistx.
```
with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
y = torch.rand([4], device="cpu")
```
When converting existing tensors to `FakeTensors`, the idiom is to instantiate a `FakeTensorMode`, and then that to use the conversion, with the new mode associated with the new tensor.
[ghstack-poisoned]
… have them hold a reference"
Migrates FakeTensors to always invoke FakeTensorMode instead of calling into shared impl function. Because `FakeTensorMode` stores a `FakeTensorConverter`, which holds a map from Real Tensors -> Fake Tensors, this also changes the `tensor_memo` to use a `WeakValueDictionary` to avoid a circular reference from `FakeTensorMode` -> `FakeTensor` -> `FakeTensorMode`...
When instantiating `FakeTensors` you now need a corresponding `FakeTensorMode` for them to own. If creating new tensors, the idiom is similar to existing fake_mode` from `torchdistx.
```
with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
y = torch.rand([4], device="cpu")
```
When converting existing tensors to `FakeTensors`, the idiom is to instantiate a `FakeTensorMode`, and then that to use the conversion, with the new mode associated with the new tensor.
[ghstack-poisoned]
… have them hold a reference"
Migrates FakeTensors to always invoke FakeTensorMode instead of calling into shared impl function. Because `FakeTensorMode` stores a `FakeTensorConverter`, which holds a map from Real Tensors -> Fake Tensors, this also changes the `tensor_memo` to use a `WeakValueDictionary` to avoid a circular reference from `FakeTensorMode` -> `FakeTensor` -> `FakeTensorMode`...
When instantiating `FakeTensors` you now need a corresponding `FakeTensorMode` for them to own. If creating new tensors, the idiom is similar to existing fake_mode` from `torchdistx.
```
with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
y = torch.rand([4], device="cpu")
```
When converting existing tensors to `FakeTensors`, the idiom is to instantiate a `FakeTensorMode`, and then that to use the conversion, with the new mode associated with the new tensor.
[ghstack-poisoned]
… have them hold a reference"
Migrates FakeTensors to always invoke FakeTensorMode instead of calling into shared impl function. Because `FakeTensorMode` stores a `FakeTensorConverter`, which holds a map from Real Tensors -> Fake Tensors, this also changes the `tensor_memo` to use a `WeakValueDictionary` to avoid a circular reference from `FakeTensorMode` -> `FakeTensor` -> `FakeTensorMode`...
When instantiating `FakeTensors` you now need a corresponding `FakeTensorMode` for them to own. If creating new tensors, the idiom is similar to existing fake_mode` from `torchdistx.
```
with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
y = torch.rand([4], device="cpu")
```
When converting existing tensors to `FakeTensors`, the idiom is to instantiate a `FakeTensorMode`, and then that to use the conversion, with the new mode associated with the new tensor.
[ghstack-poisoned]
… have them hold a reference"
Migrates FakeTensors to always invoke FakeTensorMode instead of calling into shared impl function. Because `FakeTensorMode` stores a `FakeTensorConverter`, which holds a map from Real Tensors -> Fake Tensors, this also changes the `tensor_memo` to use a `WeakValueDictionary` to avoid a circular reference from `FakeTensorMode` -> `FakeTensor` -> `FakeTensorMode`...
When instantiating `FakeTensors` you now need a corresponding `FakeTensorMode` for them to own. If creating new tensors, the idiom is similar to existing fake_mode` from `torchdistx.
```
with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
y = torch.rand([4], device="cpu")
```
When converting existing tensors to `FakeTensors`, the idiom is to instantiate a `FakeTensorMode`, and then that to use the conversion, with the new mode associated with the new tensor.
[ghstack-poisoned]
… have them hold a reference"
Migrates FakeTensors to always invoke FakeTensorMode instead of calling into shared impl function. Because `FakeTensorMode` stores a `FakeTensorConverter`, which holds a map from Real Tensors -> Fake Tensors, this also changes the `tensor_memo` to use a `WeakValueDictionary` to avoid a circular reference from `FakeTensorMode` -> `FakeTensor` -> `FakeTensorMode`...
When instantiating `FakeTensors` you now need a corresponding `FakeTensorMode` for them to own. If creating new tensors, the idiom is similar to existing fake_mode` from `torchdistx.
```
with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
y = torch.rand([4], device="cpu")
```
When converting existing tensors to `FakeTensors`, the idiom is to instantiate a `FakeTensorMode`, and then that to use the conversion, with the new mode associated with the new tensor.
[ghstack-poisoned]
…mode is an ancestor" Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping [ghstack-poisoned]
…tor" Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping [ghstack-poisoned]
…mode is an ancestor" Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping [ghstack-poisoned]
…tor" Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping [ghstack-poisoned]
… have them hold a reference"
Migrates FakeTensors to always invoke FakeTensorMode instead of calling into shared impl function. Because `FakeTensorMode` stores a `FakeTensorConverter`, which holds a map from Real Tensors -> Fake Tensors, this also changes the `tensor_memo` to use a `WeakValueDictionary` to avoid a circular reference from `FakeTensorMode` -> `FakeTensor` -> `FakeTensorMode`...
When instantiating `FakeTensors` you now need a corresponding `FakeTensorMode` for them to own. If creating new tensors, the idiom is similar to existing fake_mode` from `torchdistx.
```
with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
y = torch.rand([4], device="cpu")
```
When converting existing tensors to `FakeTensors`, the idiom is to instantiate a `FakeTensorMode`, and then that to use the conversion, with the new mode associated with the new tensor.
[ghstack-poisoned]
|
@pytorchbot merge |
|
@pytorchbot successfully started a merge job. Check the current status here |
|
Hey @eellison. |
…mode is an ancestor" Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping [ghstack-poisoned]
…tor" Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping [ghstack-poisoned]
…mode is an ancestor" Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping [ghstack-poisoned]
…tor" Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping [ghstack-poisoned]
…mode is an ancestor" Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping [ghstack-poisoned]
…tor" Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping [ghstack-poisoned]
…hold a reference ghstack-source-id: c8f448a Pull Request resolved: pytorch#78677
…hold a reference (#78677) Summary: Pull Request resolved: #78677 Approved by: https://github.com/ezyang Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/290d0979f1abf9408b1f285660420183fa8809ed Reviewed By: osalpekar Differential Revision: D37025723 Pulled By: eellison fbshipit-source-id: 5ad6683e29bf990b330fee05795d345b8c12f3c2
…mode is an ancestor" Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping [ghstack-poisoned]
…tor" Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping [ghstack-poisoned]
…mode is an ancestor" Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping [ghstack-poisoned]
…tor" Discussed briefly [here](#78677 (comment)) this lets a user set a mode whose inner is already set if that inner is the current mode or its an ancestor. This will be necessary for some more complicated use cases of composition and nice for non-lexical scoping [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Migrates FakeTensors to always invoke FakeTensorMode instead of calling into shared impl function. Because
FakeTensorModestores aFakeTensorConverter, which holds a map from Real Tensors -> Fake Tensors, this also changes thetensor_memoto use aWeakValueDictionaryto avoid a circular reference fromFakeTensorMode->FakeTensor->FakeTensorMode...When instantiating
FakeTensorsyou now need a correspondingFakeTensorModefor them to own. If creating new tensors, the idiom is similar to existing fake_modefromtorchdistx.When converting existing tensors to
FakeTensors, the idiom is to instantiate aFakeTensorMode, and then that to use the conversion, with the new mode associated with the new tensor.