Skip to content

[while_loop] support closures#123018

Closed
ydwu4 wants to merge 7 commits intogh/ydwu4/100/basefrom
gh/ydwu4/100/head
Closed

[while_loop] support closures#123018
ydwu4 wants to merge 7 commits intogh/ydwu4/100/basefrom
gh/ydwu4/100/head

Conversation

@ydwu4
Copy link
Copy Markdown
Contributor

@ydwu4 ydwu4 commented Mar 30, 2024

Stack from ghstack (oldest at bottom):

We add an additional_inputs arguments to the HOP while_loop and rename the operands to carried_inputs based on offline discussion with @zou3519 . This allows us to support closures, parameters and buffers.

The alternative is to pass the lifted inputs directly to outputs of body_fn. But since we want the body_fn's output to not aliasing input. We'll need to copy the inputs and remove the copies later. This is a bit more work to do.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire @chauhang

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Mar 30, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/123018

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (3 Unrelated Failures)

As of commit c2fb37c with merge base 09c72ea (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Mar 30, 2024
ghstack-source-id: e44af3e
Pull Request resolved: #123018
Copy link
Copy Markdown
Contributor

@aakhundov aakhundov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @ydwu4 for extending the torch.while_loop API to handle additional inputs in a more efficient way (without need for cloning in each iteration)! Left a few comments, mostly nits.

Comment on lines +796 to +798
assert (
len(additional_inputs) == 0
), "Additional inputs are set automatically by dynamo"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If additional inputs are set automatically by dynamo, what is the reason of passing an empty tuple here? Should the number of arguments match?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, its schema has changed to torch.ops.higher_order.while_loop(cond, body, carried_inputs, additional_inputs). When dynamo is compiling this higher order op, the additional_inputs is expected to be a tuple.

I've updated the PR to incorporate this additional_inputs when it has non-zero length e.g. when run the tests with 'PYTORCH_TEST_WITH_DYNAMO=1' (i.e. torch.compile (a function with while_loop)).

"body_fn",
)

additional_lifted_inputs = tuple(cond_shared + cond_unique + body_unique)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, why do we include cond_shared into the lifted inputs here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cond_shared and body_shared refer to the same proxy in parent graph. Using either of them is OK. I can add a comment here.

Comment thread torch/_higher_order_ops/while_loop.py
Comment thread torch/_higher_order_ops/while_loop.py Outdated
return super().__call__(cond_fn, body_fn, operands)
if not isinstance(additional_inputs, tuple):
raise RuntimeError(
"additional_inputs must be a tuple, got " f"{type(additional_inputs)}"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UX question: if additional_inputs are generated internally (and are not considered a part of public API), would the errors re. additional_inputs make sense to the user?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you're right. This should probably just be an assertion.

Comment thread torch/_higher_order_ops/while_loop.py Outdated
raise RuntimeError(
"operands must be a tuple of tensors, ints, floats, or bools, got "
f"{operands}"
"carried_inputs must be a tuple, got " f"{type(carried_inputs)}"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: why separate f-string here? And below.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch..I don't remember why i'm writing this way lol

Comment thread torch/_inductor/codegen/wrapper.py Outdated
body_outer_inputs
) # carry over the state from body_fn

# carry over the state from body_fn
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make this comment a bit more specific mentioning that we only carry over the carried_inputs part of the inputs, but not the additional ones?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sg!

Comment thread torch/_inductor/ir.py Outdated

assert (
len(operands) > 0
len(carried_inputs) > 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, we can fetch the device from additional_inputs, too. So maybe just check that we have any input at least once?

Comment thread torch/_inductor/ir.py Outdated
Comment on lines +7214 to +7217
fx_carried_inputs = V.graph.current_node.args[-2]
fx_additional_inputs = V.graph.current_node.args[-1]
fake_carried_inputs = [x.meta["val"] for x in fx_carried_inputs] # type: ignore[union-attr]
fake_additional_inputs = [x.meta["val"] for x in fx_additional_inputs] # type: ignore[union-attr]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As these are used identically below for carried and additional inputs, would it make sense to combine them upstream to something like all_inputs and then do fx, fake, and use that single list below?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated.

Comment thread torch/_inductor/lowering.py Outdated
def while_loop(cond_fn, body_fn, operands):
if any(map(is_triton, operands)):
def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs):
if any(map(is_triton, carried_inputs)) or any(map(is_triton, additional_inputs)):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: combine carried_inputs + additional_inputs instead?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

Comment thread torch/_library/fake_class_registry.py Outdated
@@ -0,0 +1,234 @@
import logging
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems this file inadvertently made its way to this PR? :)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops!

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Apr 1, 2024
ghstack-source-id: 6deb60e
Pull Request resolved: #123018
@ydwu4 ydwu4 requested a review from zou3519 April 1, 2024 18:20
Comment thread torch/_higher_order_ops/while_loop.py Outdated

class WhileLoopOp(HigherOrderOperator):
def __call__(self, cond_fn, body_fn, operands):
def __call__(self, cond_fn, body_fn, carried_inputs, additional_inputs):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the chances you can make this positional-only lol.

e.g.:

def __call__(self, cond_fn, body_fn, carried_inputs, additional_inputs, /)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think we can do that.

I just found previously, we're not using this WhileLoopOp but the plain HigherOrderOp. Updated it to use the WhileLoopOp for checking the inputs. Will add more tests for invalide inputs as a follow-up.

Comment thread torch/_higher_order_ops/while_loop.py Outdated

class WhileLoopOp(HigherOrderOperator):
def __call__(self, cond_fn, body_fn, operands):
def __call__(self, cond_fn, body_fn, carried_inputs, additional_inputs):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add type annotations for these?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!

We add an additional_inputs arguments to the HOP while_loop and rename the operands to carried_inputs based on offline discussion with zou3519 . This allows us to support closures, parameters and buffers.

The alternative is to pass the lifted inputs directly to outputs of body_fn. But since we want the body_fn's output to not aliasing input. We'll need to copy the inputs and remove the copies later. This is a bit more work to do. 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
@aakhundov
Copy link
Copy Markdown
Contributor

@ydwu4 thanks for addressing the comments. Lots of tests are failing, could you have a look?

We add an additional_inputs arguments to the HOP while_loop and rename the operands to carried_inputs based on offline discussion with zou3519 . This allows us to support closures, parameters and buffers.

The alternative is to pass the lifted inputs directly to outputs of body_fn. But since we want the body_fn's output to not aliasing input. We'll need to copy the inputs and remove the copies later. This is a bit more work to do. 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Apr 1, 2024
ghstack-source-id: 374b170
Pull Request resolved: #123018
@ydwu4
Copy link
Copy Markdown
Contributor Author

ydwu4 commented Apr 2, 2024

Need to wait for the change of xla's while_loop xla key implementation in pytorch/xla#6872 to fix the xla test failures.

We add an additional_inputs arguments to the HOP while_loop and rename the operands to carried_inputs based on offline discussion with zou3519 . This allows us to support closures, parameters and buffers.

The alternative is to pass the lifted inputs directly to outputs of body_fn. But since we want the body_fn's output to not aliasing input. We'll need to copy the inputs and remove the copies later. This is a bit more work to do. 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
@ydwu4 ydwu4 mentioned this pull request Apr 2, 2024
ydwu4 added a commit that referenced this pull request Apr 2, 2024
ghstack-source-id: dfee1fe
Pull Request resolved: #123018
@ydwu4
Copy link
Copy Markdown
Contributor Author

ydwu4 commented Apr 3, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 3, 2024
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@ydwu4
Copy link
Copy Markdown
Contributor Author

ydwu4 commented Apr 3, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Apr 3, 2024
#123018 introduces a necessary bc breaking change and sees a bunch of xla test failures on CI. We made a pr to pytorch/xla to prepare for the breaking change pytorch/xla#6872. We update the pin of pytorch/xla to reflect the change in this PR.

Pull Request resolved: #123217
Approved by: https://github.com/clee2000
sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this pull request Apr 22, 2024
pytorch#123018 introduces a necessary bc breaking change and sees a bunch of xla test failures on CI. We made a pr to pytorch/xla to prepare for the breaking change pytorch/xla#6872. We update the pin of pytorch/xla to reflect the change in this PR.

Pull Request resolved: pytorch#123217
Approved by: https://github.com/clee2000
sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this pull request Apr 22, 2024
We add an additional_inputs arguments to the HOP while_loop and rename the operands to carried_inputs based on offline discussion with @zou3519 . This allows us to support closures, parameters and buffers.

The alternative is to pass the lifted inputs directly to outputs of body_fn. But since we want the body_fn's output to not aliasing input. We'll need to copy the inputs and remove the copies later. This is a bit more work to do.

Pull Request resolved: pytorch#123018
Approved by: https://github.com/aakhundov
ghstack dependencies: pytorch#123217
@github-actions github-actions Bot deleted the gh/ydwu4/100/head branch May 4, 2024 01:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants