fix maxpool2d for XLA dynamo tracing#4276
Conversation
Today when XLA registers an autograd.Function to the `AutogradXLA` key, the "add to the autograd graph" step and the "run the forward kernel" step happen all in one go. That's wrong, and prevents other dispatcher code from executing in the middle. When trying to fix this, I noticed a bug in the codegen: we register kernels for both the XLA and AutogradXLA dispatch keys to the same class. This prevents XLA from registering a separate kernel to the XLA and Autograd XLA, which is what this PR attempts to address. Companion patch to fix XLA's max_pool2d registration here, which was blocking the dynamo integration: pytorch/xla#4276 After this PR, XLA should generate two separate header files: `XLANativeFunctions.h`, and `AutogradXLANativeFunctions.h` Before, all of the kernels (including autograd kernels) would be thrown in `XLANativeFunctions.h`. cc JackCaoG [ghstack-poisoned]
|
@JackCaoG do you know why the option to re-run is grey'd out on the CI failure? Not sure if it's a permissions thing or something else. It looks like CI didn't pick up my torch pin or something, trying to kick it off again: |
|
Hmm, I was able to restart the CI. I thought you are admin so you have all permission, let me double check. |
|
hmm I think torch pin has taken effect in the full log |
69f54bb to
c4d9828
Compare
seems like something has to do with fallthrough kernel? The test is about pooling so I think this is a real failure. |
|
@JackCaoG I'm trying to rebuild XLA locally on my new devserver (I'm anticipating issues) - but I just stared at the code for a while and I think I know what the problem is. Just pushed it, so I'll see what the latest round of CI yields. @shunting314 I'll give you a shout when this PR looks ready to test - when it is, can you try re-running your dynamo-XLA integration with max_pool2d (both fw and bw) and confirm if there are issues? |
|
@bdhirsh sure, I'd be glad to do the tests |
|
Hey @shunting314, it looks like the max_pool2d unit tests are passing. I do see a failure in the XLA-dynamo tests, but it doesn't seem related to this change (?). Can you try running E2E tests again? |
|
rebase should fix the issue |
443f718 to
7bf44ba
Compare
|
@bdhirsh you might also need to rebase your pytorch pr. |
…ops" Today when XLA registers an autograd.Function to the `AutogradXLA` key, the "add to the autograd graph" step and the "run the forward kernel" step happen all in one go. That's wrong, and prevents other dispatcher code from executing in the middle. When trying to fix this, I noticed a bug in the codegen: we register kernels for both the XLA and AutogradXLA dispatch keys to the same class. This prevents XLA from registering a separate kernel to the XLA and Autograd XLA, which is what this PR attempts to address. Companion patch to fix XLA's max_pool2d registration here, which was blocking the dynamo integration: pytorch/xla#4276 After this PR, XLA should generate two separate header files: `XLANativeFunctions.h`, and `AutogradXLANativeFunctions.h` Before, all of the kernels (including autograd kernels) would be thrown in `XLANativeFunctions.h`. cc JackCaoG [ghstack-poisoned]
Today when XLA registers an autograd.Function to the `AutogradXLA` key, the "add to the autograd graph" step and the "run the forward kernel" step happen all in one go. That's wrong, and prevents other dispatcher code from executing in the middle. When trying to fix this, I noticed a bug in the codegen: we register kernels for both the XLA and AutogradXLA dispatch keys to the same class. This prevents XLA from registering a separate kernel to the XLA and Autograd XLA, which is what this PR attempts to address. Companion patch to fix XLA's max_pool2d registration here, which was blocking the dynamo integration: pytorch/xla#4276 After this PR, XLA should generate two separate header files: `XLANativeFunctions.h`, and `AutogradXLANativeFunctions.h` Before, all of the kernels (including autograd kernels) would be thrown in `XLANativeFunctions.h`. cc JackCaoG [ghstack-poisoned]
|
yep - done |
|
@bdhirsh is this the only PR i need patch? ( I previously saw you have 2 related PRs?) |
|
Do I need patch ' pytorch/pytorch#90226 ' as well ? |
|
yes sorry - you'll need both :) |
|
@bdhirsh I seed the following errors when building torchxla |
|
oh, nvm, let my patch the pytorch side PR as well |
|
@bdhirsh I still see the issue after patching this PR and the corresponding PR on pytorch side. I've created a standalone tests without the need of patching my PR. You can repro using this simple script in your environment: pytorch/torchdynamo#1837 (comment) |
8278bc4 to
176b704
Compare
| c10::DispatchKey::Conjugate, | ||
| c10::DispatchKey::Negative, | ||
| c10::DispatchKey::ZeroTensor, | ||
| c10::DispatchKey::ADInplaceOrView, |
There was a problem hiding this comment.
Oh we should take out ADInplaceOrView from here
98d1059 to
83876d5
Compare
83876d5 to
47f6b54
Compare
|
Hey @JackCaoG - do you mind finishing up landing? @shunting314 confirmed that this fixes the E2E tests for max_pool2d |

waiting for CI before getting review