Skip to content

[Functionalization] Manually redispatch convolution_backward to functionalize pass#4681

Merged
alanwaketan merged 8 commits intofunctionalizationfrom
alanwaketan/conv
Feb 25, 2023
Merged

[Functionalization] Manually redispatch convolution_backward to functionalize pass#4681
alanwaketan merged 8 commits intofunctionalizationfrom
alanwaketan/conv

Conversation

@alanwaketan
Copy link
Copy Markdown
Collaborator

@alanwaketan alanwaketan commented Feb 23, 2023

Summary:
For any CompositeExplicitAutograd ops, we are supposed to explicitly re-enable functionalization such that any decomposed ops within those ops get functionalized as well.

However, if directly calling into at::functionalization::functionalize_aten_op, convolution_backward will somehow omit convolution_backward_overridable which is our own kernel to calculate convolution. Thus, no grads are produced.

To workaround the issue, we manually redispatch convolution_backward to functionalize pass.

Test Plan:
PJRT_DEVICE=TPU python test/test_operations.py -v -k test_conv2d_backward

Comment thread test/test_operations.py
], test_fn)

def test_conv2d_backward(self):
# Somehow eager cpu produces different results than us, and
Copy link
Copy Markdown
Collaborator Author

@alanwaketan alanwaketan Feb 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JackCaoG do you know why? cc @miladm

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is not that uncommon that we give slightly different result, is it off by a lot? If is possible that result is different because we run some optimization pass which make code runs faster but not as accurate.


::std::tuple<at::Tensor, at::Tensor, at::Tensor>
XLANativeFunctions::convolution_backward(
const at::Tensor& grad_output, const at::Tensor& input,
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh I hit a problem here. It looks like we do need to redispatch convolution_backward and _convolution to functionalized pass in order to make view ops in their decompositions process be replaced with view_copy ops. However, somehow convolution_backward is not calling into at::native::convolution_backward and thus our own kernel convolution_backward_overrideable is not called.

Here is the dispatcher calls:
convolution:

[call] op=[aten::conv_transpose3d.input], key=[AutogradXLA]
  [call] op=[aten::convolution], key=[AutogradXLA]
   [redispatch] op=[aten::convolution], key=[Functionalize]
    [callBoxed] op=[aten::convolution], key=[XLA]
     [call] op=[aten::_convolution], key=[XLA]
      [redispatchBoxed] op=[aten::_convolution], key=[Meta]
       [call] op=[aten::convolution_overrideable], key=[Functionalize]
        [callBoxed] op=[aten::convolution_overrideable], key=[XLA]
[call] op=[aten::ones_like], key=[Functionalize]

convolution_backward:

[call] op=[aten::convolution_backward], key=[AutogradXLA]
  [redispatch] op=[aten::convolution_backward], key=[Functionalize]
   [callBoxed] op=[aten::convolution_backward], key=[XLA]
    [redispatchBoxed] op=[aten::convolution_backward], key=[Meta]
     [call] op=[aten::new_empty], key=[Functionalize]

Let me know if you need more information.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that trace coming from a run on this PR, with these ops removed from XLA? or from before this PR.

In the convolution_backward() trace above, I see aten::convolution_backward called with the Meta key. My guess is that we're calling the meta implementation of convolution_backward for shape inference, either directly in pytorch/xla, or in the functionalization kernel (the fact that the call right before it was with the XLA key makes me think it's coming from the XLA kernel). This will just run shape compute for convolution_backward, so it won't end up dispatching to XLAs implementation - it will run shape compute from core. Is the problem that this is error'ing somehow?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before this PR, so what the convolution_backward kernal in XLA does is to call

at::functionalization::functionalize_aten_op<ATEN_OP(
      convolution_backward)>::call(grad_output, input, weight, bias_sizes,
                                   stride, padding, dilation, transposed,
                                   output_padding, groups, output_mask);

We did the same thing for _convolution. However, in the Meta kernal of _convolution, it ends up calling convolution_overrideable. But not seeing the same behavior for convolution_backward.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh I applied a hack to make things work.

Summary:
I somehow wrongly lowered _convolution/convolution_backward in the early
stage. It then makes conv.backward disappear from the graph. Therefore,
undoing that change and adds a test case for it.

Test Plan:
PJRT_DEVICE=TPU python test/test_operations.py -v -k test_conv2d_backward
@alanwaketan alanwaketan changed the title [Functionalization] Undo lowering _convolution/convolution_backward [Functionalization] Manually redispatch convolution_backward to functionalize pass Feb 25, 2023
@alanwaketan
Copy link
Copy Markdown
Collaborator Author

I think this PR is ready for reviews.

@alanwaketan
Copy link
Copy Markdown
Collaborator Author

Thanks Jack for approving the change.

@alanwaketan alanwaketan merged commit 1f50e32 into functionalization Feb 25, 2023
alanwaketan added a commit that referenced this pull request Mar 1, 2023
…ionalize pass (#4681)

Summary:
For any CompositeExplicitAutograd ops, we are supposed to explicitly re-enable functionalization such that any decomposed ops within those ops get functionalized as well.

However, if directly calling into at::functionalization::functionalize_aten_op, convolution_backward will somehow omit convolution_backward_overridable which is our own kernel to calculate convolution. Thus, no grads are produced.

To workaround the issue, we manually redispatch convolution_backward to functionalize pass.

Test Plan:
PJRT_DEVICE=TPU python test/test_operations.py -v -k test_conv2d_backward
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants