[Functionalization] Manually redispatch convolution_backward to functionalize pass#4681
[Functionalization] Manually redispatch convolution_backward to functionalize pass#4681alanwaketan merged 8 commits intofunctionalizationfrom
Conversation
| ], test_fn) | ||
|
|
||
| def test_conv2d_backward(self): | ||
| # Somehow eager cpu produces different results than us, and |
There was a problem hiding this comment.
it is not that uncommon that we give slightly different result, is it off by a lot? If is possible that result is different because we run some optimization pass which make code runs faster but not as accurate.
3fba351 to
829d578
Compare
|
|
||
| ::std::tuple<at::Tensor, at::Tensor, at::Tensor> | ||
| XLANativeFunctions::convolution_backward( | ||
| const at::Tensor& grad_output, const at::Tensor& input, |
There was a problem hiding this comment.
@bdhirsh I hit a problem here. It looks like we do need to redispatch convolution_backward and _convolution to functionalized pass in order to make view ops in their decompositions process be replaced with view_copy ops. However, somehow convolution_backward is not calling into at::native::convolution_backward and thus our own kernel convolution_backward_overrideable is not called.
Here is the dispatcher calls:
convolution:
[call] op=[aten::conv_transpose3d.input], key=[AutogradXLA]
[call] op=[aten::convolution], key=[AutogradXLA]
[redispatch] op=[aten::convolution], key=[Functionalize]
[callBoxed] op=[aten::convolution], key=[XLA]
[call] op=[aten::_convolution], key=[XLA]
[redispatchBoxed] op=[aten::_convolution], key=[Meta]
[call] op=[aten::convolution_overrideable], key=[Functionalize]
[callBoxed] op=[aten::convolution_overrideable], key=[XLA]
[call] op=[aten::ones_like], key=[Functionalize]
convolution_backward:
[call] op=[aten::convolution_backward], key=[AutogradXLA]
[redispatch] op=[aten::convolution_backward], key=[Functionalize]
[callBoxed] op=[aten::convolution_backward], key=[XLA]
[redispatchBoxed] op=[aten::convolution_backward], key=[Meta]
[call] op=[aten::new_empty], key=[Functionalize]
Let me know if you need more information.
There was a problem hiding this comment.
Is that trace coming from a run on this PR, with these ops removed from XLA? or from before this PR.
In the convolution_backward() trace above, I see aten::convolution_backward called with the Meta key. My guess is that we're calling the meta implementation of convolution_backward for shape inference, either directly in pytorch/xla, or in the functionalization kernel (the fact that the call right before it was with the XLA key makes me think it's coming from the XLA kernel). This will just run shape compute for convolution_backward, so it won't end up dispatching to XLAs implementation - it will run shape compute from core. Is the problem that this is error'ing somehow?
There was a problem hiding this comment.
Before this PR, so what the convolution_backward kernal in XLA does is to call
at::functionalization::functionalize_aten_op<ATEN_OP(
convolution_backward)>::call(grad_output, input, weight, bias_sizes,
stride, padding, dilation, transposed,
output_padding, groups, output_mask);
We did the same thing for _convolution. However, in the Meta kernal of _convolution, it ends up calling convolution_overrideable. But not seeing the same behavior for convolution_backward.
There was a problem hiding this comment.
@bdhirsh I applied a hack to make things work.
5bd4743 to
d527999
Compare
Summary: I somehow wrongly lowered _convolution/convolution_backward in the early stage. It then makes conv.backward disappear from the graph. Therefore, undoing that change and adds a test case for it. Test Plan: PJRT_DEVICE=TPU python test/test_operations.py -v -k test_conv2d_backward
829d578 to
9b30217
Compare
|
I think this PR is ready for reviews. |
|
Thanks Jack for approving the change. |
…ionalize pass (#4681) Summary: For any CompositeExplicitAutograd ops, we are supposed to explicitly re-enable functionalization such that any decomposed ops within those ops get functionalized as well. However, if directly calling into at::functionalization::functionalize_aten_op, convolution_backward will somehow omit convolution_backward_overridable which is our own kernel to calculate convolution. Thus, no grads are produced. To workaround the issue, we manually redispatch convolution_backward to functionalize pass. Test Plan: PJRT_DEVICE=TPU python test/test_operations.py -v -k test_conv2d_backward
Summary:
For any CompositeExplicitAutograd ops, we are supposed to explicitly re-enable functionalization such that any decomposed ops within those ops get functionalized as well.
However, if directly calling into at::functionalization::functionalize_aten_op, convolution_backward will somehow omit convolution_backward_overridable which is our own kernel to calculate convolution. Thus, no grads are produced.
To workaround the issue, we manually redispatch convolution_backward to functionalize pass.
Test Plan:
PJRT_DEVICE=TPU python test/test_operations.py -v -k test_conv2d_backward