Skip to content

[Inductor] Lower fallback nodes annotated with "should_fallback"#166339

Closed
chenmillie wants to merge 1 commit intopytorch:mainfrom
chenmillie:export-D85347587
Closed

[Inductor] Lower fallback nodes annotated with "should_fallback"#166339
chenmillie wants to merge 1 commit intopytorch:mainfrom
chenmillie:export-D85347587

Conversation

@chenmillie
Copy link
Contributor

@chenmillie chenmillie commented Oct 27, 2025

Summary:
This PR introduces an inductor-level fallback mechanism that gives users control over which operations or subgraphs Inductor should lower and which should fall back to preexisting kernels. This has similar motivation as #164776 in providing flexibility to selectively disable Inductor lowering for specific nodes.

The implementation simply adds a check for the "should_fallback" metadata annotation on FX graph nodes. If this is set to True, the lowerer falls back before attempting the normal lowering path. Note that since these are user-directed fallbacks dependent upon specific, customized conditions, use add_to_fallback_set=False to avoid permanent overwrites of inductor's lowering/fallback rules.

Simple example marking nodes for fallback based on custom predicates:

def should_fallback_predicate(node: torch.fx.Node, pred: Callable[torch.fx.Node, bool]):
    # Apply predicate and mark for fallback if needed
    if self.predicate(node):
         node.meta["should_fallback"] = True

Test Plan: added a CI test

Differential Revision: D85347587

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/166339

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 76cfd5d with merge base 5e7272b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-codesync
Copy link

meta-codesync bot commented Oct 27, 2025

@chenmillie has exported this pull request. If you are a Meta employee, you can view the originating Diff in D85347587.

@BoyuanFeng
Copy link
Contributor

maybe compose with regional inductor compile #164776

cc @anijain2305

@BoyuanFeng
Copy link
Contributor

BoyuanFeng commented Oct 27, 2025

could you try make_fallback?

Example: make_fallback(aten.randint)

Suggested by @angelayi

pytorch-bot bot pushed a commit that referenced this pull request Oct 28, 2025
…6339)

Summary:

This PR introduces an inductor-level fallback mechanism that gives users control over which operations or subgraphs Inductor should lower and which should fall back to preexisting kernels. This has similar motivation as #164776 in providing flexibility to selectively disable Inductor lowering for specific nodes.

The implementation simply adds a check for the `"should_fallback"` metadata annotation on FX graph nodes. If this is set to `True`, the lowerer falls back before attempting the normal lowering path. Note that since these are user-directed fallbacks dependent upon specific, customized conditions, use `add_to_fallback_set=False` to avoid permanent overwrites of inductor's lowering/fallback rules. 


Simple example marking nodes for fallback based on custom predicates:

```
def should_fallback_predicate(node: torch.fx.Node, pred: Callable[torch.fx.Node, bool]):
    # Apply predicate and mark for fallback if needed
    if self.predicate(node):
         node.meta["should_fallback"] = True
```

Test Plan: added a CI test

Differential Revision: D85347587
@chenmillie
Copy link
Contributor Author

chenmillie commented Oct 28, 2025

@BoyuanFeng @angelayi Thanks for taking a look. I did consider make_fallback at first, but it registers fallbacks at the op level by default, whereas we want fallback behavior at the node level, where the decision can depend on node-specific attributes. For that reason make_fallback may not be suitable for this feature.

chenmillie added a commit to chenmillie/pytorch that referenced this pull request Oct 28, 2025
…orch#166339)

Summary:

This PR introduces an inductor-level fallback mechanism that gives users control over which operations or subgraphs Inductor should lower and which should fall back to preexisting kernels. This has similar motivation as pytorch#164776 in providing flexibility to selectively disable Inductor lowering for specific nodes.

The implementation simply adds a check for the `"should_fallback"` metadata annotation on FX graph nodes. If this is set to `True`, the lowerer falls back before attempting the normal lowering path. Note that since these are user-directed fallbacks dependent upon specific, customized conditions, use `add_to_fallback_set=False` to avoid permanent overwrites of inductor's lowering/fallback rules. 


Simple example marking nodes for fallback based on custom predicates:

```
def should_fallback_predicate(node: torch.fx.Node, predicate: Callable[torch.fx.Node, bool]):
    # Apply predicate and mark for fallback if needed
    if predicate(node):
         node.meta["should_fallback"] = True
```

Test Plan: added a CI test

Differential Revision: D85347587
chenmillie added a commit to chenmillie/pytorch that referenced this pull request Oct 28, 2025
…orch#166339)

Summary:

This PR introduces an inductor-level fallback mechanism that gives users control over which operations or subgraphs Inductor should lower and which should fall back to preexisting kernels. This has similar motivation as pytorch#164776 in providing flexibility to selectively disable Inductor lowering for specific nodes.

The implementation simply adds a check for the `"should_fallback"` metadata annotation on FX graph nodes. If this is set to `True`, the lowerer falls back before attempting the normal lowering path. Note that since these are user-directed fallbacks dependent upon specific, customized conditions, use `add_to_fallback_set=False` to avoid permanent overwrites of inductor's lowering/fallback rules. 


Simple example marking nodes for fallback based on custom predicates:

```
def should_fallback_predicate(node: torch.fx.Node, predicate: Callable[torch.fx.Node, bool]):
    # Apply predicate and mark for fallback if needed
    if predicate(node):
         node.meta["should_fallback"] = True
```

Test Plan: added a CI test

Differential Revision: D85347587
Copy link
Contributor

@blaine-rister blaine-rister left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! This is a very important feature for MTIA, where we want to customize the fallback logic based on node arguments and not only the node target. Past experience has shown FX passes to be a very convenient way to control fallbacks. Though I'll defer to @BoyuanFeng @eellison or @ezyang to comment on the Inductor implementation.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 29, 2025
…orch#166339)

Summary:

This PR introduces an inductor-level fallback mechanism that gives users control over which operations or subgraphs Inductor should lower and which should fall back to preexisting kernels. This has similar motivation as pytorch#164776 in providing flexibility to selectively disable Inductor lowering for specific nodes.

The implementation simply adds a check for the `"should_fallback"` metadata annotation on FX graph nodes. If this is set to `True`, the lowerer falls back before attempting the normal lowering path. Note that since these are user-directed fallbacks dependent upon specific, customized conditions, use `add_to_fallback_set=False` to avoid permanent overwrites of inductor's lowering/fallback rules. 


Simple example marking nodes for fallback based on custom predicates:

```
def should_fallback_predicate(node: torch.fx.Node, predicate: Callable[torch.fx.Node, bool]):
    # Apply predicate and mark for fallback if needed
    if predicate(node):
         node.meta["should_fallback"] = True
```

Test Plan: added a CI test

Reviewed By: eellison, blaine-rister

Differential Revision: D85347587
@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

BoyuanFeng pushed a commit that referenced this pull request Oct 31, 2025
…6339)

Summary:
This PR introduces an inductor-level fallback mechanism that gives users control over which operations or subgraphs Inductor should lower and which should fall back to preexisting kernels. This has similar motivation as #164776 in providing flexibility to selectively disable Inductor lowering for specific nodes.

The implementation simply adds a check for the `"should_fallback"` metadata annotation on FX graph nodes. If this is set to `True`, the lowerer falls back before attempting the normal lowering path. Note that since these are user-directed fallbacks dependent upon specific, customized conditions, use `add_to_fallback_set=False` to avoid permanent overwrites of inductor's lowering/fallback rules.

Simple example marking nodes for fallback based on custom predicates:

```
def should_fallback_predicate(node: torch.fx.Node, pred: Callable[torch.fx.Node, bool]):
    # Apply predicate and mark for fallback if needed
    if self.predicate(node):
         node.meta["should_fallback"] = True
```

Test Plan: added a CI test

Differential Revision: D85347587

Pull Request resolved: #166339
Approved by: https://github.com/blaine-rister, https://github.com/eellison
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants