[invoke_subgraph][inductor] Thread graphsafe rng input states for hops#160713
[invoke_subgraph][inductor] Thread graphsafe rng input states for hops#160713anijain2305 wants to merge 3 commits intogh/anijain2305/849/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160713
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 7379228 with merge base fa75ba9 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
eellison
left a comment
There was a problem hiding this comment.
one question about recursion
| There is a catch: for a short period, the joint graph is in a “bad” state. | ||
| The HOP subgraphs expect additional inputs (because of the new | ||
| placeholders), but the outer graph call sites don't yet provide them. We | ||
| can't fix this in the joint graph because the joint graph's input signature | ||
| is fixed (primals, tangents). As a compromise, we keep the joint graph in | ||
| somewhat of a bad state for some time and, once the outer forward and | ||
| backward graphs are partitioned, insert the corresponding RNG placeholders | ||
| and wire up the calls. |
There was a problem hiding this comment.
Would it be clearer to have a temporary node that represents the yet to be added placeholder node ? Then both the non-hop and hop could have a pass to lift them to placeholders ?
There was a problem hiding this comment.
I thought about that a little more, but I could not figure out how to add that temporary node. It can not be a placholder on the joint graph because the signature is primals, tangents where both of them are lists. And other parts of the stack (mostly partitioner) assumes this signature. If we make it (primals, tangents, fwd_rng_state, bw_rng_state), we will have to make changes at many many places. At that point, it might end up being more hacky.
| """ | ||
|
|
||
| rng_count = 0 | ||
| rng_string = "bwd_rng_state" if is_backward else "fwd_rng_state" |
There was a problem hiding this comment.
I would expect this function to be recursive, for when we have
module:
hop:
hop (rng)
There was a problem hiding this comment.
This will work because the run_joint_graph_passes_on_hops runs recursively. So there will be a sequence of partition_hop_level2_joint -> partition_hop_level1_joint -> partition_main_joint.
Overall, we have not very thoroughly tested nested hops. But I can try to add a few more tests in the followup PR. Some ideas in my mind are AC + invoke_subgraph.
|
Talked offline with @eellison - he suggested to insert a custom op in the joint graph that is later lifted as input during the partitioning. This should keep the joint graph in somewhat reasonable state. This can be done in a follow up PR. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
pytorch#160713) Pull Request resolved: pytorch#160713 Approved by: https://github.com/eellison
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben