Skip to content

preserve node stacktraces from compiled autograd through AOTDispatcher, due to GmWrapper#133574

Closed
bdhirsh wants to merge 1 commit intogh/bdhirsh/605/basefrom
gh/bdhirsh/605/head
Closed

preserve node stacktraces from compiled autograd through AOTDispatcher, due to GmWrapper#133574
bdhirsh wants to merge 1 commit intogh/bdhirsh/605/basefrom
gh/bdhirsh/605/head

Conversation

@bdhirsh
Copy link
Collaborator

@bdhirsh bdhirsh commented Aug 15, 2024

Fixes #133567

New log output from the repro:

(/home/hirsheybar/local/b/pytorch-env) [hirsheybar@devgpu001.lla3 ~/local/b/pytorch (compiled_autograd_stacktraces)]$ TORCH_LOGS="compiled_autograd_verbose,aot" python tmp5.py
INFO: TRACED GRAPH
 ===== Joint graph 0 =====
 /home/hirsheybar/local/b/pytorch/torch/fx/_lazy_graph_module.py class joint_helper(torch.nn.Module):
    def forward(self, primals, tangents):
        primals_1: "f32[4, 4][4, 1]cpu"; tangents_1: "f32[4, 4][4, 1]cpu";

        primals_1, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
         # File: /home/hirsheybar/local/b/pytorch/tmp5.py:6 in f, code: return torch.matmul(x, x)
        mm: "f32[4, 4][4, 1]cpu" = torch.ops.aten.mm.default(primals_1, primals_1)
        permute: "f32[4, 4][1, 4]cpu" = torch.ops.aten.permute.default(primals_1, [1, 0])
        mm_1: "f32[4, 4][4, 1]cpu" = torch.ops.aten.mm.default(permute, tangents_1);  permute = None
        permute_1: "f32[4, 4][1, 4]cpu" = torch.ops.aten.permute.default(primals_1, [1, 0]);  primals_1 = None
        mm_2: "f32[4, 4][4, 1]cpu" = torch.ops.aten.mm.default(tangents_1, permute_1);  tangents_1 = permute_1 = None

         # File: /home/hirsheybar/local/b/pytorch/tmp5.py:6 in f, code: return torch.matmul(x, x)
        add: "f32[4, 4][4, 1]cpu" = torch.ops.aten.add.Tensor(mm_2, mm_1);  mm_2 = mm_1 = None
        return pytree.tree_unflatten([mm, add], self._out_spec)


INFO: aot_config id: 0, fw_metadata=ViewAndMutationMeta(input_info=[InputAliasInfo(is_leaf=True, mutates_data=False, mutates_metadata=False, mutations_hidden_from_autograd=True, mutations_under_no_grad_or_inference_mode=False, mutation_inductor_storage_resize=False, mutates_storage_metadata=False, requires_grad=True, keep_input_mutations=False)], output_info=[OutputAliasInfo(output_type=<OutputType.non_alias: 1>, raw_type=<class 'torch._subclasses.functional_tensor.FunctionalTensor'>, base_idx=None, dynamic_dims=set(), requires_grad=True, functional_tensor=None)], num_intermediate_bases=0, keep_input_mutations=False, traced_tangents=[FakeTensor(..., size=(4, 4))], subclass_inp_meta=[0], subclass_fw_graph_out_meta=[0], subclass_tangent_meta=[0], is_train=True, traced_tangent_metas=None, num_symints_saved_for_bw=0, grad_enabled_mutation=None, deterministic=False, static_input_indices=[], tokens={}, indices_of_inputs_that_requires_grad_with_mutations_in_bw=[], bw_donated_idxs=None), inner_meta=ViewAndMutationMeta(input_info=[InputAliasInfo(is_leaf=True, mutates_data=False, mutates_metadata=False, mutations_hidden_from_autograd=True, mutations_under_no_grad_or_inference_mode=False, mutation_inductor_storage_resize=False, mutates_storage_metadata=False, requires_grad=True, keep_input_mutations=False)], output_info=[OutputAliasInfo(output_type=<OutputType.non_alias: 1>, raw_type=<class 'torch._subclasses.functional_tensor.FunctionalTensor'>, base_idx=None, dynamic_dims=set(), requires_grad=True, functional_tensor=None)], num_intermediate_bases=0, keep_input_mutations=False, traced_tangents=[FakeTensor(..., size=(4, 4))], subclass_inp_meta=[0], subclass_fw_graph_out_meta=[0], subclass_tangent_meta=[0], is_train=True, traced_tangent_metas=None, num_symints_saved_for_bw=0, grad_enabled_mutation=None, deterministic=False, static_input_indices=[], tokens={}, indices_of_inputs_that_requires_grad_with_mutations_in_bw=[], bw_donated_idxs=None)
INFO: TRACED GRAPH
 ===== Forward graph 0 =====
 /home/hirsheybar/local/b/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[4, 4][4, 1]cpu"):
         # File: /home/hirsheybar/local/b/pytorch/tmp5.py:6 in f, code: return torch.matmul(x, x)
        mm: "f32[4, 4][4, 1]cpu" = torch.ops.aten.mm.default(primals_1, primals_1)
        permute: "f32[4, 4][1, 4]cpu" = torch.ops.aten.permute.default(primals_1, [1, 0]);  primals_1 = None
        return (mm, permute)


INFO: TRACED GRAPH
 ===== Backward graph 0 =====
 <eval_with_key>.1 class GraphModule(torch.nn.Module):
    def forward(self, permute: "f32[4, 4][1, 4]cpu", tangents_1: "f32[4, 4][4, 1]cpu"):
         # File: /home/hirsheybar/local/b/pytorch/tmp5.py:6 in f, code: return torch.matmul(x, x)
        mm_1: "f32[4, 4][4, 1]cpu" = torch.ops.aten.mm.default(permute, tangents_1)
        mm_2: "f32[4, 4][4, 1]cpu" = torch.ops.aten.mm.default(tangents_1, permute);  tangents_1 = permute = None

         # File: /home/hirsheybar/local/b/pytorch/tmp5.py:6 in f, code: return torch.matmul(x, x)
        add: "f32[4, 4][4, 1]cpu" = torch.ops.aten.add.Tensor(mm_2, mm_1);  mm_2 = mm_1 = None
        return (add,)


DEBUG: Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
DEBUG: TRACED GRAPH
 ===== Compiled autograd graph =====
 <eval_with_key>.2 class CompiledAutograd(torch.nn.Module):
    def forward(self, inputs, sizes, scalars, hooks):
        # No stacktrace found for following nodes
        getitem: "f32[]cpu" = inputs[0]
        getitem_1: "f32[4, 4]cpu" = inputs[1]
        getitem_2: "f32[4, 4]cpu" = inputs[2];  inputs = None

         # File: /home/hirsheybar/local/b/pytorch/torch/_dynamo/compiled_autograd.py:379 in set_node_origin, code: SumBackward0 (NodeCall 1)
        expand: "f32[4, 4]cpu" = torch.ops.aten.expand.default(getitem, [4, 4]);  getitem = None

         # File: /home/hirsheybar/local/b/pytorch/torch/_dynamo/compiled_autograd.py:379 in set_node_origin, code: CompiledFunctionBackward (NodeCall 2)
        clone: "f32[4, 4]cpu" = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format);  expand = None
        mm: "f32[4, 4]cpu" = torch.ops.aten.mm.default(getitem_1, clone)
        mm_1: "f32[4, 4]cpu" = torch.ops.aten.mm.default(clone, getitem_1);  clone = getitem_1 = None
        add: "f32[4, 4]cpu" = torch.ops.aten.add.Tensor(mm_1, mm);  mm_1 = mm = None

         # File: /home/hirsheybar/local/b/pytorch/torch/_dynamo/compiled_autograd.py:379 in set_node_origin, code: torch::autograd::AccumulateGrad (NodeCall 3)
        accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_2, add);  getitem_2 = add = accumulate_grad_ = None
        _exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub();  _exec_final_callbacks_stub = None
        return []


INFO: TRACED GRAPH
 ===== Forward graph 1 =====
 /home/hirsheybar/local/b/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: "f32[][]cpu", arg1_1: "f32[4, 4][1, 4]cpu", arg2_1: "f32[4, 4][4, 1]cpu"):
         # File: /home/hirsheybar/local/b/pytorch/torch/_dynamo/compiled_autograd.py:379 in set_node_origin, code: SumBackward0 (NodeCall 1)
        expand: "f32[4, 4][0, 0]cpu" = torch.ops.aten.expand.default(arg0_1, [4, 4]);  arg0_1 = None

         # File: /home/hirsheybar/local/b/pytorch/torch/_dynamo/compiled_autograd.py:379 in set_node_origin, code: CompiledFunctionBackward (NodeCall 2)
        clone: "f32[4, 4][4, 1]cpu" = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format);  expand = None
        mm: "f32[4, 4][4, 1]cpu" = torch.ops.aten.mm.default(arg1_1, clone)
        mm_1: "f32[4, 4][4, 1]cpu" = torch.ops.aten.mm.default(clone, arg1_1);  clone = arg1_1 = None
        add: "f32[4, 4][4, 1]cpu" = torch.ops.aten.add.Tensor(mm_1, mm);  mm_1 = mm = None

         # File: /home/hirsheybar/local/b/pytorch/torch/_dynamo/polyfill.py:44 in accumulate_grad, code: new_grad = torch.clone(new_grad)
        clone_1: "f32[4, 4][4, 1]cpu" = torch.ops.aten.clone.default(add);  add = None
        return (clone_1,)

The problem was that we expect the input to AOTAutograd to be a GraphModule in order to do all of the fancy stacktrace preservation logic, but we now need to handle compiled autograd passing in a GmWrapper instead (which it uses to try to preserve input boxing, so inductor can properly free activations)

Stack from ghstack (oldest at bottom):

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 15, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/133574

Note: Links to docs will display an error until the docs builds have been completed.

❌ 14 New Failures, 1 Unrelated Failure

As of commit 50d5811 with merge base 454713f (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@bdhirsh
Copy link
Collaborator Author

bdhirsh commented Aug 15, 2024

CI is unhappy - from a quick look, I'm failing to ensure that the other args to the compiled backward (like hooks) are properly accounted for in the graph

# to ensure args are boxed.
assert params_len == 0
assert len(kwargs) == 0
out = PropagateUnbackedSymInts(mod_).run(args)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's some logic in GmWrapper.forward that we'll need here:

pytorch/torch/_dynamo/utils.py

Lines 2907 to 2909 in 90d2593

def forward(self, *args):
args: List[Any] = list(args)
return self.gm(*self.unflatten_fn(args))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should probably just have a "middleware" wrapper that uniformly takes care of unwrapping GmWrapper and modifying the calling convention, should be cleaner.

@xmfan is there a reason we HAVE to have a GmWrapper? Shouldn't custom GraphModule prelude/postlude be enough here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the prelude/postlude? We use GmWrapper to work around the dynamo GraphModule needing boxed inputs, but AOTDispatcher always tracing the GraphModule with flat inputs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code generation for fx.Graph can be overridden via _codegen field. For example, this is used to generate GraphModule that can take arbitrary pytree as argument, it manages flattening/unflattening in the body. You could potentially use a similar mechanism to implement GmWrapper. cc @suo @Chillee

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can look into it, but we probably still need GmWrapper for non-dynamo frontends that are passing in non-overriden graphs

# https://github.com/pytorch/pytorch/issues/103569

def functional_call(*args, **kwargs):
nonlocal mod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why nonlocal? Are you assigning over mod?

mod_, pytree.tree_unflatten(args[:params_len], params_spec)
), maybe_disable_thunkify():
if isinstance(mod, torch.fx.GraphModule):
if isinstance(mod, (torch.fx.GraphModule, torch._dynamo.utils.GmWrapper)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not do the test here on mod_?

@albanD albanD removed their request for review August 21, 2024 21:37
@yf225
Copy link
Contributor

yf225 commented Sep 10, 2024

@bdhirsh I believe this would be super useful for compiled autograd debugging in general!

@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Nov 10, 2024
@github-actions github-actions bot closed this Dec 10, 2024
@github-actions github-actions bot deleted the gh/bdhirsh/605/head branch January 9, 2025 02:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants