Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

add batching rule for block_diag, kill DECOMPOSE_FUNCTIONAL#814

Merged
zou3519 merged 2 commits intomainfrom
block_diag_fix
May 31, 2022
Merged

add batching rule for block_diag, kill DECOMPOSE_FUNCTIONAL#814
zou3519 merged 2 commits intomainfrom
block_diag_fix

Conversation

@bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented May 18, 2022

Companion core PR: pytorch/pytorch#77716

The above PR makes block_diag composite compliant, and this PR adds a batching rule for it.

Those two changes together should let us fully remove the DECOMPOSE_FUNCTIONAL macro, which was preventing me from moving the Functionalize dispatch key below FuncTorchBatched (which I want to do as part of XX, in order to properly get functionalization working with LTC/XLA).

}
auto result = at::cat(batched_outputs);
return physical_views[0].getPhysicalToLogicalMap().apply(result);
}
Copy link
Contributor Author

@bdhirsh bdhirsh May 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced that I actually implemented this correctly (particularly in the case where you have multiple layers of vmap), but I wasn't sure what the right API's to use were. I got op info tests to pass though.

I also wasn't sure how to actually make this batching rule fast, so I implemented it as a dummy for loop. It's probably still more efficient than the FUNCTIONAL_DECOMPOSE version though, since that functionalizes every intermediate copy_() which probably resulted in a bunch of large temporary tensors.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced that I actually implemented this correctly (particularly in the case where you have multiple layers of vmap), but I wasn't sure what the right API's to use were. I got op info tests to pass though.

This works with multiple layers of vmap. Each Interpreter in the DynamicLayer stack handles one vmap, so as long as we're calling pytorch composite operations it all works out :)

I also wasn't sure how to actually make this batching rule fast, so I implemented it as a dummy for loop. It's probably still more efficient than the FUNCTIONAL_DECOMPOSE version though, since that functionalizes every intermediate copy_() which probably resulted in a bunch of large temporary tensors.

Yeah there isn't a more efficient way to implement this. If we want this to go faster then we would need a "batched block_diag" operator in pytorch/pytorch that implements the behavior and is easier to write a batching rule for. We should leave a comment here about that for whoever comes along in the future

@bdhirsh bdhirsh requested a review from zou3519 May 18, 2022 03:18
bdhirsh added a commit to pytorch/pytorch that referenced this pull request May 18, 2022
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this pull request May 18, 2022
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this pull request May 19, 2022
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this pull request May 19, 2022
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this pull request May 19, 2022
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this pull request May 19, 2022
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this pull request May 19, 2022
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this pull request May 19, 2022
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this pull request May 20, 2022
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this pull request May 20, 2022
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
return physical_views[0].getPhysicalToLogicalMap().apply(result);
}

Tensor block_diag_batching_rule(TensorList tensors) {
Copy link
Contributor

@zou3519 zou3519 May 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(no action required) For some more context... BatchingRegistrations.cpp is for all the legacy batching rules and is "deprecated", but since we haven't actually gotten TensorList inputs to work with the new batching rule API, this is indeed our only option for now. The downside is that the legacy API is a bit difficult to work with and all the documentation for it is misleading, as you've probably discovered here

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. We should probably wait for the pytorch-side change to get merged

Comment on lines +570 to +571
// Implementing this as a dummy for loop for now, since I'm not sure how to do it any better.
// I'm probably not accounting for potentially multiple batched dimensions?
Copy link
Contributor

@zou3519 zou3519 May 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When writing a batching rule it's safe to assume there is only a single layer of vmap. DynamicLayerStack handles the case where there are multiple layers of vmap. This is contrary to all the documentation in VmapTransforms.h which was written back in the world when DynamicLayerStack didn't exist

}
auto result = at::cat(batched_outputs);
return physical_views[0].getPhysicalToLogicalMap().apply(result);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced that I actually implemented this correctly (particularly in the case where you have multiple layers of vmap), but I wasn't sure what the right API's to use were. I got op info tests to pass though.

This works with multiple layers of vmap. Each Interpreter in the DynamicLayer stack handles one vmap, so as long as we're calling pytorch composite operations it all works out :)

I also wasn't sure how to actually make this batching rule fast, so I implemented it as a dummy for loop. It's probably still more efficient than the FUNCTIONAL_DECOMPOSE version though, since that functionalizes every intermediate copy_() which probably resulted in a bunch of large temporary tensors.

Yeah there isn't a more efficient way to implement this. If we want this to go faster then we would need a "batched block_diag" operator in pytorch/pytorch that implements the behavior and is easier to write a batching rule for. We should leave a comment here about that for whoever comes along in the future

bdhirsh added a commit to pytorch/pytorch that referenced this pull request May 20, 2022
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this pull request May 20, 2022
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this pull request May 23, 2022
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this pull request May 23, 2022
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this pull request May 24, 2022
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this pull request May 24, 2022
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this pull request May 24, 2022
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this pull request May 24, 2022
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this pull request May 25, 2022
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this pull request May 25, 2022
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this pull request May 25, 2022
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this pull request May 25, 2022
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this pull request May 25, 2022
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this pull request May 25, 2022
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this pull request May 25, 2022
…backends"

Need this to get functionalize to work with backends (LTC/XLA). Now that we can kill the `DECOMPOSE_FUNCTIONAL` code in functorch (see pytorch/functorch#814), this should be ok to land once that PR merges.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this pull request May 25, 2022
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this pull request May 25, 2022
The code for `block_diag` isn't composite compliant today - functorch deals with that by registering a special "conditional functionalization" kernel, but I want to kill that here: pytorch/functorch#814.

We can't efficiently make `block_diag` composite compliant though, because it performs `O(num_inputs)` mutations, and converting them all into out-of-place calls would be very inefficient.

Instead, I ended up making `block_diag` `CompositeExplicitAutograd`, and writing a derivative formula for it. That also ended up fixing some OpInfos tests.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this pull request May 25, 2022
Need this to get functionalize to work with backends (LTC/XLA). Now that we can kill the `DECOMPOSE_FUNCTIONAL` code in functorch (see pytorch/functorch#814), this should be ok to land once that PR merges.




[ghstack-poisoned]
@zou3519
Copy link
Contributor

zou3519 commented May 26, 2022

Block diag tests seem to be failing: https://app.circleci.com/pipelines/github/pytorch/functorch/2886/workflows/bfc11336-cb2d-4206-a6ca-7132a4e2f204/jobs/19966/tests

But also, it looks like this PR has unrelated commits in it

@bdhirsh
Copy link
Contributor Author

bdhirsh commented May 26, 2022

hmm I may have mucked up a rebase the last time. Taking a look

@bdhirsh bdhirsh force-pushed the block_diag_fix branch 2 times, most recently from a5e5800 to 53836fb Compare May 26, 2022 16:10
@bdhirsh
Copy link
Contributor Author

bdhirsh commented May 26, 2022

Had to make a quick fix locally (I had tested test_ops.py but forgot to test test_vmap.py locally) - tests should be passing now but I'll let CI run to be safe.

@bdhirsh
Copy link
Contributor Author

bdhirsh commented May 26, 2022

Failures so far look unrelated (looks like they're coming from the addr PR here: pytorch/pytorch@a1765f0)

@samdow
Copy link
Contributor

samdow commented May 26, 2022

Wait sorry I'm still seeing a block_diag failure:

https://app.circleci.com/pipelines/github/pytorch/functorch/2895/workflows/d67e1bce-1557-410c-8a10-b696f109e0e9/jobs/20105
specifically for
test_jvpvjp_block_diag_cpu_float32
test_vmapjvp_block_diag_cpu_float32
test_vmapjvpall_has_batch_rule_block_diag_cpu_float32
test_vmapvjp_block_diag_cpu_float32
test_vmapvjp_has_batch_rule_block_diag_cpu_float32

With that being sad, sorry about the addr failures. There should be 3 of them on cpu (6 on cuda) and you're definitely right that they're unrelated. Currently trying to decide if it's worth to xfail those or update the testing infra for nan inputs

@bdhirsh
Copy link
Contributor Author

bdhirsh commented May 26, 2022

Welp thanks. Weird, those tests all pass for me locally...

I removed a bunch of existing xfail's this morning - I'll try adding them back and re-running the CI.

@samdow
Copy link
Contributor

samdow commented May 26, 2022

Welp thanks. Weird, those tests all pass for me locally...

Might be from this 😭 😭 I eagerly look forward to dropping this into pytorch/pytorch so we aren't making people deal with this CI junk..

@bdhirsh
Copy link
Contributor Author

bdhirsh commented May 26, 2022

Staring at the log output, I think I see 3 remaining failing block diag tests:

TestOperatorsCPU.test_vmapjvpall_addr_cpu_float32
TestOperatorsCPU. test_jvpvjp_block_diag_cpu_float32
TestOperatorsCUDA.test_jvpvjp_block_diag_cuda_float32

What's weird is that when I pull this PR locally and build it against a fresh copy of master I don't see the same errors. The only failures are unexpected successes (which I could remove the xfails for, but they don't seem to be passing on CI).

I'm not really sure what's causing the discrepancy in CI :( Would it be reasonable to merge the PR in a way that passes locally, and see what the final version of CI looks like on main afterwards?

pytest test/test_ops.py test/test_vmap.py -k "block_diag"
...
================================================================== FAILURES ===================================================================
____________________________________________ TestOperatorsCPU.test_vmapjvp_block_diag_cpu_float32 _____________________________________________
Unexpected success
___________________________________________ TestOperatorsCPU.test_vmapjvpall_block_diag_cpu_float32 ___________________________________________
Unexpected success
___________________________________ TestOperatorsCPU.test_vmapjvpall_has_batch_rule_block_diag_cpu_float32 ____________________________________
Unexpected success
____________________________________________ TestOperatorsCPU.test_vmapvjp_block_diag_cpu_float32 _____________________________________________
Unexpected success
_____________________________________ TestOperatorsCPU.test_vmapvjp_has_batch_rule_block_diag_cpu_float32 _____________________________________
Unexpected success
___________________________________________ TestOperatorsCUDA.test_vmapjvp_block_diag_cuda_float32 ____________________________________________
Unexpected success
__________________________________________ TestOperatorsCUDA.test_vmapjvpall_block_diag_cuda_float32 __________________________________________
Unexpected success
__________________________________ TestOperatorsCUDA.test_vmapjvpall_has_batch_rule_block_diag_cuda_float32 ___________________________________
Unexpected success
___________________________________________ TestOperatorsCUDA.test_vmapvjp_block_diag_cuda_float32 ____________________________________________
Unexpected success
____________________________________ TestOperatorsCUDA.test_vmapvjp_has_batch_rule_block_diag_cuda_float32 ____________________________________
Unexpected success

...
FAILED test/test_ops.py::TestOperatorsCPU::test_vmapjvp_block_diag_cpu_float32
FAILED test/test_ops.py::TestOperatorsCPU::test_vmapjvpall_block_diag_cpu_float32
FAILED test/test_ops.py::TestOperatorsCPU::test_vmapjvpall_has_batch_rule_block_diag_cpu_float32
FAILED test/test_ops.py::TestOperatorsCPU::test_vmapvjp_block_diag_cpu_float32
FAILED test/test_ops.py::TestOperatorsCPU::test_vmapvjp_has_batch_rule_block_diag_cpu_float32
FAILED test/test_ops.py::TestOperatorsCUDA::test_vmapjvp_block_diag_cuda_float32
FAILED test/test_ops.py::TestOperatorsCUDA::test_vmapjvpall_block_diag_cuda_float32
FAILED test/test_ops.py::TestOperatorsCUDA::test_vmapjvpall_has_batch_rule_block_diag_cuda_float32
FAILED test/test_ops.py::TestOperatorsCUDA::test_vmapvjp_block_diag_cuda_float32
FAILED test/test_ops.py::TestOperatorsCUDA::test_vmapvjp_has_batch_rule_block_diag_cuda_float32

@zou3519
Copy link
Contributor

zou3519 commented May 26, 2022

Hmm I can reproduce this locally. Will debug a bit and get back to you. How urgent is it to get this merged? IIRC this is blocking your move of the functionalize dispatch key?

@bdhirsh
Copy link
Contributor Author

bdhirsh commented May 26, 2022

Will debug a bit and get back to you.

Thanks!

How urgent is it to get this merged? IIRC this is blocking your move of the functionalize dispatch key?

Not super urgent - it's blocking my moving the functionalize key, which blocks me landing the LTC <> functionalize integration. But there are still a bunch of other CI failures in that integration that I'm working through (although having less stuff in the stack makes the failures easier to reason about).

@zou3519
Copy link
Contributor

zou3519 commented May 26, 2022

Will debug a bit and get back to you.

@bdhirsh after applying your PR to my local build with PyTorch master, I am seeing the same thing (there are a bunch of unexpected successes, but that is expected).

After applying your PR to my local build with PyTorch nightly, I am seeing the same failures reported in our CI (the block_diag failures)

functorch CI runs on the nightlies, so some of the changes in master that were necessary to get everything working (or on the unexpected success state) haven't made it to the nightlies yet.

If you're in a rush, I think we can merge your dispatch key move PR in pytorch/pytorch (I believe that won't break the functorch build in fbcode; if it does someone will revert it). It will break these tests (but hey, these tests are already broken, and no one has complained yet :D)

If you're not in a rush, we should wait until the next business day (for the nightlies to update), rebase this PR & fix the unexpected successes, and the CI should be green.

@bdhirsh
Copy link
Contributor Author

bdhirsh commented May 26, 2022

If you're not in a rush, we should wait until the next business day (for the nightlies to update), rebase this PR & fix the unexpected successes, and the CI should be green.

Oof, somehow forgot about the [local master] vs [CI nightly] discrepancy again. Thanks for checking. Waiting until tomorrow and rebasing sounds fine to me, I'll go ahead and do that

@zou3519
Copy link
Contributor

zou3519 commented May 31, 2022

build+test passed locally for me (minus some xfails I will be adding soon), the code LGTM as well so let's merge

@zou3519 zou3519 merged commit abe4c4d into main May 31, 2022
zou3519 pushed a commit to zou3519/pytorch that referenced this pull request Jul 20, 2022
…AL (pytorch/functorch#814)

* add batching rule for block_diag, kill DECOMPOSE_FUNCTIONAL

* add batching rule for block_diag, kill DECOMPOSE_FUNCTIONAL
bigfootjon pushed a commit to pytorch/pytorch that referenced this pull request Jul 21, 2022
…AL (pytorch/functorch#814)

* add batching rule for block_diag, kill DECOMPOSE_FUNCTIONAL

* add batching rule for block_diag, kill DECOMPOSE_FUNCTIONAL
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants