Use dynamo fake tensor mode in aot_autograd, move aot_autograd compilation to lowering time [Merger of 89672 and 89773]#90039
Conversation
… aot_autograd compilation to lowering time After all of the preparatory commits, this is a subset of the changes in #89392 that actually change us to propagating fake tensors to backends. Signed-off-by: Edward Z. Yang <ezyangfb.com> [WIP] Commit to run CI on vaguely sus C++ changes uncovered during fixing upstream prototype Wip Random stuff Fix tests Fix char make shape prop fake tensor friendly Fix Test fixes Feedback, test fixes, xla shenanigins xla maddness xla maddness Feedback undo undo Fix test rm stupid stuff [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90039
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 Failures, 2 PendingAs of commit fb21875: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
… aot_autograd compilation to lowering time After all of the preparatory commits, this is a subset of the changes in #89392 that actually change us to propagating fake tensors to backends. Signed-off-by: Edward Z. Yang <ezyangfb.com> [WIP] Commit to run CI on vaguely sus C++ changes uncovered during fixing upstream prototype Wip Random stuff Fix tests Fix char make shape prop fake tensor friendly Fix Test fixes Feedback, test fixes, xla shenanigins xla maddness xla maddness Feedback undo undo Fix test rm stupid stuff ghstack-source-id: 13f3e9b Pull Request resolved: #90039
test/dynamo/test_dynamic_shapes.py
Outdated
| # Cannot call sizes() on tensor with symbolic sizes/strides | ||
| ) | ||
|
|
||
|
|
|
|
||
| # torch._subclasses.fake_tensor.UnsupportedFakeTensorException: meta converter nyi | ||
| # due to custom subclass (TensorProxy) | ||
| @unittest.expectedFailure |
There was a problem hiding this comment.
I have a fix for this in a local patchset, this xfail AOK
|
|
||
| self.assertIsNotNone(r1) | ||
| self.assertTrue(same(r1, r2)) | ||
| self.assertTrue(same(r1, r3)) |
There was a problem hiding this comment.
Why don't you just assert same(r2, r3) and call it a day lol
There was a problem hiding this comment.
we care about r1 as well
torch/fx/passes/shape_prop.py
Outdated
| self.module = self.fake_module | ||
|
|
||
| result = super().run_node(n) | ||
| self.module = self.real_module |
There was a problem hiding this comment.
If run_node raises an exception this won't reset the module.
There was a problem hiding this comment.
There's no need to do a swap like this. The only thing you need to override is a few ops (get_attr, call_module I think?) to fetch from fake module rather than real module.
There was a problem hiding this comment.
If run_node raises an exception this won't reset the module.
good call
There's no need to do a swap like this. The only thing you need to override is a few ops (
get_attr,call_moduleI think?) to fetch from fake module rather than real module.
Yeah but we already override this, this feels simpler. No strong opinion tho.
| # TODO: this is questionable | ||
| if isinstance(x, torch._subclasses.FakeTensor): | ||
| # this func fails on fake tensors in __torch_dispatch__ | ||
| return x |
There was a problem hiding this comment.
I'll put up with this because I think my pending refactor will resolve the confusion here
torch/_dynamo/utils.py
Outdated
| if fake_mode is None: | ||
| fake_mode = flat_input.fake_mode | ||
| else: | ||
| assert fake_mode == flat_input.fake_mode |
|
|
||
| def fwd(*args): | ||
| nonlocal compiled_graph | ||
| model = subgraph.model |
There was a problem hiding this comment.
nit: this will keep subgraph permanently alive, whereas previously it could have been GC'ed after compilation. You should del subgraph being done with compilation
| self.compiler = compiler | ||
|
|
||
| def compile_submod(self, submod, args, kwargs): | ||
| def compile_submod(self, input_mod, args, kwargs): |
There was a problem hiding this comment.
The variable renaming makes it harder for code reviewers
There was a problem hiding this comment.
like, afaict, there's literally no change from lines 211 to 250 but I had to carefully audit to make sure there weren't typos lol
There was a problem hiding this comment.
This is from your PR?
There was a problem hiding this comment.
|
|
||
| def run_node(self, n: Node) -> Any: | ||
| with fx_traceback.append_stack_trace(n.stack_trace): | ||
| args, kwargs = self.fetch_args_kwargs_from_env(n) |
There was a problem hiding this comment.
It would be really nice to have a comment here explaining what's going on
| if isinstance(arg, torch.Tensor) and not isinstance( | ||
| arg, torch._subclasses.FakeTensor | ||
| ): | ||
| new_args.append(fake_mode.from_tensor(arg)) |
There was a problem hiding this comment.
It would be much better if we could assume all the args are already fakeified (and we just maintain the invariant that all the intermediates are fake tensors). In particular, fakeifying a tensor can trigger the allocation of new symints, but if they're not properly associated as inputs we may not be able to instantiate those variables on subsequent runs.
tl;dr I suspect arg is always a fake tensor here
There was a problem hiding this comment.
Im not sure we can, I'm pretty sure we saw non fakes in here, but maybe that was in an intermediate state of this PR. I can check again.
There was a problem hiding this comment.
you can check this by pushing a separate PR that tightens the invariant and see if it fails or not
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
@pytorchbot successfully started a revert job. Check the current status here. |
|
@voznesenskym your PR has been successfully reverted. |
…d compilation to lowering time [Merger of 89672 and 89773] (#90039)" This reverts commit ef0c7ec. Reverted #90039 on behalf of https://github.com/clee2000 due to broke xla tests https://hud.pytorch.org/pytorch/pytorch/commit/ef0c7ec958439caf44a98fb7b70d920c6c2264b9 https://github.com/pytorch/pytorch/actions/runs/3606308473/jobs/6077646142
I'd rather we have forward fixed, its a 1 liner to fix. Alas. |
|
the way to correctly land xla changes is to update the hash in the PR to your branch on xla. Once landed, merge the xla branch into master. The automatic head updating process will eventually reset xla's commit id to master. |
Yes. But I did it in a way that should not need XLA changes. |
|
@pytorchbot merge -f "weird unrelated py3.7 pip install bug in manywheel" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
@pytorchmergebot / @pytorchbot merged the wrong commit >:( |
|
❌ 🤖 pytorchbot command failed: Try |
…ation to lowering time [Merger of 89672 and 89773] (#90039) After all of the preparatory commits, this is a subset of the changes in #89392 that actually change us to propagating fake tensors to backends. Signed-off-by: Edward Z. Yang <ezyangfb.com> This is the merger of Ed's PR #89672, which is a rewrite of an older PR of mine (#89392), with CI Fixes on top of it (#89773) Pull Request resolved: #90039 Approved by: https://github.com/ezyang fix
…ation to lowering time [Merger of 89672 and 89773] (#90039) After all of the preparatory commits, this is a subset of the changes in #89392 that actually change us to propagating fake tensors to backends. Signed-off-by: Edward Z. Yang <ezyangfb.com> This is the merger of Ed's PR #89672, which is a rewrite of an older PR of mine (#89392), with CI Fixes on top of it (#89773) Pull Request resolved: #90039 Approved by: https://github.com/ezyang fix
…ation to lowering time [Merger of 89672 and 89773] (pytorch#90039) After all of the preparatory commits, this is a subset of the changes in pytorch#89392 that actually change us to propagating fake tensors to backends. Signed-off-by: Edward Z. Yang <ezyangfb.com> This is the merger of Ed's PR pytorch#89672, which is a rewrite of an older PR of mine (pytorch#89392), with CI Fixes on top of it (pytorch#89773) Pull Request resolved: pytorch#90039 Approved by: https://github.com/ezyang
…d compilation to lowering time [Merger of 89672 and 89773] (pytorch#90039)" This reverts commit ef0c7ec. Reverted pytorch#90039 on behalf of https://github.com/clee2000 due to broke xla tests https://hud.pytorch.org/pytorch/pytorch/commit/ef0c7ec958439caf44a98fb7b70d920c6c2264b9 https://github.com/pytorch/pytorch/actions/runs/3606308473/jobs/6077646142
…ation to lowering time [Merger of 89672 and 89773] (pytorch#90039) After all of the preparatory commits, this is a subset of the changes in pytorch#89392 that actually change us to propagating fake tensors to backends. Signed-off-by: Edward Z. Yang <ezyangfb.com> This is the merger of Ed's PR pytorch#89672, which is a rewrite of an older PR of mine (pytorch#89392), with CI Fixes on top of it (pytorch#89773) Pull Request resolved: pytorch#90039 Approved by: https://github.com/ezyang
Stack from ghstack (oldest at bottom):
After all of the preparatory commits, this is a subset of the
changes in #89392 that actually
change us to propagating fake tensors to backends.
Signed-off-by: Edward Z. Yang <ezyangfb.com>
This is the merger of Ed's PR #89672, which is a rewrite of an older PR of mine (#89392), with CI Fixes on top of it (#89773)
cc @mlazos @soumith @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire