Skip to content

Refactor how AOTAutograd backends are defined#89736

Closed
ezyang wants to merge 5 commits intogh/ezyang/1602/basefrom
gh/ezyang/1602/head
Closed

Refactor how AOTAutograd backends are defined#89736
ezyang wants to merge 5 commits intogh/ezyang/1602/basefrom
gh/ezyang/1602/head

Conversation

@ezyang
Copy link
Contributor

@ezyang ezyang commented Nov 28, 2022

Stack from ghstack (oldest at bottom):

There was a lot of strangeness in how AOTAutograd backends were previously defined. This refactor replaces the strangeness with something simple and straightforward. The improvements:

  • There is no longer a footgun aot_autograd "backend" which doesn't actually work. No more mistyping torch._dynamo.optimize("aot_autograd") when you meant "aot_eager"
  • Deleted aot_print because it's annoying and anyway there's no uses of it
  • Instead of having BOTH the backend Subgraph and AotAutogradStrategy, there is now only an aot_autograd function which takes the kwargs to configure AOTAutograd, and then gives you a compiler function that does AOTAutograd given those kwargs. Easy.
  • The primary downside is that we are now eagerly populating all of the kwargs, and that can get us into import cycle shenanigans. Some cycles I resolved directly (e.g., we now no longer manually disable the forward function before passing it to aot_autograd; aot_autograd it does it for us), but for getting inductor decompositions I had to make it take a lambda so I could lazily populate the decomps later.

New code is 130 lines shorter!

Signed-off-by: Edward Z. Yang ezyang@fb.com

cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 28, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89736

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 Failures

As of commit c8c363b:

The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
There was a lot of strangeness in how AOTAutograd backends were previously defined. This refactor replaces the strangeness with something simple and straightforward. The improvements:

- There is no longer a footgun aot_autograd "backend" which doesn't actually work. No more mistyping `torch._dynamo.optimize("aot_autograd")` when you meant "aot_eager"
- Deleted aot_print because it's annoying and anyway there's no uses of it
- Instead of having BOTH the backend Subgraph and AotAutogradStrategy, there is now only an aot_autograd function which takes the kwargs to configure AOTAutograd, and then gives you a compiler function that does AOTAutograd given those kwargs. Easy.
- The primary downside is that we are now eagerly populating all of the kwargs, and that can get us into import cycle shenanigans. Some cycles I resolved directly (e.g., we now no longer manually disable the forward function before passing it to aot_autograd; aot_autograd it does it for us), but for getting inductor decompositions I had to make it take a lambda so I could lazily populate the decomps later.

New code is 130 lines shorter!

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
There was a lot of strangeness in how AOTAutograd backends were previously defined. This refactor replaces the strangeness with something simple and straightforward. The improvements:

- There is no longer a footgun aot_autograd "backend" which doesn't actually work. No more mistyping `torch._dynamo.optimize("aot_autograd")` when you meant "aot_eager"
- Deleted aot_print because it's annoying and anyway there's no uses of it
- Instead of having BOTH the backend Subgraph and AotAutogradStrategy, there is now only an aot_autograd function which takes the kwargs to configure AOTAutograd, and then gives you a compiler function that does AOTAutograd given those kwargs. Easy.
- The primary downside is that we are now eagerly populating all of the kwargs, and that can get us into import cycle shenanigans. Some cycles I resolved directly (e.g., we now no longer manually disable the forward function before passing it to aot_autograd; aot_autograd it does it for us), but for getting inductor decompositions I had to make it take a lambda so I could lazily populate the decomps later.

New code is 130 lines shorter!

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
fw_metadata = _fw_metadata

@staticmethod
@disable_torchdynamo
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why were we disabling dynamo earlier and not anymore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved the disable. Previously, we manually disabled dynamo on the forwards function we return. Now, dynamo is responsible for disabling itself on the returned compiled function. See https://github.com/pytorch/pytorch/pull/89736/files#diff-0b094ea719a9acd16c316b7ec6975f6c9825398952e9b3a67f180d20ac592d47R82

I wasn't really planning to do this change but import cycles were annoying


force_compile_tiny_graphs = kwargs.pop("force_compile_tiny_graphs", False)

if count_calls(gm.graph) < 2 and not force_compile_tiny_graphs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as per the comment below for decomps, should we always force compilation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forcing compilation on single op graphs for aot eager truly is pointless. Maybe we should do it anyway for debug purposes? Not sure.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm yeah. I like the idea to force recompilation in debug mode anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Chillee you cool with this idea?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm fine.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it truly pointless? what if the single op decomposes into some ops that can be fused?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we gonna remove this conditional!

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!



aot_ts = AotTorchscript.compile_fn
DEBUG = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be in the config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, probably. Need to negotiate a name for it.

kwargs["decompositions"] = kwargs["decompositions"]()

# TODO: stop monkeypatching here (without even cleaning up, UGH!)
functorch.compile.config.use_functionalize = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we just update the config's default at this point?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, see #89663

There was a lot of strangeness in how AOTAutograd backends were previously defined. This refactor replaces the strangeness with something simple and straightforward. The improvements:

- There is no longer a footgun aot_autograd "backend" which doesn't actually work. No more mistyping `torch._dynamo.optimize("aot_autograd")` when you meant "aot_eager"
- Deleted aot_print because it's annoying and anyway there's no uses of it
- Instead of having BOTH the backend Subgraph and AotAutogradStrategy, there is now only an aot_autograd function which takes the kwargs to configure AOTAutograd, and then gives you a compiler function that does AOTAutograd given those kwargs. Easy.
- The primary downside is that we are now eagerly populating all of the kwargs, and that can get us into import cycle shenanigans. Some cycles I resolved directly (e.g., we now no longer manually disable the forward function before passing it to aot_autograd; aot_autograd it does it for us), but for getting inductor decompositions I had to make it take a lambda so I could lazily populate the decomps later.

New code is 130 lines shorter!

Signed-off-by: Edward Z. Yang <ezyangfb.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
@voznesenskym voznesenskym mentioned this pull request Nov 28, 2022
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
There was a lot of strangeness in how AOTAutograd backends were previously defined. This refactor replaces the strangeness with something simple and straightforward. The improvements:

- There is no longer a footgun aot_autograd "backend" which doesn't actually work. No more mistyping `torch._dynamo.optimize("aot_autograd")` when you meant "aot_eager"
- Deleted aot_print because it's annoying and anyway there's no uses of it
- Instead of having BOTH the backend Subgraph and AotAutogradStrategy, there is now only an aot_autograd function which takes the kwargs to configure AOTAutograd, and then gives you a compiler function that does AOTAutograd given those kwargs. Easy.
- The primary downside is that we are now eagerly populating all of the kwargs, and that can get us into import cycle shenanigans. Some cycles I resolved directly (e.g., we now no longer manually disable the forward function before passing it to aot_autograd; aot_autograd it does it for us), but for getting inductor decompositions I had to make it take a lambda so I could lazily populate the decomps later.

New code is 130 lines shorter!

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Pull Request resolved: pytorch#89736
Approved by: https://github.com/anjali411, https://github.com/albanD
@facebook-github-bot facebook-github-bot deleted the gh/ezyang/1602/head branch June 8, 2023 16:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants