Add Fake Tensor Propagation#426
Conversation
| try: | ||
| return fn() | ||
| except UnsupportedFakeTensorException as e: | ||
| raise Exception( |
There was a problem hiding this comment.
Perhaps add an exception type to torchdynamo.exc.*
|
If you want to change the pytorch version used in CI you can change it here: |
|
Overall this looks good to me. One question: Will this break PyTorch 1.12 support? Would it be easy to fall back to non-fake tensors for PyTorch 1.12? I'm wondering if we should do a branch cut for 1.12 support before merging this. |
It would break as of now, but I can add a check on the pytorch version and fallback to non-Fake Tensors if it isn't recent enough. Should I add that? |
|
Yes, thanks! |
| ) from e | ||
|
|
||
| @classmethod | ||
| def create(cls, tx, proxy, example_value=None, nnmodule=None, **options): |
There was a problem hiding this comment.
We should factor this out into an FX pass that can live in PyTorch fx passes proper
Run and propagate the fx graph with FakeTensors, which are a TensorSubclass with
__torch_dispatch__defined that do device propagation built on top of meta tensors.A few notes:
FakeTensors/MetaTensors do consistent alias/storage tracking. For that, there is a cache for converting non-fake tensors to fake tensors so that the fake tensors will have the same alias/tensor id. Every Fake Tensor contains a corresponding FakeTensorMode which stores the cache from real->fake tensors and an op with fake tensor inputs can only be run if the inputs all have the same mode.
Meta Tensors don't have complete op coverage, but we get pretty far in FakeTensors by falling back to CPU and running tensors there. However, there are a few major categories of ops that don't support Meta, which is sparse, complex, and quantized. Complex should be pretty easily doable soon / is in the works but the others are farther away.
I added a
fake_tensor_propagationconfig option for this case although in the end state it would be great rid of this (cc @ezyang ).Dynamic Shape operators will now cause a graph break, because we are not able to infer the output shape with meta inputs. The only case in the code base I observed this was
repeat_interleave. There are various ways forward of supporting this (specializing on the data-dependent input / generating symbolic output shapes) but we can leave that to future dynamic shape work.Depends on pytorch/pytorch#79741, pytorch/pytorch#79809 for tests to pass but I think those are landing shortly/accepted.