Add pre_grad_pass_timing config for early vs late pre-grad passes#177429
Add pre_grad_pass_timing config for early vs late pre-grad passes#177429frgossen wants to merge 12 commits intogh/frgossen/12/basefrom
Conversation
Allow pre-grad passes to run before the AOT autograd cache lookup
("early") instead of only after it on cache miss ("late", the default).
With "early" timing the cache key reflects the already-transformed graph
and passes always execute, even on cache hits.
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/177429
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 1 PendingAs of commit 67ad41f with merge base 6a461fe ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
Allow pre-grad passes to run before the AOT autograd cache lookup
("early") instead of only after it on cache miss ("late", the default).
With "early" timing the cache key reflects the already-transformed graph
and passes always execute, even on cache hits.
ghstack-source-id: 8690d1e
Pull Request resolved: #177429
aorenste
left a comment
There was a problem hiding this comment.
Is it missing a test for caching "early" + UUID?
Maybe minor/not worth changing: After this change, won't an "early" pass's UUID will be part of the cache key? That means that if a pass ends up being a no-op it can't share a cache entry with a run without that pass. But not including the UUID is actually safe because the cache key is based on the pre-cache graph structure.
… passes"
Allow pre-grad passes to run before the AOT autograd cache lookup
("early") instead of only after it on cache miss ("late", the default).
With "early" timing the cache key reflects the already-transformed graph
and passes always execute, even on cache hits.
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo Lucaskabela
[ghstack-poisoned]
Allow pre-grad passes to run before the AOT autograd cache lookup
("early") instead of only after it on cache miss ("late", the default).
With "early" timing the cache key reflects the already-transformed graph
and passes always execute, even on cache hits.
ghstack-source-id: 0b0e8fc
Pull Request resolved: #177429
… passes"
Allow pre-grad passes to run before the AOT autograd cache lookup
("early") instead of only after it on cache miss ("late", the default).
With "early" timing the cache key reflects the already-transformed graph
and passes always execute, even on cache hits.
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo Lucaskabela
[ghstack-poisoned]
Allow pre-grad passes to run before the AOT autograd cache lookup
("early") instead of only after it on cache miss ("late", the default).
With "early" timing the cache key reflects the already-transformed graph
and passes always execute, even on cache hits.
ghstack-source-id: f257cef
Pull Request resolved: #177429
… passes"
Allow pre-grad passes to run before the AOT autograd cache lookup
("early") instead of only after it on cache miss ("late", the default).
With "early" timing the cache key reflects the already-transformed graph
and passes always execute, even on cache hits.
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo Lucaskabela
[ghstack-poisoned]
Allow pre-grad passes to run before the AOT autograd cache lookup
("early") instead of only after it on cache miss ("late", the default).
With "early" timing the cache key reflects the already-transformed graph
and passes always execute, even on cache hits.
ghstack-source-id: 21013bf
Pull Request resolved: #177429
… passes"
Allow pre-grad passes to run before the AOT autograd cache lookup
("early") instead of only after it on cache miss ("late", the default).
With "early" timing the cache key reflects the already-transformed graph
and passes always execute, even on cache hits.
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo Lucaskabela
[ghstack-poisoned]
Allow pre-grad passes to run before the AOT autograd cache lookup
("early") instead of only after it on cache miss ("late", the default).
With "early" timing the cache key reflects the already-transformed graph
and passes always execute, even on cache hits.
ghstack-source-id: bea0a46
Pull Request resolved: #177429
… passes"
Allow pre-grad passes to run before the AOT autograd cache lookup
("early") instead of only after it on cache miss ("late", the default).
With "early" timing the cache key reflects the already-transformed graph
and passes always execute, even on cache hits.
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo Lucaskabela
[ghstack-poisoned]
Allow pre-grad passes to run before the AOT autograd cache lookup
("early") instead of only after it on cache miss ("late", the default).
With "early" timing the cache key reflects the already-transformed graph
and passes always execute, even on cache hits.
ghstack-source-id: 3705010
Pull Request resolved: #177429
Good point. Put the pre-grad cache key contribution behind the timing config. |
… passes"
Allow pre-grad passes to run before the AOT autograd cache lookup
("early") instead of only after it on cache miss ("late", the default).
With "early" timing the cache key reflects the already-transformed graph
and passes always execute, even on cache hits.
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo Lucaskabela
[ghstack-poisoned]
Allow pre-grad passes to run before the AOT autograd cache lookup
("early") instead of only after it on cache miss ("late", the default).
With "early" timing the cache key reflects the already-transformed graph
and passes always execute, even on cache hits.
ghstack-source-id: 8519814
Pull Request resolved: #177429
… passes"
Allow pre-grad passes to run before the AOT autograd cache lookup
("early") instead of only after it on cache miss ("late", the default).
With "early" timing the cache key reflects the already-transformed graph
and passes always execute, even on cache hits.
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo Lucaskabela
[ghstack-poisoned]
… passes"
Allow pre-grad passes to run before the AOT autograd cache lookup
("early") instead of only after it on cache miss ("late", the default).
With "early" timing the cache key reflects the already-transformed graph
and passes always execute, even on cache hits.
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo Lucaskabela
[ghstack-poisoned]
|
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
This PR needs a
|
… passes"
Allow pre-grad passes to run before the AOT autograd cache lookup
("early") instead of only after it on cache miss ("late", the default).
With "early" timing the cache key reflects the already-transformed graph
and passes always execute, even on cache hits.
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo Lucaskabela
[ghstack-poisoned]
zou3519
left a comment
There was a problem hiding this comment.
lgtm but please read comments
… passes"
Allow pre-grad passes to run before the AOT autograd cache lookup
("early") instead of only after it on cache miss ("late", the default).
With "early" timing the cache key reflects the already-transformed graph
and passes always execute, even on cache hits.
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo Lucaskabela
[ghstack-poisoned]
… passes"
Allow pre-grad passes to run before the AOT autograd cache lookup
("early") instead of only after it on cache miss ("late", the default).
With "early" timing the cache key reflects the already-transformed graph
and passes always execute, even on cache hits.
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo Lucaskabela
[ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…torch#177429) Add a pre_grad_pass_timing config ("early", "late", or "default") that controls when pre-grad passes run relative to the AOT autograd cache lookup. - "early": passes run before cache lookup, so they execute on every compile (including cache hits) and the cache key reflects the already-transformed graph. - "late": passes run after cache lookup (only on cache miss); requires custom passes to provide a UUID for the cache key. - "default": automatically resolves to "late" when possible (no custom pass, or a custom pass with a UUID), and falls back to "early" when the custom pass has no UUID. Explicitly setting "late" with a UUID-less custom pass now raises a RuntimeError instead of silently bypassing the cache. The existing test_pre_grad_passes_called_on_cache_miss_only test is renamed and pinned to "late" timing, and new tests cover early timing, both default timing branches, and the error case. Pull Request resolved: pytorch#177429 Approved by: https://github.com/aorenste, https://github.com/zou3519 ghstack dependencies: pytorch#177397, pytorch#177403, pytorch#177428
Stack from ghstack (oldest at bottom):
Add a pre_grad_pass_timing config ("early", "late", or "default") that
controls when pre-grad passes run relative to the AOT autograd cache lookup.
(including cache hits) and the cache key reflects the already-transformed
graph.
custom passes to provide a UUID for the cache key.
or a custom pass with a UUID), and falls back to "early" when the custom
pass has no UUID.
Explicitly setting "late" with a UUID-less custom pass now raises a
RuntimeError instead of silently bypassing the cache. The existing
test_pre_grad_passes_called_on_cache_miss_only test is renamed and
pinned to "late" timing, and new tests cover early timing, both default
timing branches, and the error case.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo @Lucaskabela