Fix nested stableHLO composite regions#9385
Merged
qihqi merged 5 commits intopytorch:masterfrom Jul 20, 2025
Merged
Conversation
Renamed _impl in test, so that first impl is 'impl', first is 'impl_0' and so on
Collaborator
|
Thank you for the fix! Pending on CI |
lsy323
approved these changes
Jul 1, 2025
Contributor
Author
|
Hi, happy this was approved! Apparently the last check is failing since "Secret TORCH_XLA_BOT_TOKEN is required, but not provided while calling." . I'm guessing that isn't something that is wrong from my end? If I need to change anything let me know! |
Contributor
Author
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fix export failure when StableHLO regions are nested (e.g. SubModule inside Model), follow up on issue 6978
This PR aims to add support for nested stable HLO regions. Currently, trying to export something along these lines:
raises an error. This PR makes it so that this generates correct MLIR:
Cause
Solution