Skip to content

Add Caching of Conversion to Fake/Meta tensors in FakeTensorMode#78090

Closed
eellison wants to merge 23 commits intogh/eellison/296/basefrom
gh/eellison/296/head
Closed

Add Caching of Conversion to Fake/Meta tensors in FakeTensorMode#78090
eellison wants to merge 23 commits intogh/eellison/296/basefrom
gh/eellison/296/head

Conversation

@eellison
Copy link
Contributor

@eellison eellison commented May 23, 2022

Stack from ghstack (oldest at bottom):

This PR does a few things to allow caching of conversion to Fake/Meta Tensors in order for the output FakeTensors to share storage/have accurate aliasing relationships (happy to break up prs as needed but they are all pretty intertwined):

  • invokes setup_mode and cleanup_mode in the invocation of _enable_mode. This is needed to set up a cache of Fake/Meta Tensor conversions.

  • Disables the debug invariant checking in VariableType that checks things like storage/tensor ptr counts == 1, or that the input and output share the same storage when a torch dispatch mode is set. These assertions limit what you can do with torch_dispatch_mode, and are only being called in DEBUG builds anyway. Maybe we could also have modes opt/into out of these assertions.

  • Introduces a FakeTensorConverter similar to MetaTensorConverter which caches conversions of Tensors to FakeTensors and uses MetaTensorConverter under the hood so that newly allocated FakeTensors will have the same storage. There is one active FakeTensorConverter for the duration of FakeTensorMode. Since all newly allocated tensors will be on meta devices, memory should not significantly increase.

The end result of the PR is so that you can do things like

x = torch.rand([4, 4])

with enable_torch_dispatch_mode(FakeTensorMode):
     # conversion from x to Meta/Fake cached for duration of the `FakeTensorMode` call
     y = x[0]
     z = x[1]

self.assertEqual(torch._C._storage_id(y), torch._C._storage_id(z))

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented May 23, 2022

🔗 Helpful links

❌ 2 New Failures

As of commit 05ff3ff (more details on the Dr. CI page):

Expand to see more
  • 2/2 failures introduced in this PR

🕵️ 2 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages

See GitHub Actions build pull / pytorch-xla-linux-bionic-py3.7-clang8 / test (xla, 1, 1, linux.2xlarge) (1/2)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-06-02T20:02:25.7108037Z RuntimeError: /var...sor_impl.cpp:163 : XLA tensors do not have storage
2022-06-02T20:02:25.7098559Z ----------------------------------------------------------------------
2022-06-02T20:02:25.7099055Z Traceback (most recent call last):
2022-06-02T20:02:25.7099794Z   File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_device_type.py", line 390, in instantiated_test
2022-06-02T20:02:25.7100301Z     raise rte
2022-06-02T20:02:25.7100925Z   File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_device_type.py", line 377, in instantiated_test
2022-06-02T20:02:25.7101785Z     result = test(self, **param_kwargs)
2022-06-02T20:02:25.7102286Z   File "/var/lib/jenkins/workspace/xla/test/../../test/test_nn.py", line 16495, in test_embedding_max_norm_fwd_AD
2022-06-02T20:02:25.7102971Z     dual_weight = torch.autograd.forward_ad.make_dual(weight, tangent)
2022-06-02T20:02:25.7106901Z   File "/opt/conda/lib/python3.7/site-packages/torch/autograd/forward_ad.py", line 79, in make_dual
2022-06-02T20:02:25.7107467Z     return torch._VF._make_dual(tensor, tangent, level=level)
2022-06-02T20:02:25.7108037Z RuntimeError: /var/lib/jenkins/workspace/xla/torch_xla/csrc/tensor_impl.cpp:163 : XLA tensors do not have storage
2022-06-02T20:02:25.7108429Z 
2022-06-02T20:02:25.8581943Z ----------------------------------------------------------------------
2022-06-02T20:02:25.8582240Z Ran 912 tests in 721.466s
2022-06-02T20:02:25.8582346Z 
2022-06-02T20:02:25.8582477Z FAILED (errors=1, skipped=708, expected failures=3)
2022-06-02T20:02:25.8582624Z 
2022-06-02T20:02:25.8582710Z Generating XML reports...
2022-06-02T20:02:25.8583161Z Generated XML report: test-reports/python-unittest/test.......test.test_nn/TEST-TestNNDeviceTypeXLA-20220602195024.xml
2022-06-02T20:02:26.3746414Z + cleanup
2022-06-02T20:02:26.3746765Z + retcode=1

See GitHub Actions build pull / linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 1, 4, linux.4xlarge.nvidia.gpu) (2/2)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-06-02T19:39:40.5955023Z RuntimeError: test_ops failed! Received signal: SIGIOT
2022-06-02T19:39:37.9782740Z   test_backward_baddbmm_cuda_float32 (__main__.TestCompositeComplianceCUDA) ... ok (0.126s)
2022-06-02T19:39:37.9884555Z   test_backward_bernoulli_cuda_float32 (__main__.TestCompositeComplianceCUDA) ... ok (0.010s)
2022-06-02T19:39:37.9989819Z   test_backward_bfloat16_cuda_float32 (__main__.TestCompositeComplianceCUDA) ... ok (0.010s)
2022-06-02T19:39:38.0727232Z   test_backward_block_diag_cuda_float32 (__main__.TestCompositeComplianceCUDA) ... ok (0.073s)
2022-06-02T19:39:38.0814963Z   test_backward_bmm_cuda_float32 (__main__.TestCompositeComplianceCUDA) ... ok (0.009s)
2022-06-02T19:39:40.5947722Z   test_backward_broadcast_tensors_cuda_float32 (__main__.TestCompositeComplianceCUDA) ... Traceback (most recent call last):
2022-06-02T19:39:40.5948238Z   File "test/run_test.py", line 1077, in <module>
2022-06-02T19:39:40.5951550Z     main()
2022-06-02T19:39:40.5952636Z   File "test/run_test.py", line 1055, in main
2022-06-02T19:39:40.5954566Z     raise RuntimeError(err_message)
2022-06-02T19:39:40.5955023Z RuntimeError: test_ops failed! Received signal: SIGIOT
2022-06-02T19:39:41.7139089Z + cleanup
2022-06-02T19:39:41.7139379Z + retcode=1
2022-06-02T19:39:41.7142543Z + set +x
2022-06-02T19:39:41.7185676Z ##[error]Process completed with exit code 1.
2022-06-02T19:39:41.7238001Z ##[group]Run pytorch/pytorch/.github/actions/get-workflow-job-id@master
2022-06-02T19:39:41.7238358Z with:
2022-06-02T19:39:41.7238908Z   github-token: ***
2022-06-02T19:39:41.7239157Z env:
2022-06-02T19:39:41.7239375Z   IN_CI: 1
2022-06-02T19:39:41.7239588Z   IS_GHA: 1

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label May 23, 2022
@eellison eellison mentioned this pull request May 23, 2022
eellison pushed a commit that referenced this pull request May 23, 2022
…orMode"


This PR introduces a `FakeTensorConverter` similar to [MetaTensorConverter](https://github.com/pytorch/pytorch/blob/master/test/test_meta.py#L77) which caches conversions of Tensors to `FakeTensors` and uses `MetaTensorConverter` under the hood so that newly allocated FakeTensors will have the same storage. 



[ghstack-poisoned]
eellison pushed a commit that referenced this pull request May 23, 2022
@eellison eellison requested a review from ezyang May 23, 2022 16:54
@albanD
Copy link
Collaborator

albanD commented May 23, 2022

Disables the debug invariant checking in VariableType that checks things like storage/tensor ptr counts == 1, or that the input and output share the same storage. These assertions limit what you can do with torch_dispatch_mode, and are only being called in DEBUG builds anyway. Maybe we could also have modes opt/into out of these assertions.

I'm not sure we want to do that? These assertions are important and the view behavior must be followed. While I agree that they are annoying, user code might get silently wrong if you do not respect these.

# multiple tensors into fake tensors which share the same view/storage
# structure. Like `MetaConverter`, it will keep alive all
# tensors that are converted to FakeTensors.
class FakeTensorConverter(MetaConverter):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, the code reuse here means a lot to me

@eellison
Copy link
Contributor Author

eellison commented May 23, 2022

@albanD

I'm not sorry we want to do that? These assertions are important and the view behavior must be followed. While I agree that they are annoying, user code might get silently wrong if you do not respect these.

sorry it just disables them when a torch dispatch mode is set. As is, they prevent a bunch of different things you might want to do in torch dispatch. We could also make the active torch dispatch mode explicitly opt/into out of these assertions.

"_storage_id",
[](const at::Tensor& ten) -> int64_t {
return reinterpret_cast<int64_t>(
ten.storage().unsafeGetStorageImpl());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't be necessary once #78008 lands

@ezyang
Copy link
Contributor

ezyang commented May 23, 2022

I'm a bit surprised that you needed the asserts to be disabled for this patch; it feels like part of the point of conversion caching (ugh, I hate this name, it implies that you can drop the caching and the result will be sound but that's not the case here) is to make sure the storage relationships are setup appropriately so the asserts don't fail.

@eellison
Copy link
Contributor Author

eellison commented May 23, 2022

I'm a bit surprised

In this call

x = torch.rand([4, 4])

with enable_torch_dispatch_mode(FakeTensorMode):
     # conversion from x to Meta/Fake cached for duration of the `FakeTensorMode` call
     y = x[0]

y won't share a storage with x, because y will have a meta storage and x won't. So it will fail ENFORCE_SAME_TENSOR_STORAGE.

If you wanted to do something simple like hold onto every tensor that gets run through the backward and save it with its corresponding op usages, you would fail the AT_ASSERT(${tensor_name}.use_count() <= 1, "function: ${fn_name}"); call.

IMO, at the very least, the current dispatch mode should be able to opt out of particular types of assertions. These are only run during debug so doing the extra querying etc shouldn't matter perf wise.

@albanD
Copy link
Collaborator

albanD commented May 23, 2022

Isn't a simple fix to create the Tensor inside your context?
We might consider an option to disable them. But that should definitely not be the default. In your example, any inplace op will be just wrong.
Note that if the idea is that you are re-implementing the view semantic but don't use storage for that (which trips this test). Then I think we are happy to add new API to fix that. Most like you want to do like functionalization and have a fake storage that you re-use to show that your Tensors share data. That fake storage is also a good place for you to add your logic that reproduces the view behavior as it will be shared by all the Tensors that are supposed to be views of each other.

@eellison
Copy link
Contributor Author

eellison commented May 23, 2022

Isn't a simple fix to create the Tensor inside your context?

This doesn't mimic the existing FakeTensor mode usage, and I think would be cumbersome.

In your example, any inplace op will be just wrong

Something like add_ only changes the values of the tensors, which as far as fake tensors go is irrelevant, so I wouldn't consider it to be wrong. for resize_ etc, we would probably want to throw if the inputs are not FakeTensors already. I think this would be an unlikely case, as you're unlikely to resize_ on your module parameters or inputs.

Most like you want to do like functionalization and have a fake storage that you re-use to show that your Tensors share data

Link me ? I don't want to duplicate the conversion to meta storage that already exists so we should get alignment with @ezyang if this is the approach we want to go down.

@eellison eellison mentioned this pull request May 25, 2022
@eellison
Copy link
Contributor Author

Okay, so I limited the scope of this PR to make FakeTensorMode only handle constructors and throw on non-fake inputs. This unblocks the TorchDynamo use case. However I'm planning on adding that in subsequent PRs, and with not supporting inplace operators that mutate metadata on non-fake inputs.

Even for just caching here, some of the debug variants will break, and we've seen other issues with them for other workstreams (sym ints, functionalization). I still think it makes sense to disable the invariants when a torch_dispatch mode is set. I'm happy to make this a queryable property of the current mode if that's what we want to do.

I tried switching to a modern-style mode but ran into issues which I will document. Please take another look.

@eellison eellison requested a review from ezyang May 25, 2022 17:05
namespace {

bool torch_dispatch_set() {
return static_cast<bool>(at::impl::TorchDispatchModeTLS::get_state());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be more general if we also test for Python key on the argument tensors as well, as the assert disabling here only works for modes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm we agree on this. These assert are failing for a valid reason here right? The view semantic is not properly implemented by this mode (view of an outside Tensor will not be properly updated if an inplace happens inside). So this disable is only intended to unblock experimentation while we work on a fix for that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See also #78519, and @bdhirsh says that he comments out the debug checks when developing.

The view semantic is not properly implemented by this mode (view of an outside Tensor will not be properly updated if an inplace happens inside)

If you memoize a Tensor's conversion to Fake from non-fake, and also check that on any subsequent use of its storage/tensorimpl that the metadata around the original Tensor hasn't changed the view semantic would be properly implemented.

As above with the other Tensor subclass issues, I think these checks are overly restrictive when trying to extend behavior. I think we should make this a queryable property on TensorModes/TensorSubclasses which checks to disable

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I definitely agree that they are very restrictive and can be annoying when trying things out. But that doesn't make them wrong. And for sure not a good reason to just disable them.
Maybe we want to add an env variable to disable them to ease local development?

…orMode"


This PR does a few things to allow caching of conversion to Fake/Meta Tensors in order for the output FakeTensors to share storage/have accurate aliasing relationships (happy to break up prs as needed but they are all pretty intertwined): 

- invokes `setup_mode` and `cleanup_mode` in the invocation of `_enable_mode`. This is needed to set up a cache of Fake/Meta Tensor conversions. 

- Disables the debug invariant checking in `VariableType` that checks things like storage/tensor ptr counts == 1, or that the input and output share the same storage when a torch dispatch mode is set. These assertions limit what you can do with `torch_dispatch_mode`, and are only being called in DEBUG builds anyway. Maybe we could also have modes opt/into out of these assertions.

- Introduces a `FakeTensorConverter` similar to [MetaTensorConverter](https://github.com/pytorch/pytorch/blob/master/test/test_meta.py#L77) which caches conversions of Tensors to `FakeTensors` and uses `MetaTensorConverter` under the hood so that newly allocated FakeTensors will have the same storage. There is one active `FakeTensorConverter` for the duration of `FakeTensorMode`. Since all newly allocated tensors will be on `meta` devices, memory should not significantly increase.

The end result of the PR is so that you can do things like 
```
x = torch.rand([4, 4])

with enable_torch_dispatch_mode(FakeTensorMode):
     # conversion from x to Meta/Fake cached for duration of the `FakeTensorMode` call
     y = x[0]
     z = x[1]

self.assertEqual(torch._C._storage_id(y), torch._C._storage_id(z))
```

[ghstack-poisoned]
…orMode"


This PR does a few things to allow caching of conversion to Fake/Meta Tensors in order for the output FakeTensors to share storage/have accurate aliasing relationships (happy to break up prs as needed but they are all pretty intertwined): 

- invokes `setup_mode` and `cleanup_mode` in the invocation of `_enable_mode`. This is needed to set up a cache of Fake/Meta Tensor conversions. 

- Disables the debug invariant checking in `VariableType` that checks things like storage/tensor ptr counts == 1, or that the input and output share the same storage when a torch dispatch mode is set. These assertions limit what you can do with `torch_dispatch_mode`, and are only being called in DEBUG builds anyway. Maybe we could also have modes opt/into out of these assertions.

- Introduces a `FakeTensorConverter` similar to [MetaTensorConverter](https://github.com/pytorch/pytorch/blob/master/test/test_meta.py#L77) which caches conversions of Tensors to `FakeTensors` and uses `MetaTensorConverter` under the hood so that newly allocated FakeTensors will have the same storage. There is one active `FakeTensorConverter` for the duration of `FakeTensorMode`. Since all newly allocated tensors will be on `meta` devices, memory should not significantly increase.

The end result of the PR is so that you can do things like 
```
x = torch.rand([4, 4])

with enable_torch_dispatch_mode(FakeTensorMode):
     # conversion from x to Meta/Fake cached for duration of the `FakeTensorMode` call
     y = x[0]
     z = x[1]

self.assertEqual(torch._C._storage_id(y), torch._C._storage_id(z))
```

[ghstack-poisoned]
…orMode"


This PR does a few things to allow caching of conversion to Fake/Meta Tensors in order for the output FakeTensors to share storage/have accurate aliasing relationships (happy to break up prs as needed but they are all pretty intertwined): 

- invokes `setup_mode` and `cleanup_mode` in the invocation of `_enable_mode`. This is needed to set up a cache of Fake/Meta Tensor conversions. 

- Disables the debug invariant checking in `VariableType` that checks things like storage/tensor ptr counts == 1, or that the input and output share the same storage when a torch dispatch mode is set. These assertions limit what you can do with `torch_dispatch_mode`, and are only being called in DEBUG builds anyway. Maybe we could also have modes opt/into out of these assertions.

- Introduces a `FakeTensorConverter` similar to [MetaTensorConverter](https://github.com/pytorch/pytorch/blob/master/test/test_meta.py#L77) which caches conversions of Tensors to `FakeTensors` and uses `MetaTensorConverter` under the hood so that newly allocated FakeTensors will have the same storage. There is one active `FakeTensorConverter` for the duration of `FakeTensorMode`. Since all newly allocated tensors will be on `meta` devices, memory should not significantly increase.

The end result of the PR is so that you can do things like 
```
x = torch.rand([4, 4])

with enable_torch_dispatch_mode(FakeTensorMode):
     # conversion from x to Meta/Fake cached for duration of the `FakeTensorMode` call
     y = x[0]
     z = x[1]

self.assertEqual(torch._C._storage_id(y), torch._C._storage_id(z))
```

[ghstack-poisoned]
@eellison eellison mentioned this pull request May 31, 2022
Elias Ellison added 2 commits May 31, 2022 09:16
…orMode"


This PR does a few things to allow caching of conversion to Fake/Meta Tensors in order for the output FakeTensors to share storage/have accurate aliasing relationships (happy to break up prs as needed but they are all pretty intertwined): 

- invokes `setup_mode` and `cleanup_mode` in the invocation of `_enable_mode`. This is needed to set up a cache of Fake/Meta Tensor conversions. 

- Disables the debug invariant checking in `VariableType` that checks things like storage/tensor ptr counts == 1, or that the input and output share the same storage when a torch dispatch mode is set. These assertions limit what you can do with `torch_dispatch_mode`, and are only being called in DEBUG builds anyway. Maybe we could also have modes opt/into out of these assertions.

- Introduces a `FakeTensorConverter` similar to [MetaTensorConverter](https://github.com/pytorch/pytorch/blob/master/test/test_meta.py#L77) which caches conversions of Tensors to `FakeTensors` and uses `MetaTensorConverter` under the hood so that newly allocated FakeTensors will have the same storage. There is one active `FakeTensorConverter` for the duration of `FakeTensorMode`. Since all newly allocated tensors will be on `meta` devices, memory should not significantly increase.

The end result of the PR is so that you can do things like 
```
x = torch.rand([4, 4])

with enable_torch_dispatch_mode(FakeTensorMode):
     # conversion from x to Meta/Fake cached for duration of the `FakeTensorMode` call
     y = x[0]
     z = x[1]

self.assertEqual(torch._C._storage_id(y), torch._C._storage_id(z))
```

[ghstack-poisoned]
…orMode"


This PR does a few things to allow caching of conversion to Fake/Meta Tensors in order for the output FakeTensors to share storage/have accurate aliasing relationships (happy to break up prs as needed but they are all pretty intertwined): 

- invokes `setup_mode` and `cleanup_mode` in the invocation of `_enable_mode`. This is needed to set up a cache of Fake/Meta Tensor conversions. 

- Disables the debug invariant checking in `VariableType` that checks things like storage/tensor ptr counts == 1, or that the input and output share the same storage when a torch dispatch mode is set. These assertions limit what you can do with `torch_dispatch_mode`, and are only being called in DEBUG builds anyway. Maybe we could also have modes opt/into out of these assertions.

- Introduces a `FakeTensorConverter` similar to [MetaTensorConverter](https://github.com/pytorch/pytorch/blob/master/test/test_meta.py#L77) which caches conversions of Tensors to `FakeTensors` and uses `MetaTensorConverter` under the hood so that newly allocated FakeTensors will have the same storage. There is one active `FakeTensorConverter` for the duration of `FakeTensorMode`. Since all newly allocated tensors will be on `meta` devices, memory should not significantly increase.

The end result of the PR is so that you can do things like 
```
x = torch.rand([4, 4])

with enable_torch_dispatch_mode(FakeTensorMode):
     # conversion from x to Meta/Fake cached for duration of the `FakeTensorMode` call
     y = x[0]
     z = x[1]

self.assertEqual(torch._C._storage_id(y), torch._C._storage_id(z))
```

[ghstack-poisoned]
…orMode"


This PR does a few things to allow caching of conversion to Fake/Meta Tensors in order for the output FakeTensors to share storage/have accurate aliasing relationships (happy to break up prs as needed but they are all pretty intertwined): 

- invokes `setup_mode` and `cleanup_mode` in the invocation of `_enable_mode`. This is needed to set up a cache of Fake/Meta Tensor conversions. 

- Disables the debug invariant checking in `VariableType` that checks things like storage/tensor ptr counts == 1, or that the input and output share the same storage when a torch dispatch mode is set. These assertions limit what you can do with `torch_dispatch_mode`, and are only being called in DEBUG builds anyway. Maybe we could also have modes opt/into out of these assertions.

- Introduces a `FakeTensorConverter` similar to [MetaTensorConverter](https://github.com/pytorch/pytorch/blob/master/test/test_meta.py#L77) which caches conversions of Tensors to `FakeTensors` and uses `MetaTensorConverter` under the hood so that newly allocated FakeTensors will have the same storage. There is one active `FakeTensorConverter` for the duration of `FakeTensorMode`. Since all newly allocated tensors will be on `meta` devices, memory should not significantly increase.

The end result of the PR is so that you can do things like 
```
x = torch.rand([4, 4])

with enable_torch_dispatch_mode(FakeTensorMode):
     # conversion from x to Meta/Fake cached for duration of the `FakeTensorMode` call
     y = x[0]
     z = x[1]

self.assertEqual(torch._C._storage_id(y), torch._C._storage_id(z))
```

[ghstack-poisoned]
eellison added 3 commits June 2, 2022 07:16
…orMode"


This PR does a few things to allow caching of conversion to Fake/Meta Tensors in order for the output FakeTensors to share storage/have accurate aliasing relationships (happy to break up prs as needed but they are all pretty intertwined): 

- invokes `setup_mode` and `cleanup_mode` in the invocation of `_enable_mode`. This is needed to set up a cache of Fake/Meta Tensor conversions. 

- Disables the debug invariant checking in `VariableType` that checks things like storage/tensor ptr counts == 1, or that the input and output share the same storage when a torch dispatch mode is set. These assertions limit what you can do with `torch_dispatch_mode`, and are only being called in DEBUG builds anyway. Maybe we could also have modes opt/into out of these assertions.

- Introduces a `FakeTensorConverter` similar to [MetaTensorConverter](https://github.com/pytorch/pytorch/blob/master/test/test_meta.py#L77) which caches conversions of Tensors to `FakeTensors` and uses `MetaTensorConverter` under the hood so that newly allocated FakeTensors will have the same storage. There is one active `FakeTensorConverter` for the duration of `FakeTensorMode`. Since all newly allocated tensors will be on `meta` devices, memory should not significantly increase.

The end result of the PR is so that you can do things like 
```
x = torch.rand([4, 4])

with enable_torch_dispatch_mode(FakeTensorMode):
     # conversion from x to Meta/Fake cached for duration of the `FakeTensorMode` call
     y = x[0]
     z = x[1]

self.assertEqual(torch._C._storage_id(y), torch._C._storage_id(z))
```

[ghstack-poisoned]
…orMode"


This PR does a few things to allow caching of conversion to Fake/Meta Tensors in order for the output FakeTensors to share storage/have accurate aliasing relationships (happy to break up prs as needed but they are all pretty intertwined): 

- invokes `setup_mode` and `cleanup_mode` in the invocation of `_enable_mode`. This is needed to set up a cache of Fake/Meta Tensor conversions. 

- Disables the debug invariant checking in `VariableType` that checks things like storage/tensor ptr counts == 1, or that the input and output share the same storage when a torch dispatch mode is set. These assertions limit what you can do with `torch_dispatch_mode`, and are only being called in DEBUG builds anyway. Maybe we could also have modes opt/into out of these assertions.

- Introduces a `FakeTensorConverter` similar to [MetaTensorConverter](https://github.com/pytorch/pytorch/blob/master/test/test_meta.py#L77) which caches conversions of Tensors to `FakeTensors` and uses `MetaTensorConverter` under the hood so that newly allocated FakeTensors will have the same storage. There is one active `FakeTensorConverter` for the duration of `FakeTensorMode`. Since all newly allocated tensors will be on `meta` devices, memory should not significantly increase.

The end result of the PR is so that you can do things like 
```
x = torch.rand([4, 4])

with enable_torch_dispatch_mode(FakeTensorMode):
     # conversion from x to Meta/Fake cached for duration of the `FakeTensorMode` call
     y = x[0]
     z = x[1]

self.assertEqual(torch._C._storage_id(y), torch._C._storage_id(z))
```

[ghstack-poisoned]
…orMode"


This PR does a few things to allow caching of conversion to Fake/Meta Tensors in order for the output FakeTensors to share storage/have accurate aliasing relationships (happy to break up prs as needed but they are all pretty intertwined): 

- invokes `setup_mode` and `cleanup_mode` in the invocation of `_enable_mode`. This is needed to set up a cache of Fake/Meta Tensor conversions. 

- Disables the debug invariant checking in `VariableType` that checks things like storage/tensor ptr counts == 1, or that the input and output share the same storage when a torch dispatch mode is set. These assertions limit what you can do with `torch_dispatch_mode`, and are only being called in DEBUG builds anyway. Maybe we could also have modes opt/into out of these assertions.

- Introduces a `FakeTensorConverter` similar to [MetaTensorConverter](https://github.com/pytorch/pytorch/blob/master/test/test_meta.py#L77) which caches conversions of Tensors to `FakeTensors` and uses `MetaTensorConverter` under the hood so that newly allocated FakeTensors will have the same storage. There is one active `FakeTensorConverter` for the duration of `FakeTensorMode`. Since all newly allocated tensors will be on `meta` devices, memory should not significantly increase.

The end result of the PR is so that you can do things like 
```
x = torch.rand([4, 4])

with enable_torch_dispatch_mode(FakeTensorMode):
     # conversion from x to Meta/Fake cached for duration of the `FakeTensorMode` call
     y = x[0]
     z = x[1]

self.assertEqual(torch._C._storage_id(y), torch._C._storage_id(z))
```

[ghstack-poisoned]
@eellison
Copy link
Contributor Author

eellison commented Jun 3, 2022

@pytorchbot merge this please

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 3, 2022

❌ 🤖 pytorchbot command failed:

@pytorchbot: error: unrecognized arguments: this please

usage: @pytorchbot {merge,revert,rebase,help} ...

Try @pytorchbot help for more info.

@eellison
Copy link
Contributor Author

eellison commented Jun 3, 2022

@pytorchbot merge

@github-actions
Copy link
Contributor

github-actions bot commented Jun 3, 2022

Hey @eellison.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

facebook-github-bot pushed a commit that referenced this pull request Jun 3, 2022
)

Summary:
Pull Request resolved: #78090

Approved by: https://github.com/ezyang

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/26d273959c197e59d9d3e4246d4c1ad63d690137

Reviewed By: b0noI

Differential Revision: D36897423

fbshipit-source-id: c220ec59bf02c455160bdf058b00ff5d2c667f27
@facebook-github-bot facebook-github-bot deleted the gh/eellison/296/head branch June 6, 2022 14:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants