Skip to content

fix inference mode / PyDispatcher / Functionalize interaction#103275

Closed
bdhirsh wants to merge 4 commits intogh/bdhirsh/425/basefrom
gh/bdhirsh/425/head
Closed

fix inference mode / PyDispatcher / Functionalize interaction#103275
bdhirsh wants to merge 4 commits intogh/bdhirsh/425/basefrom
gh/bdhirsh/425/head

Conversation

@bdhirsh
Copy link
Collaborator

@bdhirsh bdhirsh commented Jun 8, 2023

Fixes #103132

This is kind of annoying: Functionalization (and also vmap, I think?) manually figures out which ops have C++ CompositeImplicit decomps, and directly registers them to the Functionalize key. This is a problem for the PyDispatcher: We normally want the PyDispatcher to take precedence over the regular dispatcher. But in this case, we have a python decomp registered to CompositeImplicitAutograd, and a C++ decomp registered directly to the Functionalize key, so the C++ decomp gets precedence over the python decomp.

The way this showed up was that a model was running matmul() under inference mode, so we never hit the autograd dispatch key, and go straight to the functionalize dispatch key. Matmul has both a python decomp and a c++ decomp, but we were running the C++ decomp. That C++ decomp isn't meant to be used with dynamic shapes, so we were failing with the "tried to call .sizes() on a tensor with dynamic shapes" error.

For now, I had the PyDispatcher mimic the behavior of functionalization codegen: when you register a python decomp to the CompositeImplicitAutograd key, this PR just automatically registers that decomp to the Functionalize key at the same time.

I'm trying to remember now why we didn't just add Functionalize (and all of the other functorch transform keys) directly to the CompositeImplicitAutograd alias keyset, but I couldn't remember (@zou3519 any chance you remember?).

Stack from ghstack (oldest at bottom):

cc @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @aakhundov

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 8, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/103275

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8871c77:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OHHH this is a good catch

torch/_ops.py Outdated
# to use it, instead of the C++ decomp. We can't though, because Functionalize
# isn't part of the CompositeImplicitAutograd alias set.
# (open quesetion: will we eventually need to do this for functorch transform keys too?)
self.py_kernels[torch._C.DispatchKey.Functionalize] = fn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A less hacky version of this would be to modify our py_impl CompositeImplicitAutograd sites to call some higher level function which takes care of doing both registrations. This would be good because the hack as written violates the invariant that one py_impl does one registration.

Copy link
Collaborator Author

@bdhirsh bdhirsh Jun 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm the only thing I'm worried about is that people might forget to use that new API, and continuing to register future ops with py_impl(CompositeImplicitAutograd) (which will now always be subtly wrong). If you're not worried about that though, then I can change it (or.. maybe just have py_impl() give you a nice error message if you pass in that key?)

torch/_ops.py Outdated
f"Trying to override a python impl for {k} on operator {self.name()}"
)
self.py_kernels[k] = fn
if k == torch._C.DispatchKey.CompositeImplicitAutograd and torch._C.DispatchKey.Functionalize not in self.py_kernels:
Copy link
Contributor

@zou3519 zou3519 Jun 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to remember now why we didn't just add Functionalize (and all of the other functorch transform keys) directly to the CompositeImplicitAutograd alias keyset, but I couldn't remember (@zou3519 any chance you remember?).

This PR is good as-is, but if we want to be less hacky we should just make Functionalize a part of the CompositeImplicitAutograd alias keyset (and maybe we should rename it since it is now more than just autograd?)

Not all CompositeImplicitAutograd operations work with vmap (due to "not preserve Tensor subclass-ness"), which is why vmap isn't there. Although we've fixed most of these cases, I am wary of actually adding vmap to the CompositeImplicitAutograd set because it is unclear what % of aten ops our OpInfos actually cover.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This suggestion, if it works, is better!

Copy link
Collaborator Author

@bdhirsh bdhirsh Jun 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed - I'll try it and see if there's fallout.

There was actually a comment about it (hooray!) here. At the time, at::ones() and friends were all CompositeImplicitAutograd, and decomposed into empty() + fill_(). I think we didn't want functionalization to decompose ones() since it would result in a bunch of unnecessary functionalization logic running. But this shouldn't really matter anymore, since those factory functions all got changed to be CompositeExplicitAutograd.

…ion"

Fixes #103132

This is kind of annoying: Functionalization (and also vmap, I think?) manually figures out which ops have C++ CompositeImplicit decomps, and directly registers them to the Functionalize key. This is a problem for the PyDispatcher: We normally want the PyDispatcher to take precedence over the regular dispatcher. But in this case, we have a python decomp registered to `CompositeImplicitAutograd`, and a C++ decomp registered *directly* to the `Functionalize` key, so the C++ decomp gets precedence over the python decomp.

The way this showed up was that a model was running `matmul()` under inference mode, so we never hit the autograd dispatch key, and go straight to the functionalize dispatch key. Matmul has both a python decomp and a c++ decomp, but we were running the C++ decomp. That C++ decomp isn't meant to be used with dynamic shapes, so we were failing with the "tried to call `.sizes()` on a tensor with dynamic shapes" error.

For now, I had the PyDispatcher mimic the behavior of functionalization codegen: when you register a python decomp to the `CompositeImplicitAutograd` key, this PR just automatically registers that decomp to the `Functionalize` key at the same time.

I'm trying to remember now why we didn't just add `Functionalize` (and all of the other functorch transform keys) directly to the `CompositeImplicitAutograd` alias keyset, but I couldn't remember (zou3519 any chance you remember?).




cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy aakhundov

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Jun 9, 2023
…ion"

Fixes #103132

This is kind of annoying: Functionalization (and also vmap, I think?) manually figures out which ops have C++ CompositeImplicit decomps, and directly registers them to the Functionalize key. This is a problem for the PyDispatcher: We normally want the PyDispatcher to take precedence over the regular dispatcher. But in this case, we have a python decomp registered to `CompositeImplicitAutograd`, and a C++ decomp registered *directly* to the `Functionalize` key, so the C++ decomp gets precedence over the python decomp.

The way this showed up was that a model was running `matmul()` under inference mode, so we never hit the autograd dispatch key, and go straight to the functionalize dispatch key. Matmul has both a python decomp and a c++ decomp, but we were running the C++ decomp. That C++ decomp isn't meant to be used with dynamic shapes, so we were failing with the "tried to call `.sizes()` on a tensor with dynamic shapes" error.

For now, I had the PyDispatcher mimic the behavior of functionalization codegen: when you register a python decomp to the `CompositeImplicitAutograd` key, this PR just automatically registers that decomp to the `Functionalize` key at the same time.

I'm trying to remember now why we didn't just add `Functionalize` (and all of the other functorch transform keys) directly to the `CompositeImplicitAutograd` alias keyset, but I couldn't remember (zou3519 any chance you remember?).




cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy aakhundov

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Jun 9, 2023
@bdhirsh bdhirsh added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 20, 2023
@bdhirsh
Copy link
Collaborator Author

bdhirsh commented Jun 20, 2023

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

…ion"

Fixes #103132

This is kind of annoying: Functionalization (and also vmap, I think?) manually figures out which ops have C++ CompositeImplicit decomps, and directly registers them to the Functionalize key. This is a problem for the PyDispatcher: We normally want the PyDispatcher to take precedence over the regular dispatcher. But in this case, we have a python decomp registered to `CompositeImplicitAutograd`, and a C++ decomp registered *directly* to the `Functionalize` key, so the C++ decomp gets precedence over the python decomp.

The way this showed up was that a model was running `matmul()` under inference mode, so we never hit the autograd dispatch key, and go straight to the functionalize dispatch key. Matmul has both a python decomp and a c++ decomp, but we were running the C++ decomp. That C++ decomp isn't meant to be used with dynamic shapes, so we were failing with the "tried to call `.sizes()` on a tensor with dynamic shapes" error.

For now, I had the PyDispatcher mimic the behavior of functionalization codegen: when you register a python decomp to the `CompositeImplicitAutograd` key, this PR just automatically registers that decomp to the `Functionalize` key at the same time.

I'm trying to remember now why we didn't just add `Functionalize` (and all of the other functorch transform keys) directly to the `CompositeImplicitAutograd` alias keyset, but I couldn't remember (zou3519 any chance you remember?).




cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Successfully rebased gh/bdhirsh/425/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/103275)

pytorchmergebot pushed a commit that referenced this pull request Jun 20, 2023
@bdhirsh
Copy link
Collaborator Author

bdhirsh commented Jun 21, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@bdhirsh
Copy link
Collaborator Author

bdhirsh commented Jun 21, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@bdhirsh bdhirsh added the release notes: composability release notes category label Jun 21, 2023
@bdhirsh
Copy link
Collaborator Author

bdhirsh commented Jun 21, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/bdhirsh/425/head branch June 25, 2023 14:16
pytorchmergebot pushed a commit that referenced this pull request Oct 9, 2025
…'t have decomposed (#164939)

This fixes AOTAutograd rms_norm not being bitwise equivalent to
eager, because it avoids a decomposition.  You can force the
decomposition by having the decomposition in the dispatch table,
but if eager mode wouldn't have decomposed (because it went to the fused
one), we now default to preserving the fused call by default.

This largely reverts #103275 for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel.

Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: #164939
Approved by: https://github.com/bdhirsh
ghstack dependencies: #164573
pytorchmergebot pushed a commit that referenced this pull request Oct 10, 2025
…'t have decomposed (#164939)

This fixes AOTAutograd rms_norm not being bitwise equivalent to
eager, because it avoids a decomposition.  You can force the
decomposition by having the decomposition in the dispatch table,
but if eager mode wouldn't have decomposed (because it went to the fused
one), we now default to preserving the fused call by default.

This largely reverts #103275 for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel.

Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: #164939
Approved by: https://github.com/bdhirsh
pytorchmergebot pushed a commit that referenced this pull request Oct 11, 2025
…'t have decomposed (#164939)

This fixes AOTAutograd rms_norm not being bitwise equivalent to
eager, because it avoids a decomposition.  You can force the
decomposition by having the decomposition in the dispatch table,
but if eager mode wouldn't have decomposed (because it went to the fused
one), we now default to preserving the fused call by default.

This largely reverts #103275 for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel.

Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: #164939
Approved by: https://github.com/bdhirsh
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
…'t have decomposed (pytorch#164939)

This fixes AOTAutograd rms_norm not being bitwise equivalent to
eager, because it avoids a decomposition.  You can force the
decomposition by having the decomposition in the dispatch table,
but if eager mode wouldn't have decomposed (because it went to the fused
one), we now default to preserving the fused call by default.

This largely reverts pytorch#103275 for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel.

Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: pytorch#164939
Approved by: https://github.com/bdhirsh
ghstack dependencies: pytorch#164573
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
…'t have decomposed (pytorch#164939)

This fixes AOTAutograd rms_norm not being bitwise equivalent to
eager, because it avoids a decomposition.  You can force the
decomposition by having the decomposition in the dispatch table,
but if eager mode wouldn't have decomposed (because it went to the fused
one), we now default to preserving the fused call by default.

This largely reverts pytorch#103275 for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel.

Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: pytorch#164939
Approved by: https://github.com/bdhirsh
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
…'t have decomposed (pytorch#164939)

This fixes AOTAutograd rms_norm not being bitwise equivalent to
eager, because it avoids a decomposition.  You can force the
decomposition by having the decomposition in the dispatch table,
but if eager mode wouldn't have decomposed (because it went to the fused
one), we now default to preserving the fused call by default.

This largely reverts pytorch#103275 for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel.

Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: pytorch#164939
Approved by: https://github.com/bdhirsh
pianpwk pushed a commit that referenced this pull request Nov 4, 2025
…'t have decomposed (#164939)

This fixes AOTAutograd rms_norm not being bitwise equivalent to
eager, because it avoids a decomposition.  You can force the
decomposition by having the decomposition in the dispatch table,
but if eager mode wouldn't have decomposed (because it went to the fused
one), we now default to preserving the fused call by default.

This largely reverts #103275 for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel.

Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: #164939
Approved by: https://github.com/bdhirsh
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo release notes: composability release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants