[Dynamo][autograd.Function] Use fake tensor prop to infer fwd output#136184
[Dynamo][autograd.Function] Use fake tensor prop to infer fwd output#136184yanboliang wants to merge 13 commits intogh/yanboliang/39/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136184
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f8ed313 with merge base 0a4bcbf ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@zou3519 Let's discuss if this is the proper way at here. I think we only need to update the
I think all of these issues are resolved by And especially the first issue above, seems we still need to fake prop under |
yanboliang
left a comment
There was a problem hiding this comment.
Updated, after switching to the new solution, the only failed tests are Triton related, e.g,
======================================================================
ERROR: test_triton_kernel_basic (__main__.AutogradFunctionTests)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/ybliang/local/pytorch/torch/testing/_internal/common_utils.py", line 2979, in wrapper
method(*args, **kwargs)
File "/data/users/ybliang/pytorch/test/dynamo/test_autograd_function.py", line 1275, in test_triton_kernel_basic
z = f(x, y)
File "/home/ybliang/local/pytorch/torch/_dynamo/eval_frame.py", line 465, in _fn
return fn(*args, **kwargs)
File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 1292, in __call__
return self._torchdynamo_orig_callable(
File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 530, in __call__
return _compile(
File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 970, in _compile
raise InternalTorchDynamoError(
File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 933, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 675, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "/home/ybliang/local/pytorch/torch/_utils_internal.py", line 87, in wrapper_function
return function(*args, **kwargs)
File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 708, in _compile_inner
out_code = transform_code_object(code, transform)
File "/home/ybliang/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
transformations(instructions, code_options)
File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 220, in _fn
return fn(*args, **kwargs)
File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 643, in transform
tracer.run()
File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2776, in run
super().run()
File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 979, in run
while self.step():
File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 891, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 569, in wrapper
return inner_fn(self, inst)
File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1598, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 826, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "/home/ybliang/local/pytorch/torch/_dynamo/variables/misc.py", line 1025, in call_function
return self.obj.call_method(tx, self.name, args, kwargs)
File "/home/ybliang/local/pytorch/torch/_dynamo/variables/misc.py", line 775, in call_method
return self.call_apply(tx, args, kwargs)
File "/home/ybliang/local/pytorch/torch/_dynamo/variables/misc.py", line 700, in call_apply
).call_function(tx, args, kwargs)
File "/home/ybliang/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 2265, in call_function
example_value = self.fn_cls.apply(*fake_args, **fake_kwargs)
File "/home/ybliang/local/pytorch/torch/autograd/function.py", line 575, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/data/users/ybliang/pytorch/test/dynamo/test_autograd_function.py", line 1260, in forward
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/triton/runtime/jit.py", line 345, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/triton/runtime/jit.py", line 618, in run
bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs)
File "<string>", line 2, in dynamic_func
File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/triton/runtime/jit.py", line 297, in compute_spec_key
if hasattr(v, "data_ptr") and (v.data_ptr() % 16 == 0):
torch._dynamo.exc.InternalTorchDynamoError: RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). If you're using torch.compile/export/fx, it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html
|
@zou3519 I checked all of these failures and found they are from two reasons:
I think this is because we create symbols (
|
|
I'm not terribly familiar with the context behind this PR, but if you are doing a fake prop and then throwing away the results, you might want a pattern similar to what we did for aot autograd metadata analysis: Ignore fresh unbacked symbols is the important part, but I'm not sure if you need to rev the epoch or not. |
|
If these new unbacked symbols are going into the FX graph, though, you need to do something like PropagateUnbackedSymInts instead |
|
Context: the way Dynamo handles autograd.Function isn't faithful to eager-mode PyTorch when the inputs have requires_grad=True. What Dynamo does today is:
The FakeTensor outputs of tracing the forward are not actually faithful to eager-mode PyTorch. The problem is we trace the forward under torch.no_grad(), so the FakeTensor outputs have requires_grad=False. In eager-mode, the eager outputs of the autograd.Function have requires_grad=True. There's two approaches we can do to resolve this:
We tried to do (1) in this PR, because it seemed the easiest, but are running into the unbacked symint thing because we do FakeTensor prop twice: once when tracing out the autograd.Function forward, and once in a new step that does FakeTensor prop using the autograd.Function.
They are going into the FX graph. If we keep going down this design -- if there are symints in both traces, then we probably need to indicate that they are the same |
|
Is help still needed on this? Might need to hop on a VC |
|
@ezyang yes, we still need help. Yanbo's on recharge so we can wait until he comes back. |
|
@yanboliang @zou3519 Hi, thanks for these findings and any updates here? |
|
@xinyu-intel I will update this PR and get this in soon. |
| return torch.ops.mylib.foo(x0, x1) | ||
|
|
||
| f(torch.randn(9, requires_grad=True), torch.tensor([3, 6])) | ||
| f(torch.randn(9, requires_grad=True), torch.tensor([3, 6])) |
There was a problem hiding this comment.
After chatting with @ezyang, we decided to wrap this as custom op to unblock this PR. The previous failure actually is an unlikely situation in real use cases.
There was a problem hiding this comment.
Didn't you promise me you would add an xfail'ed version of the old test lol
|
Endorsing test changes |
| # Store the invocation as a call | ||
| from torch._functorch.autograd_function import autograd_function_apply | ||
|
|
||
| # We use speculate_subgraph to get the fwd graph, but it's alway under no grad mode like what eager mode does. |
There was a problem hiding this comment.
There was a problem hiding this comment.
Nice catch! Updated!
| # Compiling autograd.Function traces fwd function twice, but the same unbacked symints were not identified | ||
| # as the same across the two tracings. This is an unlikely situation in real use cases, so we add another | ||
| # `test_validate_outputs_unbacked_by_custom_op` to mitigate it and keep this one as expected failure | ||
| # until we have a proper fix. |
There was a problem hiding this comment.
This was the problem that we had before, right? Are we just saying we don't care about it? There are alternative designs we can do to support this.
Alternative design proposal:
- run speculate subgraph
- do not do a separate FakeTensor pass
- instead, just do exactly what the autograd.Function logic in C++ does:
- create a grad_fn and put it onto the FakeTensor. This grad_fn should probably be some dummy grad_fn. If this is too difficult, then I'm fine with a solution that just sets the output Tensor's requires_grad to be True (we can figure out how to actually build a grad_fn if someone runs into this problem).
- Also, there is a special case for what happens if the autograd.Function returns the input directly. In that case we want to alias the input, detach, and then slam a grad_fn onto the Tensor.
There was a problem hiding this comment.
NVM Ed tells me you guys discussed this and we don't care about unbacked symints in autograd.Function. I'm OK with this then
| import torch._C | ||
| import torch.fx | ||
| import torch.nn | ||
| import torch.onnx.operators |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…ytorch#136184) Fixes pytorch#129963 Pull Request resolved: pytorch#136184 Approved by: https://github.com/zou3519
Stack from ghstack (oldest at bottom):
Fixes #129963
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @rec