Skip to content

make block_diag composite compliant#77716

Closed
bdhirsh wants to merge 17 commits intogh/bdhirsh/238/basefrom
gh/bdhirsh/238/head
Closed

make block_diag composite compliant#77716
bdhirsh wants to merge 17 commits intogh/bdhirsh/238/basefrom
gh/bdhirsh/238/head

Conversation

@bdhirsh
Copy link
Collaborator

@bdhirsh bdhirsh commented May 18, 2022

The code for block_diag isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make block_diag composite compliant though, because it performs O(num_inputs) mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making block_diag CompositeExplicitAutograd, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.

Stack from ghstack:

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented May 18, 2022

🔗 Helpful links

✅ No Failures (1 Pending)

As of commit 3cebbac (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

return grad_inputs;
}

std::vector<Tensor> block_diag_backward(const Tensor & grad, const std::vector<std::vector<int64_t>> &sizes, const std::vector<ScalarType> &dtypes) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mostly got this logic from our existing implementation of cat_backward (which knows how to deal with complex)

@bdhirsh bdhirsh requested a review from zou3519 May 18, 2022 03:17
Comment on lines +727 to +731
Tensor grad_;
bool grad_is_complex = grad.is_complex();
if (grad_is_complex) {
grad_ = at::real(grad);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is going on with complex numbers? Is the problem that there is a weird case when we have mixed dtype inputs to block_diag, some of which can be complex, and we need to handle that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this is pretty interesting - I got this from the logic for cat backward, and it looks this handles exactly that - mixed complex/non-complex inputs. I tried a mixed complex/non-complex input locally and it only started passing when I added this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(there also aren't OpInfo tests for the mixed case, so I'll add some)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, I will read the formula more closely. I am not sure OpInfos are supposed to test the mixed dtype case (@mruberry ?) because the framework requests a test with a specific dtype. Maybe we just want a manual test somewhere for this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I have this implemented by having the sample inputs function check if the passed-in dtype is complex, and if so it spits out an extra sample input for the [only complex], and [mixed complex/noncomplex] case, and the OpInfo suite seemed to pass locally for me. I'll see what happens in CI

The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
@bdhirsh
Copy link
Collaborator Author

bdhirsh commented May 18, 2022

@zou3519 I also see some gradcheck tests failing with RuntimeError: Batching rule not implemented for aten::block_diag. We could not generate a fallback. (https://github.com/pytorch/pytorch/runs/6496894229?check_suite_focus=true)

How does vmap get used in core testing? (I'm guessing this is the vmap in core, through the Batched dispatch key?)

Maybe I just need to opt block_diag out of a list somewhere, since it can't use the boxed fallback (it looks like it doesn't support tensorlist inputs)

@zou3519
Copy link
Contributor

zou3519 commented May 18, 2022

@zou3519 I also see some gradcheck tests failing with RuntimeError: Batching rule not implemented for aten::block_diag. We could not generate a fallback. (https://github.com/pytorch/pytorch/runs/6496894229?check_suite_focus=true)

How does vmap get used in core testing? (I'm guessing this is the vmap in core, through the Batched dispatch key?)

Maybe I just need to opt block_diag out of a list somewhere, since it can't use the boxed fallback (it looks like it doesn't support tensorlist inputs)

Yes, that's the vmap in core, through the Batched dispatch key. You just need to opt out of it (don't worry about adding a batching rule for block_diag in core, the vmap in core is very behind and has a different batching rule registration API).

The most likely place where it is being tested is batched gradient computation. You probably want to set at least one of these flags to False in the OpInfo:
image

bdhirsh added 8 commits May 19, 2022 06:34
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
@zou3519 zou3519 self-requested a review May 24, 2022 15:35
Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but tbh I am not sure what is going on with the Lazy shape inference. Is there a test we can add for that?

Also some minor c++ efficiency nits

if (any_defined) {
std::vector<Tensor> fw_grads;

for (auto& t: tensors) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: const auto& t

}

if (any_defined) {
std::vector<Tensor> fw_grads;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: fwd_grads.reserve(...); it's good practice to reserve space so the vector doesn't get resized

}
auto& shape = sizes[i];
// If input was empty tensor, gradInput should be empty tensor.
if (shape == std::vector<int64_t>({0})) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the RHS construct a std::vector and throw it away? Because if so, that's not very efficient due to the dynamic allocation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point - I'll clean this up (blindly copied from the cat backward kernel)

if (shape.size() == 1) {
slice = slice.squeeze(-1);
} else if (shape.size() == 0) {
slice = slice.squeeze(-1).squeeze(-1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: slice.view({}) might be faster, it does one fewer tensor operation

Tensor grad_val;
if (!at::isComplexType(dtypes[i]) && grad_is_complex) {
// R -> C
grad_val = grad_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: a better name for grad_ is "real_view_of_grad" or something

int64_t cur_dim1 = 0;

for (const auto i : c10::irange(sizes.size())) {
Tensor grad_val;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this causes an extra refcount bump I think? If we really care we can do
auto grad_val = cond ? grad_ : grad;

@bdhirsh
Copy link
Collaborator Author

bdhirsh commented May 24, 2022

@zou3519 I'll clean up the efficiency nits, and I also think I can remove the LTC changes (I forget why I even got the compiler errors that caused them in the first place.. but we shouldn't need them, since LTC can still rely on the decomposition).

I'll merge this - and then once nightlies update and functorch CI starts complaining, I'll merge pytorch/functorch#814. Does that sound ok?

@zou3519
Copy link
Contributor

zou3519 commented May 24, 2022

@zou3519 I'll clean up the efficiency nits, and I also think I can remove the LTC changes (I forget why I even got the compiler errors that caused them in the first place.. but we shouldn't need them, since LTC can still rely on the decomposition).

I'll merge this - and then once nightlies update and functorch CI starts complaining, I'll merge pytorch/functorch#814. Does that sound ok?

SGTM!

bdhirsh added 4 commits May 25, 2022 07:13
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
@facebook-github-bot facebook-github-bot deleted the gh/bdhirsh/238/head branch May 30, 2022 14:17
facebook-github-bot pushed a commit that referenced this pull request May 31, 2022
Summary:
Pull Request resolved: #77716

Approved by: https://github.com/zou3519

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/5cc258ec9ea15951384fc35e62006691b2935915

Reviewed By: seemethere

Differential Revision: D36783092

Pulled By: bdhirsh

fbshipit-source-id: f8069095d5755136ee589f3cde1198abac57c0c7
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants