Skip to content

add the option to disable functionalization in AOTDispatcher#164577

Closed
bdhirsh wants to merge 13 commits intogh/bdhirsh/671/basefrom
gh/bdhirsh/671/head
Closed

add the option to disable functionalization in AOTDispatcher#164577
bdhirsh wants to merge 13 commits intogh/bdhirsh/671/basefrom
gh/bdhirsh/671/head

Conversation

@bdhirsh
Copy link
Collaborator

@bdhirsh bdhirsh commented Oct 3, 2025

I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup

Stack from ghstack (oldest at bottom):

cc @ezyang @EikanWang @jgong5 @wenzhe-nrv @voznesenskym @penguinwu @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 3, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164577

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit efde212 with merge base e787d53 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

bdhirsh added a commit that referenced this pull request Oct 3, 2025
… aot_eager on llama3

ghstack-source-id: 2f88fd4
Pull Request resolved: #164577
@pytorch-bot pytorch-bot bot added ciflow/inductor release notes: fx release notes category labels Oct 3, 2025
return cast(JointTraceFn, inner_fn_with_anomaly) # deal with 'handle' property
# TODO: only need to skip this when turning off functionalization
# inner_fn_with_anomaly.handle = joint_fn_handle # type: ignore[attr-defined]
def joint_helper(primals, tangents):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit silly, but we need a dedicated joint_helper function that we don't functools.wrap, because make_fx tracing has some logic to "get the inner most function" to determine argument names, and we want it to "stop" at this joint function to properly name the arguments as primals and tangents: https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/proxy_tensor.py#L2290

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put it in a comment

torch/_ops.py Outdated
)

if 'CompositeImplicit' in str(k) or 'Autograd' in str(k):
return fn
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't actually need this to get bitwise equivalence for llama3, but we may want to consider a way to stop running all python-only decomps that run above autograd, in case they have different numerics than the eager impls.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're worrying about Python dispatcher, is that right? What if we don't enable Python dispatcher before running AOTAutograd so this code is all dead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep we can do that too. IIRC though, we will need the python dispatcher if we want to truly care about dynamic shapes (a lot of the python dispatcher only impls are reimplemntations of our meta functions from C++ that didn't work with SymInts)

torch.ops.aten.stride.default,
torch.ops.aten.storage_offset.default,
]
and should_decompose(func, flat_args_kwargs)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

turns out that it wasn't only functionalization that was running CIA decomps - proxy tensor runs them too, even if the op has a direct backend registration

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can land this on its own ahead of time

@bdhirsh
Copy link
Collaborator Author

bdhirsh commented Oct 3, 2025

the other change i needed was to tweak the test here: https://github.com/meta-pytorch/autoparallel/pull/176/files#diff-cbfe42e87df7867925a5e76ee407076c30f5ee141bf9a6785833679ef3c08533R59 so that it clones before the sum(), so we actually get contiguous tangents in the compiled backward (and don't need to clone)

@ezyang
Copy link
Contributor

ezyang commented Oct 5, 2025

FYI this patch as is cannot run llama3 simple_fsdp on torchtitan:

  traceback : Traceback (most recent call last):
    File "/data/users/ezyang/b/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 357, in wrapper
      return f(*args, **kwargs)
    File "/data/users/ezyang/b/torchtitan/torchtitan/train.py", line 596, in train
      self.train_step(data_iterator)
    File "/data/users/ezyang/b/torchtitan/torchtitan/train.py", line 496, in train_step
      loss = self.forward_backward_step(input_dict, labels)
    File "/data/users/ezyang/b/torchtitan/torchtitan/train.py", line 472, in forward_backward_step
      pred = model_parts[0](inputs, **extra_inputs)
    File "/data/users/ezyang/b/pytorch/torch/_dynamo/eval_frame.py", line 418, in __call__
      return super().__call__(*args, **kwargs)
    File "/data/users/ezyang/b/pytorch/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "/data/users/ezyang/b/pytorch/torch/nn/modules/module.py", line 1788, in _call_impl
      return forward_call(*args, **kwargs)
    File "/data/users/ezyang/b/pytorch/torch/_dynamo/eval_frame.py", line 886, in compile_wrapper
      return fn(*args, **kwargs)
    File "/data/users/ezyang/b/pytorch/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "/data/users/ezyang/b/pytorch/torch/nn/modules/module.py", line 1788, in _call_impl
      return forward_call(*args, **kwargs)
    File "/data/users/ezyang/b/torchtitan/torchtitan/models/llama3/model/model.py", line 394, in forward
      def forward(
    File "/data/users/ezyang/b/pytorch/torch/_dynamo/eval_frame.py", line 1098, in _fn
      return fn(*args, **kwargs)
    File "/data/users/ezyang/b/pytorch/torch/_functorch/aot_autograd.py", line 1134, in forward
      return compiled_fn(full_args)
    File "/data/users/ezyang/b/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 341, in runtime_wrapper
      all_outs = call_func_at_runtime_with_args(
    File "/data/users/ezyang/b/pytorch/torch/_functorch/_aot_autograd/utils.py", line 130, in call_func_at_runtime_with_args
      out = normalize_as_list(f(args))
    File "/data/users/ezyang/b/pytorch/torch/_functorch/_aot_autograd/utils.py", line 104, in g
      return f(*args)
    File "/data/users/ezyang/b/pytorch/torch/autograd/function.py", line 581, in apply
      return super().apply(*args, **kwargs)  # type: ignore[misc]
    File "/data/users/ezyang/b/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2148, in forward
      fw_outs = call_func_at_runtime_with_args(
    File "/data/users/ezyang/b/pytorch/torch/_functorch/_aot_autograd/utils.py", line 130, in call_func_at_runtime_with_args
      out = normalize_as_list(f(args))
    File "/data/users/ezyang/b/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 529, in wrapper
      return compiled_fn(runtime_args)
    File "/data/users/ezyang/b/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 727, in inner_fn
      outs = compiled_fn(args)
    File "/data/users/ezyang/b/pytorch/torch/_dynamo/backends/debugging.py", line 166, in run
      return forward_fn(args)
    File "/data/users/ezyang/b/pytorch/torch/fx/_lazy_graph_module.py", line 126, in _lazy_forward
      return self(*args, **kwargs)
    File "/data/users/ezyang/b/pytorch/torch/fx/graph_module.py", line 843, in call_wrapped
      return self._wrapped_call(self, *args, **kwargs)
    File "/data/users/ezyang/b/pytorch/torch/fx/graph_module.py", line 414, in __call__
      raise e
    File "/data/users/ezyang/b/pytorch/torch/fx/graph_module.py", line 401, in __call__
      return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
    File "/data/users/ezyang/b/pytorch/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "/data/users/ezyang/b/pytorch/torch/nn/modules/module.py", line 1788, in _call_impl
      return forward_call(*args, **kwargs)
    File "<eval_with_key>.6", line 182, in forward
      graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten._scaled_dot_product_flash_attention.default, transpose_4, transpose_5, transpose_6, 0.0, True, scale = 0.25, rng_state = fwd_rng_state_0);  transpose_4 = transpose_5 = transpose_6 = fwd_rng_state_0 = None
    File "/data/users/ezyang/b/pytorch/torch/_prims/rng_prims.py", line 322, in __call__
      return super().__call__(op, *args, rng_state=rng_state, **kwargs)
    File "/data/users/ezyang/b/pytorch/torch/_ops.py", line 538, in __call__
      return wrapper()
    File "/data/users/ezyang/b/pytorch/torch/_ops.py", line 534, in wrapper
      return self.dispatch(
    File "/data/users/ezyang/b/pytorch/torch/_ops.py", line 507, in dispatch
      raise NotImplementedError(
  NotImplementedError: could not find kernel for HigherOrderOperator graphsafe_run_with_rng_state at dispatch key DispatchKey.AutogradCUDA (resolved from DispatchKey.AutogradCUDA)

============================================================
(b) [ezyang@devvm006.dkl0 ~/local/b/torchtitan] pp NGPU=1 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --compile.enable --compile.backend aot_eager --training.steps 10 --model.name simple_fsdp.llama3

@ezyang
Copy link
Contributor

ezyang commented Oct 6, 2025

WAR for this from Brian

diff --git a/torch/_ops.py b/torch/_ops.py
index fe8dd4fee62..c67eb35870b 100644
--- a/torch/_ops.py
+++ b/torch/_ops.py
@@ -148,8 +148,9 @@ class OperatorBase:
                 "Please register a mode for the DispatchKey.Python key instead."
             )

-            if 'CompositeImplicit' in str(k) or 'Autograd' in str(k):
-                return fn
+            if k == DispatchKey.CompositeImplicitAutograd or k == DispatchKey.Autograd:
+                if torch._C._dispatch_has_kernel(self.name()) and torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), k):
+                    return fn
             if k in self.py_kernels:
                 raise RuntimeError(
                     f"Trying to override a python impl for {k} on operator {self.name()}"

ezyang pushed a commit that referenced this pull request Oct 7, 2025
… aot_eager on llama3

ghstack-source-id: 2f88fd4
Pull Request resolved: #164577
ezyang pushed a commit that referenced this pull request Oct 7, 2025
… aot_eager on llama3

ghstack-source-id: 2f88fd4
Pull Request resolved: #164577
ezyang pushed a commit that referenced this pull request Oct 7, 2025
… aot_eager on llama3

ghstack-source-id: 2f88fd4
Pull Request resolved: #164577
…alence with aot_eager on llama3"

Not for land. The right way for me to land this would be to clean up the "kill functionalization" changes into a separate PR backed by a config in AOTAutograd, and move the other changes into separate PRs




cc ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 8, 2025
… aot_eager on llama3

ghstack-source-id: d8b9744
Pull Request resolved: #164577
@bdhirsh bdhirsh changed the title temp hacks to remove functionalization + get bitwise equivalence with aot_eager on llama3 add the option to disable functionalization in AOTDispatcher Oct 8, 2025
Not for land. The right way for me to land this would be to clean up the "kill functionalization" changes into a separate PR backed by a config in AOTAutograd, and move the other changes into separate PRs




cc ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 8, 2025
… aot_eager on llama3

ghstack-source-id: f6340ca
Pull Request resolved: #164577
I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup




cc ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
@bdhirsh bdhirsh requested a review from Chillee as a code owner October 13, 2025 16:15
bdhirsh added a commit that referenced this pull request Oct 13, 2025
… aot_eager on llama3

ghstack-source-id: affe109
Pull Request resolved: #164577

fix
# Check if we are under AOTAutograd tracing
# Checking that a functional mode is active should always do what we want
return torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL) is not None
return torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) is not None
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added an AC test in test_activation_checkpointing.py for this

"test_aot_autograd_disable_functionalization_exhaustive",
aot_autograd_failures,
)
def test_aot_autograd_default_partition_exhaustive(self, device, dtype, op):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it appropriate to delete this test?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my plan in this PR is:

(1) have some tests that test "functionalization on, with min cut partitioner"

(2) have some tests that test "functionalization off, with default partitioner"

In the PR before this one, I wanted to test the default partitioner independently, so I added a test for "functionalization on, with default partitioner". My thought is that this test is unnecessary if we are testing the default partitioner with functionalization off, but happy to add it back


out_ref.sum().backward()
out.sum().backward()
self.assertEqual(inps_ref[0].grad, inps[0].grad)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we going to assert that the intermediate mutations didn't get moved?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah you're right, technically if we accidentally moved the mutation in my backward into the forward we would still get the same output, ill make it an expecttest.

dynamic=dynamic,
partition_fn=default_partition,
keep_inference_input_mutations=True
partition_fn=min_cut_rematerialization_partition,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems sus? Don't you still need to test use_min_cut?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So right now this PR will:

(1) instantiate a set of tests with functionalization on, and with functionalization off

(2) the "functionalization on" test runs the min cut partitioner, and the "functionalization off" test runs the default partitioner.

Right now the function exposes a use_min_cut flag separately from disable_functionalization because of my previous PR to test the default partitioner. I think I'm just going to kill that, and only add a flag for flipping disable_functionalization (which will implicitly decide what partitioner we test, since I don't think it's necessary to also test the default partitioner with functionalization on).

fw_module.graph.eliminate_dead_code()
fw_module.recompile()
bw_module.graph.eliminate_dead_code()
bw_module.recompile()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's going on here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forgot to delete some changes when I was testing :/ fixing


# There should be *NO* mutating ops in the graph at this point.
assert_functional_graph(fx_g.graph)
# /assert_functional_graph(fx_g.graph)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotta restore this under if branch

torch._dynamo.utils.assert_no_fake_params_or_buffers(fx_g)
fx_g.graph.eliminate_dead_code()
if not aot_config.disable_functionalization:
fx_g.graph.eliminate_dead_code()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fwiw it OUGHT to be safe to DCE even when functionalization is off lol

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it depends on how careful we want to be. If you write a custom op that advertises as functional but secretly mutates inputs, we will be silently wrong in aot_eager vs pytorch eager if we always run DCE.

My vote would be that any graph passes that you want to run with should be opt-in by the user (DCE and also probably CSE). what do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We gotta have accurate schema. I wouldn't overly focus on this case

return True
elif u.target.is_view:
tensor_arg_aliases.append(u)
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And so with a partitioner that exactly respects the original placement, I do not feel this function would be necessary.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is very reasonable. If you would rather I rewrite the default partitioner completely to do this I'm happy to. Otherwise, I'll land this PR to get something shippable and we can do a broader refactor.

If we want to re-write the default partitioner this way, we'll have to think about the AC interaction (if we trace AC we get AC support "for free" if we do this strategy, but if we don't trace AC then the default partitioner will still need a graph pass to duplicate recompute in the backward, similar to what Jeffrey's PR does)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I basically want to do what Jeffreys PR does but in this pass lol

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'll be good to merge this stuff so more people can kick the tires!

@bdhirsh
Copy link
Collaborator Author

bdhirsh commented Oct 15, 2025

Oops there are some new failures after my last round of changes, investigating

I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup




cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 15, 2025
I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup




cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 15, 2025
I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup




cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 15, 2025
I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup




cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 16, 2025
@bdhirsh
Copy link
Collaborator Author

bdhirsh commented Oct 16, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 16, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
…#164577)

I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: pytorch#164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup

Pull Request resolved: pytorch#164577
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#165372
zhudada0120 pushed a commit to zhudada0120/pytorch that referenced this pull request Oct 22, 2025
…#164577)

I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: pytorch#164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup

Pull Request resolved: pytorch#164577
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#165372
@github-actions github-actions bot deleted the gh/bdhirsh/671/head branch November 16, 2025 02:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request fx Merged module: dynamo release notes: fx release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants