Skip to content

[Dynamo][autograd.Function] Use fake tensor prop to infer fwd output#136184

Closed
yanboliang wants to merge 13 commits intogh/yanboliang/39/basefrom
gh/yanboliang/39/head
Closed

[Dynamo][autograd.Function] Use fake tensor prop to infer fwd output#136184
yanboliang wants to merge 13 commits intogh/yanboliang/39/basefrom
gh/yanboliang/39/head

Conversation

@yanboliang
Copy link
Contributor

@yanboliang yanboliang commented Sep 17, 2024

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 17, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136184

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f8ed313 with merge base 0a4bcbf (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

yanboliang added a commit that referenced this pull request Sep 17, 2024
@yanboliang
Copy link
Contributor Author

yanboliang commented Sep 17, 2024

@zou3519 Let's discuss if this is the proper way at here. I think we only need to update the example_value of the autograd.Function.appy output, which is inferred from speculating fwd graph. Now I'm switching to a separate fake tensor prop to get the example_value, and it works well in common cases. However, there are several cases don't work well:

I think all of these issues are resolved by speculate_subgraph, and probably it's the reason that we need it. So if we leverage on fake tensor prop to get example_value, seems we need to duplicate these features again.

And especially the first issue above, seems we still need to fake prop under no_grad mode, which seems we are back to the starting point.

@yanboliang yanboliang added the topic: not user facing topic category label Sep 17, 2024
[ghstack-poisoned]
Copy link
Contributor Author

@yanboliang yanboliang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, after switching to the new solution, the only failed tests are Triton related, e.g,

======================================================================
ERROR: test_triton_kernel_basic (__main__.AutogradFunctionTests)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/ybliang/local/pytorch/torch/testing/_internal/common_utils.py", line 2979, in wrapper
    method(*args, **kwargs)
  File "/data/users/ybliang/pytorch/test/dynamo/test_autograd_function.py", line 1275, in test_triton_kernel_basic
    z = f(x, y)
  File "/home/ybliang/local/pytorch/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 1292, in __call__
    return self._torchdynamo_orig_callable(
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 530, in __call__
    return _compile(
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 970, in _compile
    raise InternalTorchDynamoError(
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 933, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 675, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/home/ybliang/local/pytorch/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 708, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/ybliang/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 220, in _fn
    return fn(*args, **kwargs)
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 643, in transform
    tracer.run()
  File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2776, in run
    super().run()
  File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 979, in run
    while self.step():
  File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 891, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 569, in wrapper
    return inner_fn(self, inst)
  File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1598, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 826, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/home/ybliang/local/pytorch/torch/_dynamo/variables/misc.py", line 1025, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
  File "/home/ybliang/local/pytorch/torch/_dynamo/variables/misc.py", line 775, in call_method
    return self.call_apply(tx, args, kwargs)
  File "/home/ybliang/local/pytorch/torch/_dynamo/variables/misc.py", line 700, in call_apply
    ).call_function(tx, args, kwargs)
  File "/home/ybliang/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 2265, in call_function
    example_value = self.fn_cls.apply(*fake_args, **fake_kwargs)
  File "/home/ybliang/local/pytorch/torch/autograd/function.py", line 575, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/data/users/ybliang/pytorch/test/dynamo/test_autograd_function.py", line 1260, in forward
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
  File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/triton/runtime/jit.py", line 618, in run
    bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs)
  File "<string>", line 2, in dynamic_func
  File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/triton/runtime/jit.py", line 297, in compute_spec_key
    if hasattr(v, "data_ptr") and (v.data_ptr() % 16 == 0):
torch._dynamo.exc.InternalTorchDynamoError: RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). If you're using torch.compile/export/fx, it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html

[ghstack-poisoned]
yanboliang added a commit that referenced this pull request Sep 18, 2024
[ghstack-poisoned]
@yanboliang yanboliang added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 26, 2024
[ghstack-poisoned]
yanboliang added a commit that referenced this pull request Sep 26, 2024
@yanboliang
Copy link
Contributor Author

yanboliang commented Sep 26, 2024

@zou3519 I checked all of these failures and found they are from two reasons:

  • The added fake prop doesn't handle a pathological scenario well: If users define an autograd function with a static forward method but also pass in self as the first argument, e.g, the failed DebertaForQuestionAnswering. This is because the new added fake prop using nn.Module._call_impl which offsets self and causes positional arguments mismatch.

  • Dynamic shapes related failure

torch._dynamo.exc.InternalTorchDynamoError: PendingUnbackedSymbolNotFound: Pending unbacked symbols {u5, u4} not in returned outputs FakeTensor(..., size=(u0 + u1,), grad_fn=<ApplyTemplateBackward>) ((1,), 0).
Did you accidentally call new_dynamic_size() or item() more times than you needed to in your fake implementation?
For more help, see https://docs.google.com/document/d/1RWrH-3wLEpzR9kCS6gGBNen_-Fs-8PVbWWFE5AcgeWE/edit

I think this is because we create symbols (u0 + u1) during speculate_subgraph, but when do the fake prop I added, we are creating new symbols (u5, u4) rather than directly refer to the original one. This seems only happens when capture_scalar_outputs=True. cc @ezyang

  • The other flex attention related failures are trunk issues, which is not relevant to this PR.

@ezyang
Copy link
Contributor

ezyang commented Sep 30, 2024

I'm not terribly familiar with the context behind this PR, but if you are doing a fake prop and then throwing away the results, you might want a pattern similar to what we did for aot autograd metadata analysis:

        fake_mode = detect_fake_mode()
        if fake_mode and (shape_env := fake_mode.shape_env):
            suppress_pending = shape_env.ignore_fresh_unbacked_symbols()
        with disable_above, mode, suppress_pending:
            # precondition: The passed in function already handles unflattening inputs + flattening outputs
            flat_f_args = pytree.tree_map(_to_fun, flat_args)
            flat_f_outs = f(*flat_f_args)
            # We didn't do any tracing, so we don't need to process the
            # unbacked symbols, they will just disappear into the ether.
            # Also, prevent memoization from applying.
            if fake_mode:
                fake_mode.epoch += 1
                fake_mode.reset_nt_tensor_id_counter()

Ignore fresh unbacked symbols is the important part, but I'm not sure if you need to rev the epoch or not.

@ezyang
Copy link
Contributor

ezyang commented Sep 30, 2024

If these new unbacked symbols are going into the FX graph, though, you need to do something like PropagateUnbackedSymInts instead

@zou3519
Copy link
Contributor

zou3519 commented Sep 30, 2024

Context: the way Dynamo handles autograd.Function isn't faithful to eager-mode PyTorch when the inputs have requires_grad=True. What Dynamo does today is:

  • we trace out the autograd.Function forward and backward
  • we put these into a HOP
  • we put a call to the HOP in the graph.
  • we set the example values on the output of the HOP to be the FakeTensor outputs of tracing the forward

The FakeTensor outputs of tracing the forward are not actually faithful to eager-mode PyTorch. The problem is we trace the forward under torch.no_grad(), so the FakeTensor outputs have requires_grad=False. In eager-mode, the eager outputs of the autograd.Function have requires_grad=True.

There's two approaches we can do to resolve this:

  1. Add a new step that does FakeTensor prop using the autograd.Function, so that the outputs have the correct required_grad-ness, and use these as the FakeTensor outputs
  2. Somehow "fix up" the output of the HOP by emulating what eager-mode autograd.Function does

We tried to do (1) in this PR, because it seemed the easiest, but are running into the unbacked symint thing because we do FakeTensor prop twice: once when tracing out the autograd.Function forward, and once in a new step that does FakeTensor prop using the autograd.Function.

If these new unbacked symbols are going into the FX graph, though, you need to do something like PropagateUnbackedSymInts instead

They are going into the FX graph. If we keep going down this design -- if there are symints in both traces, then we probably need to indicate that they are the same

@ezyang
Copy link
Contributor

ezyang commented Oct 12, 2024

Is help still needed on this? Might need to hop on a VC

@zou3519
Copy link
Contributor

zou3519 commented Oct 14, 2024

@ezyang yes, we still need help. Yanbo's on recharge so we can wait until he comes back.

@xinyu-intel
Copy link
Contributor

@yanboliang @zou3519 Hi, thanks for these findings and any updates here?

@yanboliang
Copy link
Contributor Author

@xinyu-intel I will update this PR and get this in soon.

[ghstack-poisoned]
yanboliang added a commit that referenced this pull request Nov 20, 2024
[ghstack-poisoned]
yanboliang added a commit that referenced this pull request Nov 21, 2024
@yanboliang yanboliang requested a review from zou3519 November 21, 2024 05:22
@yanboliang
Copy link
Contributor Author

@zou3519 @ezyang This is ready for reviewing.

return torch.ops.mylib.foo(x0, x1)

f(torch.randn(9, requires_grad=True), torch.tensor([3, 6]))
f(torch.randn(9, requires_grad=True), torch.tensor([3, 6]))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After chatting with @ezyang, we decided to wrap this as custom op to unblock this PR. The previous failure actually is an unlikely situation in real use cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't you promise me you would add an xfail'ed version of the old test lol

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated!

@ezyang
Copy link
Contributor

ezyang commented Nov 22, 2024

Endorsing test changes

[ghstack-poisoned]
[ghstack-poisoned]
# Store the invocation as a call
from torch._functorch.autograd_function import autograd_function_apply

# We use speculate_subgraph to get the fwd graph, but it's alway under no grad mode like what eager mode does.
Copy link
Contributor

@zou3519 zou3519 Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reverted #134872 (revert PR #137891) when you were out so that we can have consistency with the release and main branch.

Because this PR is supposed to "fix" a bug introduced by #134872, we should merge the changes in #134872 in here and test them together.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch! Updated!

[ghstack-poisoned]
yanboliang added a commit that referenced this pull request Nov 22, 2024
@yanboliang yanboliang requested a review from zou3519 November 23, 2024 06:30
Comment on lines +9114 to +9117
# Compiling autograd.Function traces fwd function twice, but the same unbacked symints were not identified
# as the same across the two tracings. This is an unlikely situation in real use cases, so we add another
# `test_validate_outputs_unbacked_by_custom_op` to mitigate it and keep this one as expected failure
# until we have a proper fix.
Copy link
Contributor

@zou3519 zou3519 Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the problem that we had before, right? Are we just saying we don't care about it? There are alternative designs we can do to support this.

Alternative design proposal:

  • run speculate subgraph
  • do not do a separate FakeTensor pass
  • instead, just do exactly what the autograd.Function logic in C++ does:
    • create a grad_fn and put it onto the FakeTensor. This grad_fn should probably be some dummy grad_fn. If this is too difficult, then I'm fine with a solution that just sets the output Tensor's requires_grad to be True (we can figure out how to actually build a grad_fn if someone runs into this problem).
    • Also, there is a special case for what happens if the autograd.Function returns the input directly. In that case we want to alias the input, detach, and then slam a grad_fn onto the Tensor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NVM Ed tells me you guys discussed this and we don't care about unbacked symints in autograd.Function. I'm OK with this then

import torch._C
import torch.fx
import torch.nn
import torch.onnx.operators
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo, removed!

[ghstack-poisoned]
yanboliang added a commit that referenced this pull request Nov 25, 2024
@yanboliang
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@yanboliang yanboliang deleted the gh/yanboliang/39/head branch November 26, 2024 01:17
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants