Add Caching of Conversion to Fake/Meta tensors in FakeTensorMode#78090
Add Caching of Conversion to Fake/Meta tensors in FakeTensorMode#78090eellison wants to merge 23 commits intogh/eellison/296/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful links
❌ 2 New FailuresAs of commit 05ff3ff (more details on the Dr. CI page): Expand to see more
🕵️ 2 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages
|
…orMode" [ghstack-poisoned]
…orMode" [ghstack-poisoned]
…orMode" [ghstack-poisoned]
…orMode" This PR introduces a `FakeTensorConverter` similar to [MetaTensorConverter](https://github.com/pytorch/pytorch/blob/master/test/test_meta.py#L77) which caches conversions of Tensors to `FakeTensors` and uses `MetaTensorConverter` under the hood so that newly allocated FakeTensors will have the same storage. [ghstack-poisoned]
I'm not sure we want to do that? These assertions are important and the view behavior must be followed. While I agree that they are annoying, user code might get silently wrong if you do not respect these. |
| # multiple tensors into fake tensors which share the same view/storage | ||
| # structure. Like `MetaConverter`, it will keep alive all | ||
| # tensors that are converted to FakeTensors. | ||
| class FakeTensorConverter(MetaConverter): |
There was a problem hiding this comment.
thanks, the code reuse here means a lot to me
sorry it just disables them when a torch dispatch mode is set. As is, they prevent a bunch of different things you might want to do in torch dispatch. We could also make the active torch dispatch mode explicitly opt/into out of these assertions. |
| "_storage_id", | ||
| [](const at::Tensor& ten) -> int64_t { | ||
| return reinterpret_cast<int64_t>( | ||
| ten.storage().unsafeGetStorageImpl()); |
|
I'm a bit surprised that you needed the asserts to be disabled for this patch; it feels like part of the point of conversion caching (ugh, I hate this name, it implies that you can drop the caching and the result will be sound but that's not the case here) is to make sure the storage relationships are setup appropriately so the asserts don't fail. |
In this call
If you wanted to do something simple like hold onto every tensor that gets run through the backward and save it with its corresponding op usages, you would fail the IMO, at the very least, the current dispatch mode should be able to opt out of particular types of assertions. These are only run during debug so doing the extra querying etc shouldn't matter perf wise. |
|
Isn't a simple fix to create the Tensor inside your context? |
This doesn't mimic the existing
Something like
Link me ? I don't want to duplicate the conversion to meta storage that already exists so we should get alignment with @ezyang if this is the approach we want to go down. |
|
Okay, so I limited the scope of this PR to make FakeTensorMode only handle constructors and throw on non-fake inputs. This unblocks the TorchDynamo use case. However I'm planning on adding that in subsequent PRs, and with not supporting inplace operators that mutate metadata on non-fake inputs. Even for just caching here, some of the debug variants will break, and we've seen other issues with them for other workstreams (sym ints, functionalization). I still think it makes sense to disable the invariants when a torch_dispatch mode is set. I'm happy to make this a queryable property of the current mode if that's what we want to do. I tried switching to a modern-style mode but ran into issues which I will document. Please take another look. |
| namespace { | ||
|
|
||
| bool torch_dispatch_set() { | ||
| return static_cast<bool>(at::impl::TorchDispatchModeTLS::get_state()); |
There was a problem hiding this comment.
This would be more general if we also test for Python key on the argument tensors as well, as the assert disabling here only works for modes.
There was a problem hiding this comment.
Just to confirm we agree on this. These assert are failing for a valid reason here right? The view semantic is not properly implemented by this mode (view of an outside Tensor will not be properly updated if an inplace happens inside). So this disable is only intended to unblock experimentation while we work on a fix for that?
There was a problem hiding this comment.
See also #78519, and @bdhirsh says that he comments out the debug checks when developing.
The view semantic is not properly implemented by this mode (view of an outside Tensor will not be properly updated if an inplace happens inside)
If you memoize a Tensor's conversion to Fake from non-fake, and also check that on any subsequent use of its storage/tensorimpl that the metadata around the original Tensor hasn't changed the view semantic would be properly implemented.
As above with the other Tensor subclass issues, I think these checks are overly restrictive when trying to extend behavior. I think we should make this a queryable property on TensorModes/TensorSubclasses which checks to disable
There was a problem hiding this comment.
I definitely agree that they are very restrictive and can be annoying when trying things out. But that doesn't make them wrong. And for sure not a good reason to just disable them.
Maybe we want to add an env variable to disable them to ease local development?
…orMode" This PR does a few things to allow caching of conversion to Fake/Meta Tensors in order for the output FakeTensors to share storage/have accurate aliasing relationships (happy to break up prs as needed but they are all pretty intertwined): - invokes `setup_mode` and `cleanup_mode` in the invocation of `_enable_mode`. This is needed to set up a cache of Fake/Meta Tensor conversions. - Disables the debug invariant checking in `VariableType` that checks things like storage/tensor ptr counts == 1, or that the input and output share the same storage when a torch dispatch mode is set. These assertions limit what you can do with `torch_dispatch_mode`, and are only being called in DEBUG builds anyway. Maybe we could also have modes opt/into out of these assertions. - Introduces a `FakeTensorConverter` similar to [MetaTensorConverter](https://github.com/pytorch/pytorch/blob/master/test/test_meta.py#L77) which caches conversions of Tensors to `FakeTensors` and uses `MetaTensorConverter` under the hood so that newly allocated FakeTensors will have the same storage. There is one active `FakeTensorConverter` for the duration of `FakeTensorMode`. Since all newly allocated tensors will be on `meta` devices, memory should not significantly increase. The end result of the PR is so that you can do things like ``` x = torch.rand([4, 4]) with enable_torch_dispatch_mode(FakeTensorMode): # conversion from x to Meta/Fake cached for duration of the `FakeTensorMode` call y = x[0] z = x[1] self.assertEqual(torch._C._storage_id(y), torch._C._storage_id(z)) ``` [ghstack-poisoned]
…orMode" This PR does a few things to allow caching of conversion to Fake/Meta Tensors in order for the output FakeTensors to share storage/have accurate aliasing relationships (happy to break up prs as needed but they are all pretty intertwined): - invokes `setup_mode` and `cleanup_mode` in the invocation of `_enable_mode`. This is needed to set up a cache of Fake/Meta Tensor conversions. - Disables the debug invariant checking in `VariableType` that checks things like storage/tensor ptr counts == 1, or that the input and output share the same storage when a torch dispatch mode is set. These assertions limit what you can do with `torch_dispatch_mode`, and are only being called in DEBUG builds anyway. Maybe we could also have modes opt/into out of these assertions. - Introduces a `FakeTensorConverter` similar to [MetaTensorConverter](https://github.com/pytorch/pytorch/blob/master/test/test_meta.py#L77) which caches conversions of Tensors to `FakeTensors` and uses `MetaTensorConverter` under the hood so that newly allocated FakeTensors will have the same storage. There is one active `FakeTensorConverter` for the duration of `FakeTensorMode`. Since all newly allocated tensors will be on `meta` devices, memory should not significantly increase. The end result of the PR is so that you can do things like ``` x = torch.rand([4, 4]) with enable_torch_dispatch_mode(FakeTensorMode): # conversion from x to Meta/Fake cached for duration of the `FakeTensorMode` call y = x[0] z = x[1] self.assertEqual(torch._C._storage_id(y), torch._C._storage_id(z)) ``` [ghstack-poisoned]
…orMode" This PR does a few things to allow caching of conversion to Fake/Meta Tensors in order for the output FakeTensors to share storage/have accurate aliasing relationships (happy to break up prs as needed but they are all pretty intertwined): - invokes `setup_mode` and `cleanup_mode` in the invocation of `_enable_mode`. This is needed to set up a cache of Fake/Meta Tensor conversions. - Disables the debug invariant checking in `VariableType` that checks things like storage/tensor ptr counts == 1, or that the input and output share the same storage when a torch dispatch mode is set. These assertions limit what you can do with `torch_dispatch_mode`, and are only being called in DEBUG builds anyway. Maybe we could also have modes opt/into out of these assertions. - Introduces a `FakeTensorConverter` similar to [MetaTensorConverter](https://github.com/pytorch/pytorch/blob/master/test/test_meta.py#L77) which caches conversions of Tensors to `FakeTensors` and uses `MetaTensorConverter` under the hood so that newly allocated FakeTensors will have the same storage. There is one active `FakeTensorConverter` for the duration of `FakeTensorMode`. Since all newly allocated tensors will be on `meta` devices, memory should not significantly increase. The end result of the PR is so that you can do things like ``` x = torch.rand([4, 4]) with enable_torch_dispatch_mode(FakeTensorMode): # conversion from x to Meta/Fake cached for duration of the `FakeTensorMode` call y = x[0] z = x[1] self.assertEqual(torch._C._storage_id(y), torch._C._storage_id(z)) ``` [ghstack-poisoned]
…orMode" This PR does a few things to allow caching of conversion to Fake/Meta Tensors in order for the output FakeTensors to share storage/have accurate aliasing relationships (happy to break up prs as needed but they are all pretty intertwined): - invokes `setup_mode` and `cleanup_mode` in the invocation of `_enable_mode`. This is needed to set up a cache of Fake/Meta Tensor conversions. - Disables the debug invariant checking in `VariableType` that checks things like storage/tensor ptr counts == 1, or that the input and output share the same storage when a torch dispatch mode is set. These assertions limit what you can do with `torch_dispatch_mode`, and are only being called in DEBUG builds anyway. Maybe we could also have modes opt/into out of these assertions. - Introduces a `FakeTensorConverter` similar to [MetaTensorConverter](https://github.com/pytorch/pytorch/blob/master/test/test_meta.py#L77) which caches conversions of Tensors to `FakeTensors` and uses `MetaTensorConverter` under the hood so that newly allocated FakeTensors will have the same storage. There is one active `FakeTensorConverter` for the duration of `FakeTensorMode`. Since all newly allocated tensors will be on `meta` devices, memory should not significantly increase. The end result of the PR is so that you can do things like ``` x = torch.rand([4, 4]) with enable_torch_dispatch_mode(FakeTensorMode): # conversion from x to Meta/Fake cached for duration of the `FakeTensorMode` call y = x[0] z = x[1] self.assertEqual(torch._C._storage_id(y), torch._C._storage_id(z)) ``` [ghstack-poisoned]
…orMode" This PR does a few things to allow caching of conversion to Fake/Meta Tensors in order for the output FakeTensors to share storage/have accurate aliasing relationships (happy to break up prs as needed but they are all pretty intertwined): - invokes `setup_mode` and `cleanup_mode` in the invocation of `_enable_mode`. This is needed to set up a cache of Fake/Meta Tensor conversions. - Disables the debug invariant checking in `VariableType` that checks things like storage/tensor ptr counts == 1, or that the input and output share the same storage when a torch dispatch mode is set. These assertions limit what you can do with `torch_dispatch_mode`, and are only being called in DEBUG builds anyway. Maybe we could also have modes opt/into out of these assertions. - Introduces a `FakeTensorConverter` similar to [MetaTensorConverter](https://github.com/pytorch/pytorch/blob/master/test/test_meta.py#L77) which caches conversions of Tensors to `FakeTensors` and uses `MetaTensorConverter` under the hood so that newly allocated FakeTensors will have the same storage. There is one active `FakeTensorConverter` for the duration of `FakeTensorMode`. Since all newly allocated tensors will be on `meta` devices, memory should not significantly increase. The end result of the PR is so that you can do things like ``` x = torch.rand([4, 4]) with enable_torch_dispatch_mode(FakeTensorMode): # conversion from x to Meta/Fake cached for duration of the `FakeTensorMode` call y = x[0] z = x[1] self.assertEqual(torch._C._storage_id(y), torch._C._storage_id(z)) ``` [ghstack-poisoned]
…orMode" This PR does a few things to allow caching of conversion to Fake/Meta Tensors in order for the output FakeTensors to share storage/have accurate aliasing relationships (happy to break up prs as needed but they are all pretty intertwined): - invokes `setup_mode` and `cleanup_mode` in the invocation of `_enable_mode`. This is needed to set up a cache of Fake/Meta Tensor conversions. - Disables the debug invariant checking in `VariableType` that checks things like storage/tensor ptr counts == 1, or that the input and output share the same storage when a torch dispatch mode is set. These assertions limit what you can do with `torch_dispatch_mode`, and are only being called in DEBUG builds anyway. Maybe we could also have modes opt/into out of these assertions. - Introduces a `FakeTensorConverter` similar to [MetaTensorConverter](https://github.com/pytorch/pytorch/blob/master/test/test_meta.py#L77) which caches conversions of Tensors to `FakeTensors` and uses `MetaTensorConverter` under the hood so that newly allocated FakeTensors will have the same storage. There is one active `FakeTensorConverter` for the duration of `FakeTensorMode`. Since all newly allocated tensors will be on `meta` devices, memory should not significantly increase. The end result of the PR is so that you can do things like ``` x = torch.rand([4, 4]) with enable_torch_dispatch_mode(FakeTensorMode): # conversion from x to Meta/Fake cached for duration of the `FakeTensorMode` call y = x[0] z = x[1] self.assertEqual(torch._C._storage_id(y), torch._C._storage_id(z)) ``` [ghstack-poisoned]
…orMode" This PR does a few things to allow caching of conversion to Fake/Meta Tensors in order for the output FakeTensors to share storage/have accurate aliasing relationships (happy to break up prs as needed but they are all pretty intertwined): - invokes `setup_mode` and `cleanup_mode` in the invocation of `_enable_mode`. This is needed to set up a cache of Fake/Meta Tensor conversions. - Disables the debug invariant checking in `VariableType` that checks things like storage/tensor ptr counts == 1, or that the input and output share the same storage when a torch dispatch mode is set. These assertions limit what you can do with `torch_dispatch_mode`, and are only being called in DEBUG builds anyway. Maybe we could also have modes opt/into out of these assertions. - Introduces a `FakeTensorConverter` similar to [MetaTensorConverter](https://github.com/pytorch/pytorch/blob/master/test/test_meta.py#L77) which caches conversions of Tensors to `FakeTensors` and uses `MetaTensorConverter` under the hood so that newly allocated FakeTensors will have the same storage. There is one active `FakeTensorConverter` for the duration of `FakeTensorMode`. Since all newly allocated tensors will be on `meta` devices, memory should not significantly increase. The end result of the PR is so that you can do things like ``` x = torch.rand([4, 4]) with enable_torch_dispatch_mode(FakeTensorMode): # conversion from x to Meta/Fake cached for duration of the `FakeTensorMode` call y = x[0] z = x[1] self.assertEqual(torch._C._storage_id(y), torch._C._storage_id(z)) ``` [ghstack-poisoned]
…orMode" This PR does a few things to allow caching of conversion to Fake/Meta Tensors in order for the output FakeTensors to share storage/have accurate aliasing relationships (happy to break up prs as needed but they are all pretty intertwined): - invokes `setup_mode` and `cleanup_mode` in the invocation of `_enable_mode`. This is needed to set up a cache of Fake/Meta Tensor conversions. - Disables the debug invariant checking in `VariableType` that checks things like storage/tensor ptr counts == 1, or that the input and output share the same storage when a torch dispatch mode is set. These assertions limit what you can do with `torch_dispatch_mode`, and are only being called in DEBUG builds anyway. Maybe we could also have modes opt/into out of these assertions. - Introduces a `FakeTensorConverter` similar to [MetaTensorConverter](https://github.com/pytorch/pytorch/blob/master/test/test_meta.py#L77) which caches conversions of Tensors to `FakeTensors` and uses `MetaTensorConverter` under the hood so that newly allocated FakeTensors will have the same storage. There is one active `FakeTensorConverter` for the duration of `FakeTensorMode`. Since all newly allocated tensors will be on `meta` devices, memory should not significantly increase. The end result of the PR is so that you can do things like ``` x = torch.rand([4, 4]) with enable_torch_dispatch_mode(FakeTensorMode): # conversion from x to Meta/Fake cached for duration of the `FakeTensorMode` call y = x[0] z = x[1] self.assertEqual(torch._C._storage_id(y), torch._C._storage_id(z)) ``` [ghstack-poisoned]
…orMode" This PR does a few things to allow caching of conversion to Fake/Meta Tensors in order for the output FakeTensors to share storage/have accurate aliasing relationships (happy to break up prs as needed but they are all pretty intertwined): - invokes `setup_mode` and `cleanup_mode` in the invocation of `_enable_mode`. This is needed to set up a cache of Fake/Meta Tensor conversions. - Disables the debug invariant checking in `VariableType` that checks things like storage/tensor ptr counts == 1, or that the input and output share the same storage when a torch dispatch mode is set. These assertions limit what you can do with `torch_dispatch_mode`, and are only being called in DEBUG builds anyway. Maybe we could also have modes opt/into out of these assertions. - Introduces a `FakeTensorConverter` similar to [MetaTensorConverter](https://github.com/pytorch/pytorch/blob/master/test/test_meta.py#L77) which caches conversions of Tensors to `FakeTensors` and uses `MetaTensorConverter` under the hood so that newly allocated FakeTensors will have the same storage. There is one active `FakeTensorConverter` for the duration of `FakeTensorMode`. Since all newly allocated tensors will be on `meta` devices, memory should not significantly increase. The end result of the PR is so that you can do things like ``` x = torch.rand([4, 4]) with enable_torch_dispatch_mode(FakeTensorMode): # conversion from x to Meta/Fake cached for duration of the `FakeTensorMode` call y = x[0] z = x[1] self.assertEqual(torch._C._storage_id(y), torch._C._storage_id(z)) ``` [ghstack-poisoned]
|
@pytorchbot merge this please |
|
❌ 🤖 pytorchbot command failed: Try |
|
@pytorchbot merge |
|
Hey @eellison. |
) Summary: Pull Request resolved: #78090 Approved by: https://github.com/ezyang Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/26d273959c197e59d9d3e4246d4c1ad63d690137 Reviewed By: b0noI Differential Revision: D36897423 fbshipit-source-id: c220ec59bf02c455160bdf058b00ff5d2c667f27
Stack from ghstack (oldest at bottom):
This PR does a few things to allow caching of conversion to Fake/Meta Tensors in order for the output FakeTensors to share storage/have accurate aliasing relationships (happy to break up prs as needed but they are all pretty intertwined):
invokes
setup_modeandcleanup_modein the invocation of_enable_mode. This is needed to set up a cache of Fake/Meta Tensor conversions.Disables the debug invariant checking in
VariableTypethat checks things like storage/tensor ptr counts == 1, or that the input and output share the same storage when a torch dispatch mode is set. These assertions limit what you can do withtorch_dispatch_mode, and are only being called in DEBUG builds anyway. Maybe we could also have modes opt/into out of these assertions.Introduces a
FakeTensorConvertersimilar to MetaTensorConverter which caches conversions of Tensors toFakeTensorsand usesMetaTensorConverterunder the hood so that newly allocated FakeTensors will have the same storage. There is one activeFakeTensorConverterfor the duration ofFakeTensorMode. Since all newly allocated tensors will be onmetadevices, memory should not significantly increase.The end result of the PR is so that you can do things like