Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Add Fake Tensor Propagation#426

Merged
eellison merged 21 commits intopytorch:mainfrom
eellison:fake_tensor
Jun 27, 2022
Merged

Add Fake Tensor Propagation#426
eellison merged 21 commits intopytorch:mainfrom
eellison:fake_tensor

Conversation

@eellison
Copy link
Copy Markdown
Contributor

Run and propagate the fx graph with FakeTensors, which are a TensorSubclass with __torch_dispatch__ defined that do device propagation built on top of meta tensors.

A few notes:

  • FakeTensors/MetaTensors do consistent alias/storage tracking. For that, there is a cache for converting non-fake tensors to fake tensors so that the fake tensors will have the same alias/tensor id. Every Fake Tensor contains a corresponding FakeTensorMode which stores the cache from real->fake tensors and an op with fake tensor inputs can only be run if the inputs all have the same mode.

  • Meta Tensors don't have complete op coverage, but we get pretty far in FakeTensors by falling back to CPU and running tensors there. However, there are a few major categories of ops that don't support Meta, which is sparse, complex, and quantized. Complex should be pretty easily doable soon / is in the works but the others are farther away.
    I added a fake_tensor_propagation config option for this case although in the end state it would be great rid of this (cc @ezyang ).

  • Dynamic Shape operators will now cause a graph break, because we are not able to infer the output shape with meta inputs. The only case in the code base I observed this was repeat_interleave. There are various ways forward of supporting this (specializing on the data-dependent input / generating symbolic output shapes) but we can leave that to future dynamic shape work.

Depends on pytorch/pytorch#79741, pytorch/pytorch#79809 for tests to pass but I think those are landing shortly/accepted.

Comment thread torchdynamo/variables/tensor.py Outdated
try:
return fn()
except UnsupportedFakeTensorException as e:
raise Exception(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps add an exception type to torchdynamo.exc.*

@jansel
Copy link
Copy Markdown
Contributor

jansel commented Jun 21, 2022

If you want to change the pytorch version used in CI you can change it here:
https://github.com/jansel/torchdynamo/blob/3a0ef2eb9b2d22b44997cb79c5df4eebf6dccd46/.github/workflows/test-py38.yml#L18

@eellison eellison requested a review from jansel June 22, 2022 21:17
@jansel
Copy link
Copy Markdown
Contributor

jansel commented Jun 23, 2022

Overall this looks good to me. One question:

Will this break PyTorch 1.12 support? Would it be easy to fall back to non-fake tensors for PyTorch 1.12?

I'm wondering if we should do a branch cut for 1.12 support before merging this.

@eellison
Copy link
Copy Markdown
Contributor Author

Will this break PyTorch 1.12 support? Would it be easy to fall back to non-fake tensors for PyTorch 1.12?

It would break as of now, but I can add a check on the pytorch version and fallback to non-Fake Tensors if it isn't recent enough. Should I add that?

@jansel
Copy link
Copy Markdown
Contributor

jansel commented Jun 23, 2022

Yes, thanks!

@eellison eellison merged commit 10572cf into pytorch:main Jun 27, 2022
) from e

@classmethod
def create(cls, tx, proxy, example_value=None, nnmodule=None, **options):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should factor this out into an FX pass that can live in PyTorch fx passes proper

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants