make block_diag composite compliant#77716
make block_diag composite compliant#77716bdhirsh wants to merge 17 commits intogh/bdhirsh/238/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful links
✅ No Failures (1 Pending)As of commit 3cebbac (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
[ghstack-poisoned]
| return grad_inputs; | ||
| } | ||
|
|
||
| std::vector<Tensor> block_diag_backward(const Tensor & grad, const std::vector<std::vector<int64_t>> &sizes, const std::vector<ScalarType> &dtypes) { |
There was a problem hiding this comment.
I mostly got this logic from our existing implementation of cat_backward (which knows how to deal with complex)
[ghstack-poisoned]
[ghstack-poisoned]
| Tensor grad_; | ||
| bool grad_is_complex = grad.is_complex(); | ||
| if (grad_is_complex) { | ||
| grad_ = at::real(grad); | ||
| } |
There was a problem hiding this comment.
What is going on with complex numbers? Is the problem that there is a weird case when we have mixed dtype inputs to block_diag, some of which can be complex, and we need to handle that?
There was a problem hiding this comment.
yeah this is pretty interesting - I got this from the logic for cat backward, and it looks this handles exactly that - mixed complex/non-complex inputs. I tried a mixed complex/non-complex input locally and it only started passing when I added this.
There was a problem hiding this comment.
(there also aren't OpInfo tests for the mixed case, so I'll add some)
There was a problem hiding this comment.
That makes sense, I will read the formula more closely. I am not sure OpInfos are supposed to test the mixed dtype case (@mruberry ?) because the framework requests a test with a specific dtype. Maybe we just want a manual test somewhere for this.
There was a problem hiding this comment.
Oh I have this implemented by having the sample inputs function check if the passed-in dtype is complex, and if so it spits out an extra sample input for the [only complex], and [mixed complex/noncomplex] case, and the OpInfo suite seemed to pass locally for me. I'll see what happens in CI
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
|
@zou3519 I also see some gradcheck tests failing with How does vmap get used in core testing? (I'm guessing this is the vmap in core, through the Maybe I just need to opt |
Yes, that's the vmap in core, through the Batched dispatch key. You just need to opt out of it (don't worry about adding a batching rule for block_diag in core, the vmap in core is very behind and has a different batching rule registration API). The most likely place where it is being tested is batched gradient computation. You probably want to set at least one of these flags to False in the OpInfo: |
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
| if (any_defined) { | ||
| std::vector<Tensor> fw_grads; | ||
|
|
||
| for (auto& t: tensors) { |
| } | ||
|
|
||
| if (any_defined) { | ||
| std::vector<Tensor> fw_grads; |
There was a problem hiding this comment.
nit: fwd_grads.reserve(...); it's good practice to reserve space so the vector doesn't get resized
| } | ||
| auto& shape = sizes[i]; | ||
| // If input was empty tensor, gradInput should be empty tensor. | ||
| if (shape == std::vector<int64_t>({0})) { |
There was a problem hiding this comment.
Does the RHS construct a std::vector and throw it away? Because if so, that's not very efficient due to the dynamic allocation
There was a problem hiding this comment.
good point - I'll clean this up (blindly copied from the cat backward kernel)
| if (shape.size() == 1) { | ||
| slice = slice.squeeze(-1); | ||
| } else if (shape.size() == 0) { | ||
| slice = slice.squeeze(-1).squeeze(-1); |
There was a problem hiding this comment.
nit: slice.view({}) might be faster, it does one fewer tensor operation
| Tensor grad_val; | ||
| if (!at::isComplexType(dtypes[i]) && grad_is_complex) { | ||
| // R -> C | ||
| grad_val = grad_; |
There was a problem hiding this comment.
nit: a better name for grad_ is "real_view_of_grad" or something
| int64_t cur_dim1 = 0; | ||
|
|
||
| for (const auto i : c10::irange(sizes.size())) { | ||
| Tensor grad_val; |
There was a problem hiding this comment.
Nit: this causes an extra refcount bump I think? If we really care we can do
auto grad_val = cond ? grad_ : grad;
|
@zou3519 I'll clean up the efficiency nits, and I also think I can remove the LTC changes (I forget why I even got the compiler errors that caused them in the first place.. but we shouldn't need them, since LTC can still rely on the decomposition). I'll merge this - and then once nightlies update and functorch CI starts complaining, I'll merge pytorch/functorch#814. Does that sound ok? |
SGTM! |
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
Summary: Pull Request resolved: #77716 Approved by: https://github.com/zou3519 Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/5cc258ec9ea15951384fc35e62006691b2935915 Reviewed By: seemethere Differential Revision: D36783092 Pulled By: bdhirsh fbshipit-source-id: f8069095d5755136ee589f3cde1198abac57c0c7

The code for
block_diagisn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.We can't efficiently make
block_diagcomposite compliant though, because it performsO(num_inputs)mutations, and converting them all into out-of-place calls would be very inefficient.Instead, I ended up making
block_diagCompositeExplicitAutograd, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.Stack from ghstack: