Skip to content

[aotd] Support mutations in reordering_to_mimic_autograd_engine#155353

Closed
IvanKobzarev wants to merge 2 commits intogh/IvanKobzarev/113/basefrom
gh/IvanKobzarev/113/head
Closed

[aotd] Support mutations in reordering_to_mimic_autograd_engine#155353
IvanKobzarev wants to merge 2 commits intogh/IvanKobzarev/113/basefrom
gh/IvanKobzarev/113/head

Conversation

@IvanKobzarev
Copy link
Contributor

@IvanKobzarev IvanKobzarev commented Jun 6, 2025

Stack from ghstack (oldest at bottom):

Original issue: #154820

Dedicated sub-issue: #155242

Backward graph is reordered by partitioners.py: reordering_to_mimic_autograd_engine

Which only records in the backward graph compute that starts from tangents.

Mutation of primals(inputs) in backward can be disconnected from backward.

Handling this copy_ specifically, as we add this mutation in framework and this is the only mutation that exist.

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/155353

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 8ff6544 with merge base ea5b9ec (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

torch.compile(fn, backend="aot_eager", fullgraph=True)(
dummy, inplace
).sum().backward()
self.assertEqual(ref, inplace)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have the test assert that the inputs are correct, both:

(1) after running the compiled (and reference) forwards, but before running the backward

(2) after running the backward?

That should help ensure that the test actually confirms that we are not e.g. moving the backward mutation into the forward graph, too

Copy link
Collaborator

@bdhirsh bdhirsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm minus test nit!

…ngine"


Original issue: #154820

Dedicated sub-issue: #155242

Backward graph is reordered by partitioners.py: reordering_to_mimic_autograd_engine

Which only records in the backward graph compute that starts from tangents.

Mutation of primals(inputs) in backward can be disconnected from backward.

Handling this copy_ specifically, as we  add this mutation in framework and this is the only mutation that exist.




[ghstack-poisoned]
@IvanKobzarev
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 9, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-jammy-rocm-py3.10 / test (distributed, 1, 1, linux.rocm.gpu.4)

Details for Dev Infra team Raised by workflow job

@IvanKobzarev
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants