add the option to disable functionalization in AOTDispatcher#164577
add the option to disable functionalization in AOTDispatcher#164577bdhirsh wants to merge 13 commits intogh/bdhirsh/671/basefrom
Conversation
… aot_eager on llama3 [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164577
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit efde212 with merge base e787d53 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| return cast(JointTraceFn, inner_fn_with_anomaly) # deal with 'handle' property | ||
| # TODO: only need to skip this when turning off functionalization | ||
| # inner_fn_with_anomaly.handle = joint_fn_handle # type: ignore[attr-defined] | ||
| def joint_helper(primals, tangents): |
There was a problem hiding this comment.
this is a bit silly, but we need a dedicated joint_helper function that we don't functools.wrap, because make_fx tracing has some logic to "get the inner most function" to determine argument names, and we want it to "stop" at this joint function to properly name the arguments as primals and tangents: https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/proxy_tensor.py#L2290
torch/_ops.py
Outdated
| ) | ||
|
|
||
| if 'CompositeImplicit' in str(k) or 'Autograd' in str(k): | ||
| return fn |
There was a problem hiding this comment.
I didn't actually need this to get bitwise equivalence for llama3, but we may want to consider a way to stop running all python-only decomps that run above autograd, in case they have different numerics than the eager impls.
There was a problem hiding this comment.
You're worrying about Python dispatcher, is that right? What if we don't enable Python dispatcher before running AOTAutograd so this code is all dead?
There was a problem hiding this comment.
yep we can do that too. IIRC though, we will need the python dispatcher if we want to truly care about dynamic shapes (a lot of the python dispatcher only impls are reimplemntations of our meta functions from C++ that didn't work with SymInts)
| torch.ops.aten.stride.default, | ||
| torch.ops.aten.storage_offset.default, | ||
| ] | ||
| and should_decompose(func, flat_args_kwargs) |
There was a problem hiding this comment.
turns out that it wasn't only functionalization that was running CIA decomps - proxy tensor runs them too, even if the op has a direct backend registration
There was a problem hiding this comment.
We can land this on its own ahead of time
|
the other change i needed was to tweak the test here: https://github.com/meta-pytorch/autoparallel/pull/176/files#diff-cbfe42e87df7867925a5e76ee407076c30f5ee141bf9a6785833679ef3c08533R59 so that it clones before the sum(), so we actually get contiguous tangents in the compiled backward (and don't need to clone) |
|
FYI this patch as is cannot run llama3 simple_fsdp on torchtitan: |
|
WAR for this from Brian |
…alence with aot_eager on llama3" Not for land. The right way for me to land this would be to clean up the "kill functionalization" changes into a separate PR backed by a config in AOTAutograd, and move the other changes into separate PRs cc ezyang EikanWang jgong5 wenzhe-nrv [ghstack-poisoned]
Not for land. The right way for me to land this would be to clean up the "kill functionalization" changes into a separate PR backed by a config in AOTAutograd, and move the other changes into separate PRs cc ezyang EikanWang jgong5 wenzhe-nrv [ghstack-poisoned]
I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version: (1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939) (2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup cc ezyang EikanWang jgong5 wenzhe-nrv [ghstack-poisoned]
| # Check if we are under AOTAutograd tracing | ||
| # Checking that a functional mode is active should always do what we want | ||
| return torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL) is not None | ||
| return torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) is not None |
There was a problem hiding this comment.
added an AC test in test_activation_checkpointing.py for this
| "test_aot_autograd_disable_functionalization_exhaustive", | ||
| aot_autograd_failures, | ||
| ) | ||
| def test_aot_autograd_default_partition_exhaustive(self, device, dtype, op): |
There was a problem hiding this comment.
Is it appropriate to delete this test?
There was a problem hiding this comment.
my plan in this PR is:
(1) have some tests that test "functionalization on, with min cut partitioner"
(2) have some tests that test "functionalization off, with default partitioner"
In the PR before this one, I wanted to test the default partitioner independently, so I added a test for "functionalization on, with default partitioner". My thought is that this test is unnecessary if we are testing the default partitioner with functionalization off, but happy to add it back
|
|
||
| out_ref.sum().backward() | ||
| out.sum().backward() | ||
| self.assertEqual(inps_ref[0].grad, inps[0].grad) |
There was a problem hiding this comment.
Are we going to assert that the intermediate mutations didn't get moved?
There was a problem hiding this comment.
yeah you're right, technically if we accidentally moved the mutation in my backward into the forward we would still get the same output, ill make it an expecttest.
| dynamic=dynamic, | ||
| partition_fn=default_partition, | ||
| keep_inference_input_mutations=True | ||
| partition_fn=min_cut_rematerialization_partition, |
There was a problem hiding this comment.
Seems sus? Don't you still need to test use_min_cut?
There was a problem hiding this comment.
So right now this PR will:
(1) instantiate a set of tests with functionalization on, and with functionalization off
(2) the "functionalization on" test runs the min cut partitioner, and the "functionalization off" test runs the default partitioner.
Right now the function exposes a use_min_cut flag separately from disable_functionalization because of my previous PR to test the default partitioner. I think I'm just going to kill that, and only add a flag for flipping disable_functionalization (which will implicitly decide what partitioner we test, since I don't think it's necessary to also test the default partitioner with functionalization on).
| fw_module.graph.eliminate_dead_code() | ||
| fw_module.recompile() | ||
| bw_module.graph.eliminate_dead_code() | ||
| bw_module.recompile() |
There was a problem hiding this comment.
forgot to delete some changes when I was testing :/ fixing
|
|
||
| # There should be *NO* mutating ops in the graph at this point. | ||
| assert_functional_graph(fx_g.graph) | ||
| # /assert_functional_graph(fx_g.graph) |
There was a problem hiding this comment.
Gotta restore this under if branch
| torch._dynamo.utils.assert_no_fake_params_or_buffers(fx_g) | ||
| fx_g.graph.eliminate_dead_code() | ||
| if not aot_config.disable_functionalization: | ||
| fx_g.graph.eliminate_dead_code() |
There was a problem hiding this comment.
fwiw it OUGHT to be safe to DCE even when functionalization is off lol
There was a problem hiding this comment.
I think it depends on how careful we want to be. If you write a custom op that advertises as functional but secretly mutates inputs, we will be silently wrong in aot_eager vs pytorch eager if we always run DCE.
My vote would be that any graph passes that you want to run with should be opt-in by the user (DCE and also probably CSE). what do you think?
There was a problem hiding this comment.
We gotta have accurate schema. I wouldn't overly focus on this case
| return True | ||
| elif u.target.is_view: | ||
| tensor_arg_aliases.append(u) | ||
| return False |
There was a problem hiding this comment.
And so with a partitioner that exactly respects the original placement, I do not feel this function would be necessary.
There was a problem hiding this comment.
Yeah this is very reasonable. If you would rather I rewrite the default partitioner completely to do this I'm happy to. Otherwise, I'll land this PR to get something shippable and we can do a broader refactor.
If we want to re-write the default partitioner this way, we'll have to think about the AC interaction (if we trace AC we get AC support "for free" if we do this strategy, but if we don't trace AC then the default partitioner will still need a graph pass to duplicate recompute in the backward, similar to what Jeffrey's PR does)
There was a problem hiding this comment.
I basically want to do what Jeffreys PR does but in this pass lol
ezyang
left a comment
There was a problem hiding this comment.
I think it'll be good to merge this stuff so more people can kick the tires!
|
Oops there are some new failures after my last round of changes, investigating |
I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version: (1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939) (2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela [ghstack-poisoned]
I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version: (1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939) (2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela [ghstack-poisoned]
I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version: (1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939) (2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela [ghstack-poisoned]
I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version: (1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939) (2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…#164577) I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version: (1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: pytorch#164939) (2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup Pull Request resolved: pytorch#164577 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#165372
…#164577) I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version: (1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: pytorch#164939) (2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup Pull Request resolved: pytorch#164577 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#165372
I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:
(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939)
(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup
Stack from ghstack (oldest at bottom):
cc @ezyang @EikanWang @jgong5 @wenzhe-nrv @voznesenskym @penguinwu @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela