Skip to content

[invoke_subgraph][inductor] Thread graphsafe rng input states for hops#160713

Closed
anijain2305 wants to merge 3 commits intogh/anijain2305/849/basefrom
gh/anijain2305/849/head
Closed

[invoke_subgraph][inductor] Thread graphsafe rng input states for hops#160713
anijain2305 wants to merge 3 commits intogh/anijain2305/849/basefrom
gh/anijain2305/849/head

Conversation

[ghstack-poisoned]
@anijain2305 anijain2305 requested a review from zou3519 as a code owner August 15, 2025 06:15
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 15, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160713

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 7379228 with merge base fa75ba9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Aug 15, 2025
@anijain2305 anijain2305 added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 15, 2025
@anijain2305 anijain2305 requested a review from eellison August 15, 2025 19:03
[ghstack-poisoned]
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one question about recursion

Comment on lines +2619 to +2626
There is a catch: for a short period, the joint graph is in a “bad” state.
The HOP subgraphs expect additional inputs (because of the new
placeholders), but the outer graph call sites don't yet provide them. We
can't fix this in the joint graph because the joint graph's input signature
is fixed (primals, tangents). As a compromise, we keep the joint graph in
somewhat of a bad state for some time and, once the outer forward and
backward graphs are partitioned, insert the corresponding RNG placeholders
and wire up the calls.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be clearer to have a temporary node that represents the yet to be added placeholder node ? Then both the non-hop and hop could have a pass to lift them to placeholders ?

Copy link
Contributor Author

@anijain2305 anijain2305 Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about that a little more, but I could not figure out how to add that temporary node. It can not be a placholder on the joint graph because the signature is primals, tangents where both of them are lists. And other parts of the stack (mostly partitioner) assumes this signature. If we make it (primals, tangents, fwd_rng_state, bw_rng_state), we will have to make changes at many many places. At that point, it might end up being more hacky.

"""

rng_count = 0
rng_string = "bwd_rng_state" if is_backward else "fwd_rng_state"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect this function to be recursive, for when we have

module:
hop:
hop (rng)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will work because the run_joint_graph_passes_on_hops runs recursively. So there will be a sequence of partition_hop_level2_joint -> partition_hop_level1_joint -> partition_main_joint.

Overall, we have not very thoroughly tested nested hops. But I can try to add a few more tests in the followup PR. Some ideas in my mind are AC + invoke_subgraph.

@anijain2305
Copy link
Contributor Author

anijain2305 commented Aug 21, 2025

Talked offline with @eellison - he suggested to insert a custom op in the joint graph that is later lifted as input during the partitioning. This should keep the joint graph in somewhat reasonable state. This can be done in a follow up PR.

@anijain2305
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
@github-actions github-actions bot deleted the gh/anijain2305/849/head branch September 21, 2025 02:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants