[Inductor] Lower fallback nodes annotated with "should_fallback"#166339
[Inductor] Lower fallback nodes annotated with "should_fallback"#166339chenmillie wants to merge 1 commit intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/166339
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 76cfd5d with merge base 5e7272b ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@chenmillie has exported this pull request. If you are a Meta employee, you can view the originating Diff in D85347587. |
|
maybe compose with regional inductor compile #164776 cc @anijain2305 |
|
could you try make_fallback? Example: Suggested by @angelayi |
4cedc02 to
0aa2296
Compare
…6339) Summary: This PR introduces an inductor-level fallback mechanism that gives users control over which operations or subgraphs Inductor should lower and which should fall back to preexisting kernels. This has similar motivation as #164776 in providing flexibility to selectively disable Inductor lowering for specific nodes. The implementation simply adds a check for the `"should_fallback"` metadata annotation on FX graph nodes. If this is set to `True`, the lowerer falls back before attempting the normal lowering path. Note that since these are user-directed fallbacks dependent upon specific, customized conditions, use `add_to_fallback_set=False` to avoid permanent overwrites of inductor's lowering/fallback rules. Simple example marking nodes for fallback based on custom predicates: ``` def should_fallback_predicate(node: torch.fx.Node, pred: Callable[torch.fx.Node, bool]): # Apply predicate and mark for fallback if needed if self.predicate(node): node.meta["should_fallback"] = True ``` Test Plan: added a CI test Differential Revision: D85347587
|
@BoyuanFeng @angelayi Thanks for taking a look. I did consider |
0aa2296 to
9789e9f
Compare
…orch#166339) Summary: This PR introduces an inductor-level fallback mechanism that gives users control over which operations or subgraphs Inductor should lower and which should fall back to preexisting kernels. This has similar motivation as pytorch#164776 in providing flexibility to selectively disable Inductor lowering for specific nodes. The implementation simply adds a check for the `"should_fallback"` metadata annotation on FX graph nodes. If this is set to `True`, the lowerer falls back before attempting the normal lowering path. Note that since these are user-directed fallbacks dependent upon specific, customized conditions, use `add_to_fallback_set=False` to avoid permanent overwrites of inductor's lowering/fallback rules. Simple example marking nodes for fallback based on custom predicates: ``` def should_fallback_predicate(node: torch.fx.Node, predicate: Callable[torch.fx.Node, bool]): # Apply predicate and mark for fallback if needed if predicate(node): node.meta["should_fallback"] = True ``` Test Plan: added a CI test Differential Revision: D85347587
9789e9f to
7d12a00
Compare
…orch#166339) Summary: This PR introduces an inductor-level fallback mechanism that gives users control over which operations or subgraphs Inductor should lower and which should fall back to preexisting kernels. This has similar motivation as pytorch#164776 in providing flexibility to selectively disable Inductor lowering for specific nodes. The implementation simply adds a check for the `"should_fallback"` metadata annotation on FX graph nodes. If this is set to `True`, the lowerer falls back before attempting the normal lowering path. Note that since these are user-directed fallbacks dependent upon specific, customized conditions, use `add_to_fallback_set=False` to avoid permanent overwrites of inductor's lowering/fallback rules. Simple example marking nodes for fallback based on custom predicates: ``` def should_fallback_predicate(node: torch.fx.Node, predicate: Callable[torch.fx.Node, bool]): # Apply predicate and mark for fallback if needed if predicate(node): node.meta["should_fallback"] = True ``` Test Plan: added a CI test Differential Revision: D85347587
There was a problem hiding this comment.
LGTM! This is a very important feature for MTIA, where we want to customize the fallback logic based on node arguments and not only the node target. Past experience has shown FX passes to be a very convenient way to control fallbacks. Though I'll defer to @BoyuanFeng @eellison or @ezyang to comment on the Inductor implementation.
…orch#166339) Summary: This PR introduces an inductor-level fallback mechanism that gives users control over which operations or subgraphs Inductor should lower and which should fall back to preexisting kernels. This has similar motivation as pytorch#164776 in providing flexibility to selectively disable Inductor lowering for specific nodes. The implementation simply adds a check for the `"should_fallback"` metadata annotation on FX graph nodes. If this is set to `True`, the lowerer falls back before attempting the normal lowering path. Note that since these are user-directed fallbacks dependent upon specific, customized conditions, use `add_to_fallback_set=False` to avoid permanent overwrites of inductor's lowering/fallback rules. Simple example marking nodes for fallback based on custom predicates: ``` def should_fallback_predicate(node: torch.fx.Node, predicate: Callable[torch.fx.Node, bool]): # Apply predicate and mark for fallback if needed if predicate(node): node.meta["should_fallback"] = True ``` Test Plan: added a CI test Reviewed By: eellison, blaine-rister Differential Revision: D85347587
7d12a00 to
76cfd5d
Compare
|
@pytorchbot merge (Initiating merge automatically since Phabricator Diff has merged) |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…6339) Summary: This PR introduces an inductor-level fallback mechanism that gives users control over which operations or subgraphs Inductor should lower and which should fall back to preexisting kernels. This has similar motivation as #164776 in providing flexibility to selectively disable Inductor lowering for specific nodes. The implementation simply adds a check for the `"should_fallback"` metadata annotation on FX graph nodes. If this is set to `True`, the lowerer falls back before attempting the normal lowering path. Note that since these are user-directed fallbacks dependent upon specific, customized conditions, use `add_to_fallback_set=False` to avoid permanent overwrites of inductor's lowering/fallback rules. Simple example marking nodes for fallback based on custom predicates: ``` def should_fallback_predicate(node: torch.fx.Node, pred: Callable[torch.fx.Node, bool]): # Apply predicate and mark for fallback if needed if self.predicate(node): node.meta["should_fallback"] = True ``` Test Plan: added a CI test Differential Revision: D85347587 Pull Request resolved: #166339 Approved by: https://github.com/blaine-rister, https://github.com/eellison
Summary:
This PR introduces an inductor-level fallback mechanism that gives users control over which operations or subgraphs Inductor should lower and which should fall back to preexisting kernels. This has similar motivation as #164776 in providing flexibility to selectively disable Inductor lowering for specific nodes.
The implementation simply adds a check for the
"should_fallback"metadata annotation on FX graph nodes. If this is set toTrue, the lowerer falls back before attempting the normal lowering path. Note that since these are user-directed fallbacks dependent upon specific, customized conditions, useadd_to_fallback_set=Falseto avoid permanent overwrites of inductor's lowering/fallback rules.Simple example marking nodes for fallback based on custom predicates:
Test Plan: added a CI test
Differential Revision: D85347587
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben