Move pre-grad passes after AOTAutograd cache lookup#176340
Move pre-grad passes after AOTAutograd cache lookup#176340frgossen wants to merge 11 commits intogh/frgossen/2/basefrom
Conversation
Previously, `run_pre_grad_passes` was called in `compile_fx.py` before `aot_module_simplified`, which meant pre-grad transformations were not included in the cached artifacts. On cache hits, the passes would still run unnecessarily. This change: - Adds an optional `pre_grad_passes` callback parameter to `aot_module_simplified` to avoid circular imports - Calls the callback after the cache lookup (on cache miss only) - Ensures pre-grad transformations are included in cached artifacts - Adds a test to verify pre_grad_passes is skipped on cache hits [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/176340
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (3 Unrelated Failures)As of commit 150f7fe with merge base e45dfba ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following job failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Previously, `run_pre_grad_passes` was called in `compile_fx.py` before `aot_module_simplified`, which meant pre-grad transformations were not included in the cached artifacts. On cache hits, the passes would still run unnecessarily. This change: - Adds an optional `pre_grad_passes` callback parameter to `aot_module_simplified` to avoid circular imports - Calls the callback after the cache lookup (on cache miss only) - Ensures pre-grad transformations are included in cached artifacts - Adds a test to verify pre_grad_passes is skipped on cache hits ghstack-source-id: ebc2e9b Pull Request resolved: #176340
This PR needs a
|
Previously, `run_pre_grad_passes` was called in `compile_fx.py` before `aot_module_simplified`, which meant pre-grad transformations were not included in the cached artifacts. On cache hits, the passes would still run unnecessarily. This change: - Adds an optional `pre_grad_passes` callback parameter to `aot_module_simplified` to avoid circular imports - Calls the callback after the cache lookup (on cache miss only) - Ensures pre-grad transformations are included in cached artifacts - Adds a test to verify pre_grad_passes is skipped on cache hits cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo Lucaskabela [ghstack-poisoned]
Previously, `run_pre_grad_passes` was called in `compile_fx.py` before `aot_module_simplified`, which meant pre-grad transformations were not included in the cached artifacts. On cache hits, the passes would still run unnecessarily. This change: - Adds an optional `pre_grad_passes` callback parameter to `aot_module_simplified` to avoid circular imports - Calls the callback after the cache lookup (on cache miss only) - Ensures pre-grad transformations are included in cached artifacts - Adds a test to verify pre_grad_passes is skipped on cache hits ghstack-source-id: e659fd5 Pull Request resolved: #176340
Previously, `run_pre_grad_passes` was called in `compile_fx.py` before `aot_module_simplified`, which meant pre-grad transformations were not included in the cached artifacts. On cache hits, the passes would still run unnecessarily. This change: - Adds an optional `pre_grad_passes` callback parameter to `aot_module_simplified` to avoid circular imports - Calls the callback after the cache lookup (on cache miss only) - Ensures pre-grad transformations are included in cached artifacts - Adds a test to verify pre_grad_passes is skipped on cache hits cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo Lucaskabela [ghstack-poisoned]
Previously, `run_pre_grad_passes` was called in `compile_fx.py` before `aot_module_simplified`, which meant pre-grad transformations were not included in the cached artifacts. On cache hits, the passes would still run unnecessarily. This change: - Adds an optional `pre_grad_passes` callback parameter to `aot_module_simplified` to avoid circular imports - Calls the callback after the cache lookup (on cache miss only) - Ensures pre-grad transformations are included in cached artifacts - Adds a test to verify pre_grad_passes is skipped on cache hits ghstack-source-id: 2db9aae Pull Request resolved: #176340
Previously, `run_pre_grad_passes` was called in `compile_fx.py` before `aot_module_simplified`, which meant pre-grad transformations were not included in the cached artifacts. On cache hits, the passes would still run unnecessarily. This change: - Adds an optional `pre_grad_passes` callback parameter to `aot_module_simplified` to avoid circular imports - Calls the callback after the cache lookup (on cache miss only) - Ensures pre-grad transformations are included in cached artifacts - Adds a test to verify pre_grad_passes is skipped on cache hits cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo Lucaskabela [ghstack-poisoned]
Previously, `run_pre_grad_passes` was called in `compile_fx.py` before `aot_module_simplified`, which meant pre-grad transformations were not included in the cached artifacts. On cache hits, the passes would still run unnecessarily. This change: - Adds an optional `pre_grad_passes` callback parameter to `aot_module_simplified` to avoid circular imports - Calls the callback after the cache lookup (on cache miss only) - Ensures pre-grad transformations are included in cached artifacts - Adds a test to verify pre_grad_passes is skipped on cache hits ghstack-source-id: 3d81951 Pull Request resolved: #176340
aorenste
left a comment
There was a problem hiding this comment.
With the additional changes this LGTM
Previously, `run_pre_grad_passes` was called unconditionally at the top of `_compile_fx_main`. This meant pre-grad transformations were not included in cached artifacts and ran unnecessarily on cache hits. Move pre-grad passes into `aot_module_simplified` (Path B) via a callback so they run after the cache lookup — on cache miss only. `_compile_fx_main` has two compilation paths that diverge at the `V.aot_compilation` check: Path A uses `aot_export_module` (AOTInductor, no cache) and Path B uses `aot_autograd` → `aot_module_simplified` (with `AOTAutogradCache`). Since Path A has no cache, run pre-grad passes explicitly before `aot_export_module`. ghstack-source-id: 47c36fa Pull Request resolved: #176340
Previously, `run_pre_grad_passes` was called unconditionally at the top of `_compile_fx_main`. This meant pre-grad transformations were not included in cached artifacts and ran unnecessarily on cache hits. Move pre-grad passes into `aot_module_simplified` (Path B) via a callback so they run after the cache lookup — on cache miss only. `_compile_fx_main` has two compilation paths that diverge at the `V.aot_compilation` check: Path A uses `aot_export_module` (AOTInductor, no cache) and Path B uses `aot_autograd` → `aot_module_simplified` (with `AOTAutogradCache`). Since Path A has no cache, run pre-grad passes explicitly before `aot_export_module`. [ghstack-poisoned]
Previously, `run_pre_grad_passes` was called unconditionally at the top of `_compile_fx_main`. This meant pre-grad transformations were not included in cached artifacts and ran unnecessarily on cache hits. Move pre-grad passes into `aot_module_simplified` (Path B) via a callback so they run after the cache lookup — on cache miss only. `_compile_fx_main` has two compilation paths that diverge at the `V.aot_compilation` check: Path A uses `aot_export_module` (AOTInductor, no cache) and Path B uses `aot_autograd` → `aot_module_simplified` (with `AOTAutogradCache`). Since Path A has no cache, run pre-grad passes explicitly before `aot_export_module`. ghstack-source-id: 47c36fa Pull Request resolved: #176340
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 2 checks: inductor / unit-test / inductor-test / test (inductor, 2, 2, linux.g5.4xlarge.nvidia.gpu), inductor / inductor-test / test (inductor_torchbench, 1, 2, linux.g5.4xlarge.nvidia.gpu) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / test (default, 3, 3, macos-m1-stable) Details for Dev Infra teamRaised by workflow job |
Previously, `run_pre_grad_passes` was called unconditionally at the top of `_compile_fx_main`. This meant pre-grad transformations were not included in cached artifacts and ran unnecessarily on cache hits. Move pre-grad passes into `aot_module_simplified` (Path B) via a callback so they run after the cache lookup — on cache miss only. `_compile_fx_main` has two compilation paths that diverge at the `V.aot_compilation` check: Path A uses `aot_export_module` (AOTInductor, no cache) and Path B uses `aot_autograd` → `aot_module_simplified` (with `AOTAutogradCache`). Since Path A has no cache, run pre-grad passes explicitly before `aot_export_module`. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo Lucaskabela [ghstack-poisoned]
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
@pytorchbot revert -m "Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable" -c autorevert This PR is attributed to have caused regression in: Please investigate and fix the issues. |
|
Claude finished @pytorch-auto-revert[bot]'s task in 5m 41s —— View job Diagnosis of Revert for PR #176340
Failing Test
The test fails consistently across all CI configurations (ASAN, clang15, CUDA 12.8, CUDA 13.0) — both CPU and GPU variants. The error is: This means Verdict: Legitimate RevertThis is a real regression, not a flake. The failure is deterministic and reproduced across every CI job that ran this test. Root CauseThe PR moved if pre_grad_passes is not None and isinstance(mod, torch.fx.GraphModule):
mod = pre_grad_passes(mod, fake_flat_args)The problem is the second argument:
The How to FixThe pre-grad passes callback should be called with
Option 2 is more precise since Additionally, for |
|
@pytorchbot successfully started a revert job. Check the current status here. |
This reverts commit ff91f31. Reverted #176340 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](#176340 (comment)))
|
@frgossen your PR has been successfully reverted. |
This reverts commit 3087999. [ghstack-poisoned]
Previously, `run_pre_grad_passes` was called unconditionally at the top of `_compile_fx_main`. This meant pre-grad transformations were not included in cached artifacts and ran unnecessarily on cache hits. Move pre-grad passes into `aot_module_simplified` (Path B) via a callback so they run after the cache lookup — on cache miss only. `_compile_fx_main` has two compilation paths that diverge at the `V.aot_compilation` check: Path A uses `aot_export_module` (AOTInductor, no cache) and Path B uses `aot_autograd` → `aot_module_simplified` (with `AOTAutogradCache`). Since Path A has no cache, run pre-grad passes explicitly before `aot_export_module`. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo Lucaskabela [ghstack-poisoned]
|
Fixed test case, which relied on cache misses. @pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…_by_example (#176938) When running pre-grad passes outside the `fake_mode` wrapper in `compile_fx`, the fake tensors attached to graph nodes can belong to a different fake mode than the one active during pattern replacement tracing. This mismatch triggers the "Mixing fake modes NYI" assertion. Fix this by detecting the fake mode from the example values in `Match.replace_by_example` and converting any fake tensors into that mode before calling `trace_fn`. This removes the need for the `wrapped_run_pre_grad_passes` shim in `compile_fx` that previously papered over the problem by restoring the original fake mode. Pull Request resolved: #176938 Approved by: https://github.com/aorenste, https://github.com/zou3519 ghstack dependencies: #176340
|
|
||
| @inductor_config.patch("fx_graph_cache", True) | ||
| @functorch_config.patch("enable_autograd_cache", True) | ||
| def test_pre_grad_passes_called_on_cache_miss_only(self): |
There was a problem hiding this comment.
@frgossen how exactly is this being treated by the cache? is the cache just using the id of the pre_grad_pass function? If so, that's not robust (we'll always cache miss in a new process) and we should find a better way to deal with it.
Inductor post-grad passes offer a way for the user to specify a uuid() that gets included with the cache key. We should check if the same thing works with pre_grad_passes
I guess the way to convince me is to have a test launch another process and see what happens
There was a problem hiding this comment.
Rather than starting another process could we just use a separate-but-equal additional function? Or do we bake in too much source information?
There was a problem hiding this comment.
Assuming the uuid mechanism exists for pre-grad passes, then yeah we could use a separate function that has the same uuid
There was a problem hiding this comment.
I guess the way to convince me is to have a test launch another process and see what happens
That passes. #177397
I suspect though that pre-grad passes are not explicitly in the cache-key, however, what is included in the cache key thorugh torch_key() -> build_code_hash(_TORCH_PATH) -> is a hash of the entire source code, IIUC.
There was a problem hiding this comment.
Claude just came to the same conclusion. "Pre-grad passes? Not explicitly as a field. The pre_grad_custom_pass is part of the inductor config snapshot (hashed inside inductor_config). The built-in pre-grad passes are captured implicitly via the PyTorch source code hash (torch_version)."
That still leaves pre-grad passes living outside of the torch code base uncovered, though.
There was a problem hiding this comment.
Adding this in https://github.com/pytorch/pytorch/pull/177403/changes
It does change the config API though, requiring a CustomGraphPassType now
| if compiled_fn is None: | ||
| # Run pre-grad passes after cache lookup to cache pre-grad transforms. | ||
| if pre_grad_passes is not None and isinstance(mod, torch.fx.GraphModule): | ||
| mod = pre_grad_passes(mod, fake_flat_args) |
There was a problem hiding this comment.
Hi @frgossen Looks like the mod here is never used afterwards. So the pre_grad_passes here does not seem to take effect at all. Some of our internal workflows are broken with this commit, and the cause is that pre_grad_passes is not successfully applied. If this is a bug here, I am wondering why it didn't break any UT. Did I miss anything here? Thanks
CC @Yuxingwang-intel @Valentine233
There was a problem hiding this comment.
The pre_grad_passes will change the mod in place. I will still fix the unused variable so it's not confusing.
I suspect your issue here is that the pre_grad_passes are not part of the cache key, which is a problem we have run into too. I'm fixing that in #177403 but your pre-grad passes will have to impl UUID. You can also change the pregrad pass timing when https://github.com/pytorch/pytorch/pull/177664/changes lands. The DEFAULT will be backwards compatible. Please let me know if that breaks you
There was a problem hiding this comment.
nit: just realizing that mod is used below
Stack from ghstack (oldest at bottom):
Previously,
run_pre_grad_passeswas called unconditionally at the topof
_compile_fx_main. This meant pre-grad transformations were notincluded in cached artifacts and ran unnecessarily on cache hits.
Move pre-grad passes into
aot_module_simplified(Path B) via acallback so they run after the cache lookup — on cache miss only.
_compile_fx_mainhas two compilation paths that diverge at theV.aot_compilationcheck: Path A usesaot_export_module(AOTInductor,no cache) and Path B uses
aot_autograd→aot_module_simplified(withAOTAutogradCache). Since Path A has no cache, run pre-grad passesexplicitly before
aot_export_module.cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo @Lucaskabela