Skip to content

Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed#164939

Closed
ezyang wants to merge 11 commits intogh/ezyang/3171/basefrom
gh/ezyang/3171/head
Closed

Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed#164939
ezyang wants to merge 11 commits intogh/ezyang/3171/basefrom
gh/ezyang/3171/head

Conversation

@ezyang
Copy link
Contributor

@ezyang ezyang commented Oct 8, 2025

Stack from ghstack (oldest at bottom):

This fixes AOTAutograd rms_norm not being bitwise equivalent to
eager, because it avoids a decomposition. You can force the
decomposition by having the decomposition in the dispatch table,
but if eager mode wouldn't have decomposed (because it went to the fused
one), we now default to preserving the fused call by default.

This largely reverts #103275 for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel.

Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name.

Signed-off-by: Edward Z. Yang ezyang@meta.com

cc @EikanWang @jgong5 @wenzhe-nrv

[ghstack-poisoned]
@ezyang ezyang requested a review from Chillee as a code owner October 8, 2025 15:14
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 8, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164939

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 4 Pending, 2 Unrelated Failures

As of commit 7c9ddc1 with merge base 4f8a986 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

ezyang added a commit that referenced this pull request Oct 8, 2025
…'t have decomposed

This fixes AOTAutograd rms_norm not being bitwise equivalent to
eager, because it avoids a decomposition.  You can force the
decomposition by having the decomposition in the dispatch table,
but if eager mode wouldn't have decomposed (because it went to the fused
one), we now default to preserving the fused call by default.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
ghstack-source-id: b35ae76
Pull-Request: #164939
@pytorch-bot pytorch-bot bot added ciflow/inductor release notes: fx release notes category labels Oct 8, 2025
@ezyang ezyang requested a review from ngimel October 8, 2025 15:14
DispatchKeySet{DispatchKey::NestedTensor} |
// Functionalize should always reuse CompositeImplicit decomps.
DispatchKeySet{DispatchKey::Functionalize};
DispatchKeySet{DispatchKey::NestedTensor};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i could imagine this wobbling some tests

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

specifically, with the full set of PR changes we can rely on python functionalization decomposing CIA ops. but if you are only running C++ functionalization, we will no longer decompose CIA ops. This might wobble tests?

it might also be a problem if you are running C++ only functionalization and you have a CIA decomp that desugars into mutations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You were indeed right.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You were indeed right.

Copy link
Collaborator

@bdhirsh bdhirsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamp to unblock - sounds good to me if tests pass

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Oct 8, 2025
…'t have decomposed

This fixes AOTAutograd rms_norm not being bitwise equivalent to
eager, because it avoids a decomposition.  You can force the
decomposition by having the decomposition in the dispatch table,
but if eager mode wouldn't have decomposed (because it went to the fused
one), we now default to preserving the fused call by default.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
ghstack-source-id: 5765a31
Pull-Request: #164939
[ghstack-poisoned]
ezyang added a commit that referenced this pull request Oct 8, 2025
…'t have decomposed

This fixes AOTAutograd rms_norm not being bitwise equivalent to
eager, because it avoids a decomposition.  You can force the
decomposition by having the decomposition in the dispatch table,
but if eager mode wouldn't have decomposed (because it went to the fused
one), we now default to preserving the fused call by default.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
ghstack-source-id: 5326bc5
Pull-Request: #164939
@albanD albanD removed their request for review October 8, 2025 16:57
[ghstack-poisoned]
ezyang added a commit that referenced this pull request Oct 8, 2025
…'t have decomposed

This fixes AOTAutograd rms_norm not being bitwise equivalent to
eager, because it avoids a decomposition.  You can force the
decomposition by having the decomposition in the dispatch table,
but if eager mode wouldn't have decomposed (because it went to the fused
one), we now default to preserving the fused call by default.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
ghstack-source-id: 9c16fe8
Pull-Request: #164939
[ghstack-poisoned]
ezyang added a commit that referenced this pull request Oct 8, 2025
…'t have decomposed

This fixes AOTAutograd rms_norm not being bitwise equivalent to
eager, because it avoids a decomposition.  You can force the
decomposition by having the decomposition in the dispatch table,
but if eager mode wouldn't have decomposed (because it went to the fused
one), we now default to preserving the fused call by default.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
ghstack-source-id: e82980e
Pull-Request: #164939
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: inductor / inductor-test / test (inductor_torchbench, 2, 2, linux.g5.4xlarge.nvidia.gpu)

Details for Dev Infra team Raised by workflow job

@ezyang
Copy link
Contributor Author

ezyang commented Oct 11, 2025

@pytorchbot merge -f "unrelated problems"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

bdhirsh added a commit that referenced this pull request Oct 13, 2025
…n in AOTDispatcher"

I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup




cc ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 13, 2025
I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup




cc ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 13, 2025
…n in AOTDispatcher"

I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup




cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 13, 2025
I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup




cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 13, 2025
…n in AOTDispatcher"

I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup




cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 13, 2025
I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup




cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 14, 2025
…n in AOTDispatcher"

I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup




cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 14, 2025
I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup




cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 14, 2025
…n in AOTDispatcher"

I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup




cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 14, 2025
I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup




cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 14, 2025
…n in AOTDispatcher"

I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup




cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 14, 2025
I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup




cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 15, 2025
…n in AOTDispatcher"

I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup




cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 15, 2025
I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup




cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 15, 2025
…n in AOTDispatcher"

I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup




cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 15, 2025
I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup




cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 15, 2025
…n in AOTDispatcher"

I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup




cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 15, 2025
I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: #164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup




cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request has been reverted by e6ba4d0. To re-land this change, please open another pull request, assignthe same reviewers, fix the CI failures that caused the revert and make sure that the failing CI runs on the PR by applying the proper ciflow label (e.g., ciflow/trunk).

rqrst_input = rqrst_input.contiguous()

# Cast normalized result back to original input type
result = upcasted_result.type_as(input)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is sus

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well the original does this so ok

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request fx Merged release notes: fx release notes category Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants