Skip to content

Move pre-grad passes after AOTAutograd cache lookup#176340

Closed
frgossen wants to merge 11 commits intogh/frgossen/2/basefrom
gh/frgossen/2/head
Closed

Move pre-grad passes after AOTAutograd cache lookup#176340
frgossen wants to merge 11 commits intogh/frgossen/2/basefrom
gh/frgossen/2/head

Conversation

@frgossen
Copy link
Contributor

@frgossen frgossen commented Mar 3, 2026

Stack from ghstack (oldest at bottom):

Previously, run_pre_grad_passes was called unconditionally at the top
of _compile_fx_main. This meant pre-grad transformations were not
included in cached artifacts and ran unnecessarily on cache hits.

Move pre-grad passes into aot_module_simplified (Path B) via a
callback so they run after the cache lookup — on cache miss only.

_compile_fx_main has two compilation paths that diverge at the
V.aot_compilation check: Path A uses aot_export_module (AOTInductor,
no cache) and Path B uses aot_autogradaot_module_simplified (with
AOTAutogradCache). Since Path A has no cache, run pre-grad passes
explicitly before aot_export_module.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo @Lucaskabela

Previously, `run_pre_grad_passes` was called in `compile_fx.py` before
`aot_module_simplified`, which meant pre-grad transformations were not
included in the cached artifacts. On cache hits, the passes would still
run unnecessarily.

This change:
- Adds an optional `pre_grad_passes` callback parameter to
  `aot_module_simplified` to avoid circular imports
- Calls the callback after the cache lookup (on cache miss only)
- Ensures pre-grad transformations are included in cached artifacts
- Adds a test to verify pre_grad_passes is skipped on cache hits

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 3, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/176340

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (3 Unrelated Failures)

As of commit 150f7fe with merge base e45dfba (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

frgossen added a commit that referenced this pull request Mar 3, 2026
Previously, `run_pre_grad_passes` was called in `compile_fx.py` before
`aot_module_simplified`, which meant pre-grad transformations were not
included in the cached artifacts. On cache hits, the passes would still
run unnecessarily.

This change:
- Adds an optional `pre_grad_passes` callback parameter to
  `aot_module_simplified` to avoid circular imports
- Calls the callback after the cache lookup (on cache miss only)
- Ensures pre-grad transformations are included in cached artifacts
- Adds a test to verify pre_grad_passes is skipped on cache hits

ghstack-source-id: ebc2e9b
Pull Request resolved: #176340
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 3, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@frgossen frgossen requested a review from zou3519 March 3, 2026 21:49
Previously, `run_pre_grad_passes` was called in `compile_fx.py` before
`aot_module_simplified`, which meant pre-grad transformations were not
included in the cached artifacts. On cache hits, the passes would still
run unnecessarily.

This change:
- Adds an optional `pre_grad_passes` callback parameter to
  `aot_module_simplified` to avoid circular imports
- Calls the callback after the cache lookup (on cache miss only)
- Ensures pre-grad transformations are included in cached artifacts
- Adds a test to verify pre_grad_passes is skipped on cache hits

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo Lucaskabela

[ghstack-poisoned]
frgossen added a commit that referenced this pull request Mar 4, 2026
Previously, `run_pre_grad_passes` was called in `compile_fx.py` before
`aot_module_simplified`, which meant pre-grad transformations were not
included in the cached artifacts. On cache hits, the passes would still
run unnecessarily.

This change:
- Adds an optional `pre_grad_passes` callback parameter to
  `aot_module_simplified` to avoid circular imports
- Calls the callback after the cache lookup (on cache miss only)
- Ensures pre-grad transformations are included in cached artifacts
- Adds a test to verify pre_grad_passes is skipped on cache hits

ghstack-source-id: e659fd5
Pull Request resolved: #176340
Previously, `run_pre_grad_passes` was called in `compile_fx.py` before
`aot_module_simplified`, which meant pre-grad transformations were not
included in the cached artifacts. On cache hits, the passes would still
run unnecessarily.

This change:
- Adds an optional `pre_grad_passes` callback parameter to
  `aot_module_simplified` to avoid circular imports
- Calls the callback after the cache lookup (on cache miss only)
- Ensures pre-grad transformations are included in cached artifacts
- Adds a test to verify pre_grad_passes is skipped on cache hits

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo Lucaskabela

[ghstack-poisoned]
frgossen added a commit that referenced this pull request Mar 4, 2026
Previously, `run_pre_grad_passes` was called in `compile_fx.py` before
`aot_module_simplified`, which meant pre-grad transformations were not
included in the cached artifacts. On cache hits, the passes would still
run unnecessarily.

This change:
- Adds an optional `pre_grad_passes` callback parameter to
  `aot_module_simplified` to avoid circular imports
- Calls the callback after the cache lookup (on cache miss only)
- Ensures pre-grad transformations are included in cached artifacts
- Adds a test to verify pre_grad_passes is skipped on cache hits

ghstack-source-id: 2db9aae
Pull Request resolved: #176340
@frgossen frgossen added the topic: not user facing topic category label Mar 4, 2026
Previously, `run_pre_grad_passes` was called in `compile_fx.py` before
`aot_module_simplified`, which meant pre-grad transformations were not
included in the cached artifacts. On cache hits, the passes would still
run unnecessarily.

This change:
- Adds an optional `pre_grad_passes` callback parameter to
  `aot_module_simplified` to avoid circular imports
- Calls the callback after the cache lookup (on cache miss only)
- Ensures pre-grad transformations are included in cached artifacts
- Adds a test to verify pre_grad_passes is skipped on cache hits

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo Lucaskabela

[ghstack-poisoned]
frgossen added a commit that referenced this pull request Mar 4, 2026
Previously, `run_pre_grad_passes` was called in `compile_fx.py` before
`aot_module_simplified`, which meant pre-grad transformations were not
included in the cached artifacts. On cache hits, the passes would still
run unnecessarily.

This change:
- Adds an optional `pre_grad_passes` callback parameter to
  `aot_module_simplified` to avoid circular imports
- Calls the callback after the cache lookup (on cache miss only)
- Ensures pre-grad transformations are included in cached artifacts
- Adds a test to verify pre_grad_passes is skipped on cache hits

ghstack-source-id: 3d81951
Pull Request resolved: #176340
Copy link
Contributor

@aorenste aorenste left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the additional changes this LGTM

frgossen added a commit that referenced this pull request Mar 5, 2026
Previously, `run_pre_grad_passes` was called unconditionally at the top
of `_compile_fx_main`.  This meant pre-grad transformations were not
included in cached artifacts and ran unnecessarily on cache hits.

Move pre-grad passes into `aot_module_simplified` (Path B) via a
callback so they run after the cache lookup — on cache miss only.

`_compile_fx_main` has two compilation paths that diverge at the
`V.aot_compilation` check: Path A uses `aot_export_module` (AOTInductor,
no cache) and Path B uses `aot_autograd` → `aot_module_simplified` (with
`AOTAutogradCache`).  Since Path A has no cache, run pre-grad passes
explicitly before `aot_export_module`.

ghstack-source-id: 47c36fa
Pull Request resolved: #176340
Previously, `run_pre_grad_passes` was called unconditionally at the top
of `_compile_fx_main`.  This meant pre-grad transformations were not
included in cached artifacts and ran unnecessarily on cache hits.

Move pre-grad passes into `aot_module_simplified` (Path B) via a
callback so they run after the cache lookup — on cache miss only.

`_compile_fx_main` has two compilation paths that diverge at the
`V.aot_compilation` check: Path A uses `aot_export_module` (AOTInductor,
no cache) and Path B uses `aot_autograd` → `aot_module_simplified` (with
`AOTAutogradCache`).  Since Path A has no cache, run pre-grad passes
explicitly before `aot_export_module`.

[ghstack-poisoned]
frgossen added a commit that referenced this pull request Mar 5, 2026
Previously, `run_pre_grad_passes` was called unconditionally at the top
of `_compile_fx_main`.  This meant pre-grad transformations were not
included in cached artifacts and ran unnecessarily on cache hits.

Move pre-grad passes into `aot_module_simplified` (Path B) via a
callback so they run after the cache lookup — on cache miss only.

`_compile_fx_main` has two compilation paths that diverge at the
`V.aot_compilation` check: Path A uses `aot_export_module` (AOTInductor,
no cache) and Path B uses `aot_autograd` → `aot_module_simplified` (with
`AOTAutogradCache`).  Since Path A has no cache, run pre-grad passes
explicitly before `aot_export_module`.

ghstack-source-id: 47c36fa
Pull Request resolved: #176340
@frgossen frgossen added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 5, 2026
@frgossen
Copy link
Contributor Author

frgossen commented Mar 5, 2026

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 2 checks: inductor / unit-test / inductor-test / test (inductor, 2, 2, linux.g5.4xlarge.nvidia.gpu), inductor / inductor-test / test (inductor_torchbench, 1, 2, linux.g5.4xlarge.nvidia.gpu)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / test (default, 3, 3, macos-m1-stable)

Details for Dev Infra team Raised by workflow job

Previously, `run_pre_grad_passes` was called unconditionally at the top
of `_compile_fx_main`.  This meant pre-grad transformations were not
included in cached artifacts and ran unnecessarily on cache hits.

Move pre-grad passes into `aot_module_simplified` (Path B) via a
callback so they run after the cache lookup — on cache miss only.

`_compile_fx_main` has two compilation paths that diverge at the
`V.aot_compilation` check: Path A uses `aot_export_module` (AOTInductor,
no cache) and Path B uses `aot_autograd` → `aot_module_simplified` (with
`AOTAutogradCache`).  Since Path A has no cache, run pre-grad passes
explicitly before `aot_export_module`.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo Lucaskabela

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorch-auto-revert
Copy link

@pytorchbot revert -m "Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable" -c autorevert

This PR is attributed to have caused regression in:

Please investigate and fix the issues.
@claude Can you please read this revert comment, follow the links and read the errors, to then give a brief diagnostics on the cause of the error? If you judge the error to be legitimate reason for a revert, please provide brief guidance on how the author could fix it.

@claude
Copy link

claude bot commented Mar 9, 2026

Claude finished @pytorch-auto-revert[bot]'s task in 5m 41s —— View job


Diagnosis of Revert for PR #176340

  • Read revert comment and identify failing test
  • Fetch error logs from CI
  • Read relevant source code changes
  • Diagnose root cause
  • Provide fix guidance

Failing Test

inductor/test_efficient_conv_bn_eval.py::EfficientConvBNEvalCpuTests::test_basic_cpu

The test fails consistently across all CI configurations (ASAN, clang15, CUDA 12.8, CUDA 13.0) — both CPU and GPU variants. The error is:

AssertionError: Scalars are not equal!
Expected 1 but got 0.

This means counters["inductor"]["efficient_conv_bn_eval"] never increments — the efficient_conv_bn_eval pre-grad pattern matcher pass is not running (or not matching).

Verdict: Legitimate Revert

This is a real regression, not a flake. The failure is deterministic and reproduced across every CI job that ran this test.

Root Cause

The PR moved run_pre_grad_passes from _compile_fx_main into aot_module_simplified (via a callback), calling it after the cache lookup. Inside aot_module_simplified, the call at torch/_functorch/aot_autograd.py:1156-1157 is:

if pre_grad_passes is not None and isinstance(mod, torch.fx.GraphModule):
    mod = pre_grad_passes(mod, fake_flat_args)

The problem is the second argument: fake_flat_args vs the original example_inputs_.

  • Old code: run_pre_grad_passes(model_, example_inputs_)example_inputs_ matches the graph's placeholder count (only user inputs).
  • New code: pre_grad_passes(mod, fake_flat_args)fake_flat_args is constructed by prepare_aot_module_simplified as [params..., buffers..., *user_inputs], which has more elements than the graph's placeholders.

The pre_grad_passes flow calls fuse_fx, which calls ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs) (torch/_inductor/fx_passes/pre_grad.py:399). With the mismatched argument count (fake_flat_args being longer than the graph expects), ShapeProp will either error out or produce incorrect results, preventing the efficient_conv_bn_eval pattern from matching.

How to Fix

The pre-grad passes callback should be called with mod's original inputs (matching the graph's placeholders), not fake_flat_args. Two approaches:

  1. Pass original example inputs to the callback: In aot_module_simplified, call pre_grad_passes(mod, args) instead of pre_grad_passes(mod, fake_flat_args), where args are the original args passed to aot_module_simplified (before params/buffers are prepended).

  2. Extract user inputs from fake_flat_args: Slice fake_flat_args to exclude the prepended params/buffers:

    user_inputs = fake_flat_args[params_len + buffers_len:]
    mod = pre_grad_passes(mod, user_inputs)

Option 2 is more precise since user_inputs would be properly fakified. The variables params_len (from params_flat) and buffers_len (from buffers_flat) are computed earlier in prepare_aot_module_simplified and would need to be made available at the call site (e.g., returned from prepare_aot_module_simplified or computed locally).

Additionally, for GmWrapper cases (when flatten_graph_inputs wraps the module), you may need to unwrap to the inner gm or handle GmWrapper explicitly in the isinstance check, though this path is less common for Dynamo-produced graphs.


@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Mar 9, 2026
This reverts commit ff91f31.

Reverted #176340 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](#176340 (comment)))
@pytorchmergebot
Copy link
Collaborator

@frgossen your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Mar 9, 2026
frgossen added a commit that referenced this pull request Mar 9, 2026
frgossen added a commit that referenced this pull request Mar 9, 2026
This reverts commit 3087999.

ghstack-source-id: b406a66
Pull Request resolved: #176891
Previously, `run_pre_grad_passes` was called unconditionally at the top
of `_compile_fx_main`.  This meant pre-grad transformations were not
included in cached artifacts and ran unnecessarily on cache hits.

Move pre-grad passes into `aot_module_simplified` (Path B) via a
callback so they run after the cache lookup — on cache miss only.

`_compile_fx_main` has two compilation paths that diverge at the
`V.aot_compilation` check: Path A uses `aot_export_module` (AOTInductor,
no cache) and Path B uses `aot_autograd` → `aot_module_simplified` (with
`AOTAutogradCache`).  Since Path A has no cache, run pre-grad passes
explicitly before `aot_export_module`.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo Lucaskabela

[ghstack-poisoned]
frgossen added a commit that referenced this pull request Mar 9, 2026
@frgossen
Copy link
Contributor Author

frgossen commented Mar 9, 2026

Fixed test case, which relied on cache misses.

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Mar 11, 2026
…_by_example (#176938)

When running pre-grad passes outside the `fake_mode` wrapper in
`compile_fx`, the fake tensors attached to graph nodes can belong to a
different fake mode than the one active during pattern replacement
tracing. This mismatch triggers the "Mixing fake modes NYI" assertion.

Fix this by detecting the fake mode from the example values in
`Match.replace_by_example` and converting any fake tensors into that
mode before calling `trace_fn`. This removes the need for the
`wrapped_run_pre_grad_passes` shim in `compile_fx` that previously
papered over the problem by restoring the original fake mode.

Pull Request resolved: #176938
Approved by: https://github.com/aorenste, https://github.com/zou3519
ghstack dependencies: #176340

@inductor_config.patch("fx_graph_cache", True)
@functorch_config.patch("enable_autograd_cache", True)
def test_pre_grad_passes_called_on_cache_miss_only(self):
Copy link
Contributor

@zou3519 zou3519 Mar 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@frgossen how exactly is this being treated by the cache? is the cache just using the id of the pre_grad_pass function? If so, that's not robust (we'll always cache miss in a new process) and we should find a better way to deal with it.

Inductor post-grad passes offer a way for the user to specify a uuid() that gets included with the cache key. We should check if the same thing works with pre_grad_passes

I guess the way to convince me is to have a test launch another process and see what happens

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than starting another process could we just use a separate-but-equal additional function? Or do we bake in too much source information?

Copy link
Contributor

@zou3519 zou3519 Mar 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming the uuid mechanism exists for pre-grad passes, then yeah we could use a separate function that has the same uuid

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the way to convince me is to have a test launch another process and see what happens
That passes. #177397

I suspect though that pre-grad passes are not explicitly in the cache-key, however, what is included in the cache key thorugh torch_key() -> build_code_hash(_TORCH_PATH) -> is a hash of the entire source code, IIUC.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude just came to the same conclusion. "Pre-grad passes? Not explicitly as a field. The pre_grad_custom_pass is part of the inductor config snapshot (hashed inside inductor_config). The built-in pre-grad passes are captured implicitly via the PyTorch source code hash (torch_version)."

That still leaves pre-grad passes living outside of the torch code base uncovered, though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding this in https://github.com/pytorch/pytorch/pull/177403/changes
It does change the config API though, requiring a CustomGraphPassType now

if compiled_fn is None:
# Run pre-grad passes after cache lookup to cache pre-grad transforms.
if pre_grad_passes is not None and isinstance(mod, torch.fx.GraphModule):
mod = pre_grad_passes(mod, fake_flat_args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @frgossen Looks like the mod here is never used afterwards. So the pre_grad_passes here does not seem to take effect at all. Some of our internal workflows are broken with this commit, and the cause is that pre_grad_passes is not successfully applied. If this is a bug here, I am wondering why it didn't break any UT. Did I miss anything here? Thanks
CC @Yuxingwang-intel @Valentine233

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pre_grad_passes will change the mod in place. I will still fix the unused variable so it's not confusing.

I suspect your issue here is that the pre_grad_passes are not part of the cache key, which is a problem we have run into too. I'm fixing that in #177403 but your pre-grad passes will have to impl UUID. You can also change the pregrad pass timing when https://github.com/pytorch/pytorch/pull/177664/changes lands. The DEFAULT will be backwards compatible. Please let me know if that breaks you

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: just realizing that mod is used below

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants