Skip to content

[precompile] Integrate AOTI as a backend.#167338

Closed
zhxchen17 wants to merge 1 commit intomainfrom
zhxchen17/precompile/aoti
Closed

[precompile] Integrate AOTI as a backend.#167338
zhxchen17 wants to merge 1 commit intomainfrom
zhxchen17/precompile/aoti

Conversation

@zhxchen17
Copy link
Contributor

@zhxchen17 zhxchen17 commented Nov 7, 2025

@zhxchen17 zhxchen17 requested a review from bdhirsh as a code owner November 7, 2025 18:06
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 7, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/167338

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (4 Unrelated Failures)

As of commit 62bea25 with merge base d8384e2 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

with torch.device("cuda"):
from torch._dynamo.hooks import Hooks

mod = SimpleLinearModule()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of testing it this way, let's test it with actual aot_compile_module, passing in the AOTI backend with ModelInput.

Copy link
Contributor

@jamesjwu jamesjwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's get tests passing and add some tests around aot_compile_module. Otherwise, though, this looks good, hopefully we can land it quickly.


fake_mode = detect_fake_mode(inputs_)
ctx = (
mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to do this? Also isn't it equivalent to fake_mode.allow_non_fake_inputs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think normally aoti assumes this flag has been set from upper layer of the call stack, but in our case we are a new toplevel function so we need to set this properly.

reset_cudagraph_trees()


class _TorchCompileAOTInductorWrapper(_TorchCompileInductorWrapper):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's say I'm a random user with torch.compile, how do I use this? Do I have to pass in this specific backend to fullgraph_compile? Should we make it lower friction somehow?

I.e. backend="inductor" + some config = AOTI
backend="inductor" without config = Python Wrapper

@zhxchen17 zhxchen17 force-pushed the zhxchen17/precompile/aoti branch from 13b69e4 to cc82255 Compare November 10, 2025 21:03
@zhxchen17
Copy link
Contributor Author

Updated with:

  • unittest fixes
  • aot_compile_module() test
  • A toplevel option to use aoti torch.compile(options={'use_aoti': True})

@zhxchen17 zhxchen17 requested a review from jamesjwu November 10, 2025 21:36
@mlazos mlazos self-requested a review November 11, 2025 10:00
@zhxchen17
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 12, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / test (mps, 1, 1, macos-m2-15)

Details for Dev Infra team Raised by workflow job

if self.device_type.startswith("cuda"):
current_callable = (
torch._C._aoti.AOTIModelContainerRunnerCuda( # type: ignore[call-arg]
current_callable, 1, self.device_type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One perf trick here is to set run_single_threaded to True, otherwise it won't compose with cudagraphs, see #148601 for more backgrounds.

@zhxchen17
Copy link
Contributor Author

@pytorchbot merge

@zhxchen17
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@jeanschmidt
Copy link
Contributor

@pytorchbot revert -m "seems to be breaking internal tests and builds, see D86919103" -c ghfirst

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Nov 13, 2025
This reverts commit 273babe.

Reverted #167338 on behalf of https://github.com/jeanschmidt due to seems to be breaking internal tests and builds, see D86919103 ([comment](#167338 (comment)))
@pytorchmergebot
Copy link
Collaborator

@zhxchen17 your PR has been successfully reverted.

@zhxchen17 zhxchen17 force-pushed the zhxchen17/precompile/aoti branch from e9cc1a8 to 62bea25 Compare November 13, 2025 20:46
@zhxchen17
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Silv3S pushed a commit to Silv3S/pytorch that referenced this pull request Nov 18, 2025
Silv3S pushed a commit to Silv3S/pytorch that referenced this pull request Nov 18, 2025
This reverts commit 273babe.

Reverted pytorch#167338 on behalf of https://github.com/jeanschmidt due to seems to be breaking internal tests and builds, see D86919103 ([comment](pytorch#167338 (comment)))
Silv3S pushed a commit to Silv3S/pytorch that referenced this pull request Nov 18, 2025
@github-actions github-actions bot deleted the zhxchen17/precompile/aoti branch December 15, 2025 02:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants