[while_loop] support closures#123018
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/123018
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (3 Unrelated Failures)As of commit c2fb37c with merge base 09c72ea ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang [ghstack-poisoned]
| assert ( | ||
| len(additional_inputs) == 0 | ||
| ), "Additional inputs are set automatically by dynamo" |
There was a problem hiding this comment.
If additional inputs are set automatically by dynamo, what is the reason of passing an empty tuple here? Should the number of arguments match?
There was a problem hiding this comment.
yeah, its schema has changed to torch.ops.higher_order.while_loop(cond, body, carried_inputs, additional_inputs). When dynamo is compiling this higher order op, the additional_inputs is expected to be a tuple.
I've updated the PR to incorporate this additional_inputs when it has non-zero length e.g. when run the tests with 'PYTORCH_TEST_WITH_DYNAMO=1' (i.e. torch.compile (a function with while_loop)).
| "body_fn", | ||
| ) | ||
|
|
||
| additional_lifted_inputs = tuple(cond_shared + cond_unique + body_unique) |
There was a problem hiding this comment.
Curious, why do we include cond_shared into the lifted inputs here?
There was a problem hiding this comment.
cond_shared and body_shared refer to the same proxy in parent graph. Using either of them is OK. I can add a comment here.
| return super().__call__(cond_fn, body_fn, operands) | ||
| if not isinstance(additional_inputs, tuple): | ||
| raise RuntimeError( | ||
| "additional_inputs must be a tuple, got " f"{type(additional_inputs)}" |
There was a problem hiding this comment.
UX question: if additional_inputs are generated internally (and are not considered a part of public API), would the errors re. additional_inputs make sense to the user?
There was a problem hiding this comment.
Yeah, you're right. This should probably just be an assertion.
| raise RuntimeError( | ||
| "operands must be a tuple of tensors, ints, floats, or bools, got " | ||
| f"{operands}" | ||
| "carried_inputs must be a tuple, got " f"{type(carried_inputs)}" |
There was a problem hiding this comment.
Nit: why separate f-string here? And below.
There was a problem hiding this comment.
Nice catch..I don't remember why i'm writing this way lol
| body_outer_inputs | ||
| ) # carry over the state from body_fn | ||
|
|
||
| # carry over the state from body_fn |
There was a problem hiding this comment.
Could we make this comment a bit more specific mentioning that we only carry over the carried_inputs part of the inputs, but not the additional ones?
|
|
||
| assert ( | ||
| len(operands) > 0 | ||
| len(carried_inputs) > 0 |
There was a problem hiding this comment.
Actually, we can fetch the device from additional_inputs, too. So maybe just check that we have any input at least once?
| fx_carried_inputs = V.graph.current_node.args[-2] | ||
| fx_additional_inputs = V.graph.current_node.args[-1] | ||
| fake_carried_inputs = [x.meta["val"] for x in fx_carried_inputs] # type: ignore[union-attr] | ||
| fake_additional_inputs = [x.meta["val"] for x in fx_additional_inputs] # type: ignore[union-attr] |
There was a problem hiding this comment.
As these are used identically below for carried and additional inputs, would it make sense to combine them upstream to something like all_inputs and then do fx, fake, and use that single list below?
| def while_loop(cond_fn, body_fn, operands): | ||
| if any(map(is_triton, operands)): | ||
| def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs): | ||
| if any(map(is_triton, carried_inputs)) or any(map(is_triton, additional_inputs)): |
There was a problem hiding this comment.
Nit: combine carried_inputs + additional_inputs instead?
| @@ -0,0 +1,234 @@ | |||
| import logging | |||
There was a problem hiding this comment.
Seems this file inadvertently made its way to this PR? :)
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
|
|
||
| class WhileLoopOp(HigherOrderOperator): | ||
| def __call__(self, cond_fn, body_fn, operands): | ||
| def __call__(self, cond_fn, body_fn, carried_inputs, additional_inputs): |
There was a problem hiding this comment.
What are the chances you can make this positional-only lol.
e.g.:
def __call__(self, cond_fn, body_fn, carried_inputs, additional_inputs, /)
There was a problem hiding this comment.
Yeah, I think we can do that.
I just found previously, we're not using this WhileLoopOp but the plain HigherOrderOp. Updated it to use the WhileLoopOp for checking the inputs. Will add more tests for invalide inputs as a follow-up.
|
|
||
| class WhileLoopOp(HigherOrderOperator): | ||
| def __call__(self, cond_fn, body_fn, operands): | ||
| def __call__(self, cond_fn, body_fn, carried_inputs, additional_inputs): |
There was a problem hiding this comment.
can you add type annotations for these?
We add an additional_inputs arguments to the HOP while_loop and rename the operands to carried_inputs based on offline discussion with zou3519 . This allows us to support closures, parameters and buffers. The alternative is to pass the lifted inputs directly to outputs of body_fn. But since we want the body_fn's output to not aliasing input. We'll need to copy the inputs and remove the copies later. This is a bit more work to do. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
|
@ydwu4 thanks for addressing the comments. Lots of tests are failing, could you have a look? |
We add an additional_inputs arguments to the HOP while_loop and rename the operands to carried_inputs based on offline discussion with zou3519 . This allows us to support closures, parameters and buffers. The alternative is to pass the lifted inputs directly to outputs of body_fn. But since we want the body_fn's output to not aliasing input. We'll need to copy the inputs and remove the copies later. This is a bit more work to do. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
|
Need to wait for the change of xla's while_loop xla key implementation in pytorch/xla#6872 to fix the xla test failures. |
We add an additional_inputs arguments to the HOP while_loop and rename the operands to carried_inputs based on offline discussion with zou3519 . This allows us to support closures, parameters and buffers. The alternative is to pass the lifted inputs directly to outputs of body_fn. But since we want the body_fn's output to not aliasing input. We'll need to copy the inputs and remove the copies later. This is a bit more work to do. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
|
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
#123018 introduces a necessary bc breaking change and sees a bunch of xla test failures on CI. We made a pr to pytorch/xla to prepare for the breaking change pytorch/xla#6872. We update the pin of pytorch/xla to reflect the change in this PR. Pull Request resolved: #123217 Approved by: https://github.com/clee2000
pytorch#123018 introduces a necessary bc breaking change and sees a bunch of xla test failures on CI. We made a pr to pytorch/xla to prepare for the breaking change pytorch/xla#6872. We update the pin of pytorch/xla to reflect the change in this PR. Pull Request resolved: pytorch#123217 Approved by: https://github.com/clee2000
We add an additional_inputs arguments to the HOP while_loop and rename the operands to carried_inputs based on offline discussion with @zou3519 . This allows us to support closures, parameters and buffers. The alternative is to pass the lifted inputs directly to outputs of body_fn. But since we want the body_fn's output to not aliasing input. We'll need to copy the inputs and remove the copies later. This is a bit more work to do. Pull Request resolved: pytorch#123018 Approved by: https://github.com/aakhundov ghstack dependencies: pytorch#123217
Stack from ghstack (oldest at bottom):
We add an additional_inputs arguments to the HOP while_loop and rename the operands to carried_inputs based on offline discussion with @zou3519 . This allows us to support closures, parameters and buffers.
The alternative is to pass the lifted inputs directly to outputs of body_fn. But since we want the body_fn's output to not aliasing input. We'll need to copy the inputs and remove the copies later. This is a bit more work to do.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire @chauhang