Skip to content

Use dynamo fake tensor mode in aot_autograd, move aot_autograd compilation to lowering time [Merger of 89672 and 89773]#90039

Closed
voznesenskym wants to merge 7 commits intogh/voznesenskym/22/basefrom
gh/voznesenskym/22/head
Closed

Use dynamo fake tensor mode in aot_autograd, move aot_autograd compilation to lowering time [Merger of 89672 and 89773]#90039
voznesenskym wants to merge 7 commits intogh/voznesenskym/22/basefrom
gh/voznesenskym/22/head

Conversation

@voznesenskym
Copy link
Collaborator

@voznesenskym voznesenskym commented Dec 2, 2022

Stack from ghstack (oldest at bottom):

After all of the preparatory commits, this is a subset of the
changes in #89392 that actually
change us to propagating fake tensors to backends.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

This is the merger of Ed's PR #89672, which is a rewrite of an older PR of mine (#89392), with CI Fixes on top of it (#89773)

cc @mlazos @soumith @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire

… aot_autograd compilation to lowering time

After all of the preparatory commits, this is a subset of the
changes in #89392 that actually
change us to propagating fake tensors to backends.

Signed-off-by: Edward Z. Yang <ezyangfb.com>


[WIP] Commit to run CI on vaguely sus C++ changes uncovered during fixing upstream prototype


Wip

Random stuff

Fix tests

Fix char

make shape prop fake tensor friendly

Fix

Test fixes

Feedback, test fixes, xla shenanigins

xla maddness

xla maddness

Feedback

undo

undo

Fix test

rm stupid stuff

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 2, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90039

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures, 2 Pending

As of commit fb21875:

The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: fx release notes category label Dec 2, 2022
voznesenskym added a commit that referenced this pull request Dec 2, 2022
… aot_autograd compilation to lowering time

After all of the preparatory commits, this is a subset of the
changes in #89392 that actually
change us to propagating fake tensors to backends.

Signed-off-by: Edward Z. Yang <ezyangfb.com>


[WIP] Commit to run CI on vaguely sus C++ changes uncovered during fixing upstream prototype


Wip

Random stuff

Fix tests

Fix char

make shape prop fake tensor friendly

Fix

Test fixes

Feedback, test fixes, xla shenanigins

xla maddness

xla maddness

Feedback

undo

undo

Fix test

rm stupid stuff

ghstack-source-id: 13f3e9b
Pull Request resolved: #90039
@voznesenskym voznesenskym changed the title [UPDATED PROTOTYPE] Use dynamo fake tensor mode in aot_autograd, move aot_autograd compilation to lowering time Use dynamo fake tensor mode in aot_autograd, move aot_autograd compilation to lowering time [Merger of 89672 and 89773] Dec 2, 2022
@voznesenskym voznesenskym mentioned this pull request Dec 2, 2022
@voznesenskym voznesenskym added ciflow/trunk Trigger trunk jobs on your pull request topic: not user facing topic category labels Dec 2, 2022
# Cannot call sizes() on tensor with symbolic sizes/strides
)


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WHITESPAAAACE

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

k


# torch._subclasses.fake_tensor.UnsupportedFakeTensorException: meta converter nyi
# due to custom subclass (TensorProxy)
@unittest.expectedFailure
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a fix for this in a local patchset, this xfail AOK

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

k


self.assertIsNotNone(r1)
self.assertTrue(same(r1, r2))
self.assertTrue(same(r1, r3))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't you just assert same(r2, r3) and call it a day lol

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we care about r1 as well

self.module = self.fake_module

result = super().run_node(n)
self.module = self.real_module
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If run_node raises an exception this won't reset the module.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no need to do a swap like this. The only thing you need to override is a few ops (get_attr, call_module I think?) to fetch from fake module rather than real module.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If run_node raises an exception this won't reset the module.

good call

There's no need to do a swap like this. The only thing you need to override is a few ops (get_attr, call_module I think?) to fetch from fake module rather than real module.

Yeah but we already override this, this feels simpler. No strong opinion tho.

# TODO: this is questionable
if isinstance(x, torch._subclasses.FakeTensor):
# this func fails on fake tensors in __torch_dispatch__
return x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll put up with this because I think my pending refactor will resolve the confusion here

if fake_mode is None:
fake_mode = flat_input.fake_mode
else:
assert fake_mode == flat_input.fake_mode
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is, not equality, plz


def fwd(*args):
nonlocal compiled_graph
model = subgraph.model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this will keep subgraph permanently alive, whereas previously it could have been GC'ed after compilation. You should del subgraph being done with compilation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea

self.compiler = compiler

def compile_submod(self, submod, args, kwargs):
def compile_submod(self, input_mod, args, kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable renaming makes it harder for code reviewers

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like, afaict, there's literally no change from lines 211 to 250 but I had to carefully audit to make sure there weren't typos lol

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is from your PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


def run_node(self, n: Node) -> Any:
with fx_traceback.append_stack_trace(n.stack_trace):
args, kwargs = self.fetch_args_kwargs_from_env(n)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be really nice to have a comment here explaining what's going on

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing

if isinstance(arg, torch.Tensor) and not isinstance(
arg, torch._subclasses.FakeTensor
):
new_args.append(fake_mode.from_tensor(arg))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be much better if we could assume all the args are already fakeified (and we just maintain the invariant that all the intermediates are fake tensors). In particular, fakeifying a tensor can trigger the allocation of new symints, but if they're not properly associated as inputs we may not be able to instantiate those variables on subsequent runs.

tl;dr I suspect arg is always a fake tensor here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im not sure we can, I'm pretty sure we saw non fakes in here, but maybe that was in an intermediate state of this PR. I can check again.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can check this by pushing a separate PR that tightens the invariant and see if it fails or not

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@clee2000
Copy link
Contributor

clee2000 commented Dec 4, 2022

@pytorchbot revert -m "broke xla tests https://hud.pytorch.org/pytorch/pytorch/commit/ef0c7ec958439caf44a98fb7b70d920c6c2264b9 https://github.com/pytorch/pytorch/actions/runs/3606308473/jobs/6077646142" -c landrace

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@voznesenskym your PR has been successfully reverted.

@voznesenskym
Copy link
Collaborator Author

@pytorchbot revert -m "broke xla tests https://hud.pytorch.org/pytorch/pytorch/commit/ef0c7ec958439caf44a98fb7b70d920c6c2264b9 https://github.com/pytorch/pytorch/actions/runs/3606308473/jobs/6077646142" -c landrace

I'd rather we have forward fixed, its a 1 liner to fix. Alas.

@voznesenskym voznesenskym reopened this Dec 4, 2022
@github-actions github-actions bot requested a review from ezyang December 4, 2022 22:34
@ezyang
Copy link
Contributor

ezyang commented Dec 5, 2022

the way to correctly land xla changes is to update the hash in the PR to your branch on xla. Once landed, merge the xla branch into master. The automatic head updating process will eventually reset xla's commit id to master.

@voznesenskym
Copy link
Collaborator Author

the way to correctly land xla changes is to update the hash in the PR to your branch on xla. Once landed, merge the xla branch into master. The automatic head updating process will eventually reset xla's commit id to master.

Yes. But I did it in a way that should not need XLA changes.

@voznesenskym
Copy link
Collaborator Author

@pytorchbot merge -f "weird unrelated py3.7 pip install bug in manywheel"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@voznesenskym
Copy link
Collaborator Author

@pytorchmergebot / @pytorchbot merged the wrong commit >:(

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 5, 2022

❌ 🤖 pytorchbot command failed:

@pytorchbot: error: argument command: invalid choice: '/' (choose from 'merge', 'revert', 'rebase', 'label', 'drci')

usage: @pytorchbot [-h] {merge,revert,rebase,label,drci} ...

Try @pytorchbot --help for more info.

voznesenskym added a commit that referenced this pull request Dec 5, 2022
…ation to lowering time [Merger of 89672 and 89773] (#90039)

After all of the preparatory commits, this is a subset of the
changes in #89392 that actually
change us to propagating fake tensors to backends.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

This is the merger of Ed's PR #89672, which is a rewrite of an older PR of mine (#89392), with CI Fixes on top of it (#89773)

Pull Request resolved: #90039
Approved by: https://github.com/ezyang

fix
voznesenskym added a commit that referenced this pull request Dec 6, 2022
…ation to lowering time [Merger of 89672 and 89773] (#90039)

After all of the preparatory commits, this is a subset of the
changes in #89392 that actually
change us to propagating fake tensors to backends.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

This is the merger of Ed's PR #89672, which is a rewrite of an older PR of mine (#89392), with CI Fixes on top of it (#89773)

Pull Request resolved: #90039
Approved by: https://github.com/ezyang

fix
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
…ation to lowering time [Merger of 89672 and 89773] (pytorch#90039)

After all of the preparatory commits, this is a subset of the
changes in pytorch#89392 that actually
change us to propagating fake tensors to backends.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

This is the merger of Ed's PR pytorch#89672, which is a rewrite of an older PR of mine (pytorch#89392), with CI Fixes on top of it (pytorch#89773)

Pull Request resolved: pytorch#90039
Approved by: https://github.com/ezyang
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
…ation to lowering time [Merger of 89672 and 89773] (pytorch#90039)

After all of the preparatory commits, this is a subset of the
changes in pytorch#89392 that actually
change us to propagating fake tensors to backends.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

This is the merger of Ed's PR pytorch#89672, which is a rewrite of an older PR of mine (pytorch#89392), with CI Fixes on top of it (pytorch#89773)

Pull Request resolved: pytorch#90039
Approved by: https://github.com/ezyang
@facebook-github-bot facebook-github-bot deleted the gh/voznesenskym/22/head branch June 8, 2023 19:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants