add batching rule for block_diag, kill DECOMPOSE_FUNCTIONAL#814
add batching rule for block_diag, kill DECOMPOSE_FUNCTIONAL#814
Conversation
| } | ||
| auto result = at::cat(batched_outputs); | ||
| return physical_views[0].getPhysicalToLogicalMap().apply(result); | ||
| } |
There was a problem hiding this comment.
I'm not convinced that I actually implemented this correctly (particularly in the case where you have multiple layers of vmap), but I wasn't sure what the right API's to use were. I got op info tests to pass though.
I also wasn't sure how to actually make this batching rule fast, so I implemented it as a dummy for loop. It's probably still more efficient than the FUNCTIONAL_DECOMPOSE version though, since that functionalizes every intermediate copy_() which probably resulted in a bunch of large temporary tensors.
There was a problem hiding this comment.
I'm not convinced that I actually implemented this correctly (particularly in the case where you have multiple layers of vmap), but I wasn't sure what the right API's to use were. I got op info tests to pass though.
This works with multiple layers of vmap. Each Interpreter in the DynamicLayer stack handles one vmap, so as long as we're calling pytorch composite operations it all works out :)
I also wasn't sure how to actually make this batching rule fast, so I implemented it as a dummy for loop. It's probably still more efficient than the FUNCTIONAL_DECOMPOSE version though, since that functionalizes every intermediate copy_() which probably resulted in a bunch of large temporary tensors.
Yeah there isn't a more efficient way to implement this. If we want this to go faster then we would need a "batched block_diag" operator in pytorch/pytorch that implements the behavior and is easier to write a batching rule for. We should leave a comment here about that for whoever comes along in the future
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
| return physical_views[0].getPhysicalToLogicalMap().apply(result); | ||
| } | ||
|
|
||
| Tensor block_diag_batching_rule(TensorList tensors) { |
There was a problem hiding this comment.
(no action required) For some more context... BatchingRegistrations.cpp is for all the legacy batching rules and is "deprecated", but since we haven't actually gotten TensorList inputs to work with the new batching rule API, this is indeed our only option for now. The downside is that the legacy API is a bit difficult to work with and all the documentation for it is misleading, as you've probably discovered here
zou3519
left a comment
There was a problem hiding this comment.
LGTM. We should probably wait for the pytorch-side change to get merged
| // Implementing this as a dummy for loop for now, since I'm not sure how to do it any better. | ||
| // I'm probably not accounting for potentially multiple batched dimensions? |
There was a problem hiding this comment.
When writing a batching rule it's safe to assume there is only a single layer of vmap. DynamicLayerStack handles the case where there are multiple layers of vmap. This is contrary to all the documentation in VmapTransforms.h which was written back in the world when DynamicLayerStack didn't exist
| } | ||
| auto result = at::cat(batched_outputs); | ||
| return physical_views[0].getPhysicalToLogicalMap().apply(result); | ||
| } |
There was a problem hiding this comment.
I'm not convinced that I actually implemented this correctly (particularly in the case where you have multiple layers of vmap), but I wasn't sure what the right API's to use were. I got op info tests to pass though.
This works with multiple layers of vmap. Each Interpreter in the DynamicLayer stack handles one vmap, so as long as we're calling pytorch composite operations it all works out :)
I also wasn't sure how to actually make this batching rule fast, so I implemented it as a dummy for loop. It's probably still more efficient than the FUNCTIONAL_DECOMPOSE version though, since that functionalizes every intermediate copy_() which probably resulted in a bunch of large temporary tensors.
Yeah there isn't a more efficient way to implement this. If we want this to go faster then we would need a "batched block_diag" operator in pytorch/pytorch that implements the behavior and is easier to write a batching rule for. We should leave a comment here about that for whoever comes along in the future
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
…backends" Need this to get functionalize to work with backends (LTC/XLA). Now that we can kill the `DECOMPOSE_FUNCTIONAL` code in functorch (see pytorch/functorch#814), this should be ok to land once that PR merges. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814. We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient. Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests. [ghstack-poisoned]
Need this to get functionalize to work with backends (LTC/XLA). Now that we can kill the `DECOMPOSE_FUNCTIONAL` code in functorch (see pytorch/functorch#814), this should be ok to land once that PR merges. [ghstack-poisoned]
|
Block diag tests seem to be failing: https://app.circleci.com/pipelines/github/pytorch/functorch/2886/workflows/bfc11336-cb2d-4206-a6ca-7132a4e2f204/jobs/19966/tests But also, it looks like this PR has unrelated commits in it |
|
hmm I may have mucked up a rebase the last time. Taking a look |
a5e5800 to
53836fb
Compare
|
Had to make a quick fix locally (I had tested |
|
Failures so far look unrelated (looks like they're coming from the |
|
Wait sorry I'm still seeing a block_diag failure: https://app.circleci.com/pipelines/github/pytorch/functorch/2895/workflows/d67e1bce-1557-410c-8a10-b696f109e0e9/jobs/20105 With that being sad, sorry about the addr failures. There should be 3 of them on cpu (6 on cuda) and you're definitely right that they're unrelated. Currently trying to decide if it's worth to xfail those or update the testing infra for nan inputs |
|
Welp thanks. Weird, those tests all pass for me locally... I removed a bunch of existing xfail's this morning - I'll try adding them back and re-running the CI. |
Might be from this 😭 😭 I eagerly look forward to dropping this into pytorch/pytorch so we aren't making people deal with this CI junk.. |
|
Staring at the log output, I think I see 3 remaining failing block diag tests: What's weird is that when I pull this PR locally and build it against a fresh copy of master I don't see the same errors. The only failures are unexpected successes (which I could remove the xfails for, but they don't seem to be passing on CI). I'm not really sure what's causing the discrepancy in CI :( Would it be reasonable to merge the PR in a way that passes locally, and see what the final version of CI looks like on main afterwards? |
|
Hmm I can reproduce this locally. Will debug a bit and get back to you. How urgent is it to get this merged? IIRC this is blocking your move of the functionalize dispatch key? |
Thanks!
Not super urgent - it's blocking my moving the functionalize key, which blocks me landing the LTC <> functionalize integration. But there are still a bunch of other CI failures in that integration that I'm working through (although having less stuff in the stack makes the failures easier to reason about). |
@bdhirsh after applying your PR to my local build with PyTorch master, I am seeing the same thing (there are a bunch of unexpected successes, but that is expected). After applying your PR to my local build with PyTorch nightly, I am seeing the same failures reported in our CI (the block_diag failures) functorch CI runs on the nightlies, so some of the changes in master that were necessary to get everything working (or on the unexpected success state) haven't made it to the nightlies yet. If you're in a rush, I think we can merge your dispatch key move PR in pytorch/pytorch (I believe that won't break the functorch build in fbcode; if it does someone will revert it). It will break these tests (but hey, these tests are already broken, and no one has complained yet :D) If you're not in a rush, we should wait until the next business day (for the nightlies to update), rebase this PR & fix the unexpected successes, and the CI should be green. |
Oof, somehow forgot about the [local master] vs [CI nightly] discrepancy again. Thanks for checking. Waiting until tomorrow and rebasing sounds fine to me, I'll go ahead and do that |
|
build+test passed locally for me (minus some xfails I will be adding soon), the code LGTM as well so let's merge |
…AL (pytorch/functorch#814) * add batching rule for block_diag, kill DECOMPOSE_FUNCTIONAL * add batching rule for block_diag, kill DECOMPOSE_FUNCTIONAL
…AL (pytorch/functorch#814) * add batching rule for block_diag, kill DECOMPOSE_FUNCTIONAL * add batching rule for block_diag, kill DECOMPOSE_FUNCTIONAL
Companion core PR: pytorch/pytorch#77716
The above PR makes
block_diagcomposite compliant, and this PR adds a batching rule for it.Those two changes together should let us fully remove the
DECOMPOSE_FUNCTIONALmacro, which was preventing me from moving theFunctionalizedispatch key belowFuncTorchBatched(which I want to do as part of XX, in order to properly get functionalization working with LTC/XLA).