Skip to content

[Fori_loop|While_loop] Enable fori_loop with add/sub test case#6603

Merged
ManfeiBai merged 64 commits intomasterfrom
ManfeiBai-patch-73
Mar 8, 2024
Merged

[Fori_loop|While_loop] Enable fori_loop with add/sub test case#6603
ManfeiBai merged 64 commits intomasterfrom
ManfeiBai-patch-73

Conversation

@ManfeiBai
Copy link
Copy Markdown
Collaborator

@ManfeiBai ManfeiBai commented Feb 23, 2024

For fori_loop implementation with while_loop, this PR is for lowering body/cond to replace formal placeholder

This is the step two PR, and father PR(#6532), child PR(#6529), source PR(#6563)


some issue fixed:

  • (before)body fn is -, not a torch func, will test later (after)tried torch.sub(a, b), passed too locally
  • current code has changed many logic of lowering, let's move these logics to a new function without affecting the existing functions
  • (before)input are limited to list/tuple (after)this match torch._higher_order_ops.while_loop required
  • (before)input was trans from list to not list after torch.compile, TODO, add the same logic like torch.compile to use inputs, not like currently create a duplicated tensor in the fori_loop.py file

@ManfeiBai ManfeiBai changed the title Update test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py [Do Not Merge] Update test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py Feb 23, 2024
@ManfeiBai ManfeiBai changed the title [Do Not Merge] Update test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py [Do Not Merge][Fori_loop|While_loop] Update test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py Feb 24, 2024
Comment thread torch_xla/csrc/lowering_context.cpp Outdated
if (!root_tuple_.empty() & (root_tuple_.size() > 1)) {
xla::XlaOp root = xla::Tuple(builder(), root_tuple_);
xla = builder()->Build(root);
} else if (!root_tuple_.empty() & (root_tuple_.size() == 1)) {
Copy link
Copy Markdown
Collaborator Author

@ManfeiBai ManfeiBai Feb 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain: we need to skip tuple for cond/body computation creation to match the xla::While format check for cond, error log

@ManfeiBai ManfeiBai marked this pull request as ready for review February 26, 2024 19:34
@ManfeiBai
Copy link
Copy Markdown
Collaborator Author

ManfeiBai commented Feb 26, 2024

Hi, @JackCaoG, since this PR would add new function to PyLoweringContext, do we want to request review from aws too?

@ManfeiBai ManfeiBai changed the title [Do Not Merge][Fori_loop|While_loop] Update test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py [Fori_loop|While_loop] Update test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py Feb 26, 2024
@JackCaoG
Copy link
Copy Markdown
Collaborator

@amithrm FYI

@ManfeiBai
Copy link
Copy Markdown
Collaborator Author

kokoro failure should be fixed on master branch, let's skip it now

Comment thread torch_xla/csrc/init_python_bindings.cpp Outdated
Comment thread torch_xla/experimental/fori_loop.py Outdated
@ManfeiBai ManfeiBai requested a review from yeounoh February 29, 2024 18:07
Comment thread torch_xla/csrc/lowering_context.h Outdated
Comment thread torch_xla/experimental/fori_loop.py Outdated
Copy link
Copy Markdown
Contributor

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some suggestions

@ManfeiBai ManfeiBai requested a review from yeounoh February 29, 2024 22:27
@ManfeiBai ManfeiBai force-pushed the ManfeiBai-patch-73 branch from c35ac53 to c7f09d5 Compare March 4, 2024 17:47
&PyLoweringContext::GetParameterIdTensorMapping)
.def("tensor_parameter_id", &PyLoweringContext::GetTensorParameterId);
.def("tensor_parameter_id", &PyLoweringContext::GetTensorParameterId)
.def("set_name_string", &PyLoweringContext::SetNameString)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good, thanks!

Comment thread torch_xla/csrc/lowering_context.cpp Outdated
if (!root_tuple_.empty() & (root_tuple_.size() > 1)) {
xla::XlaOp root = xla::Tuple(builder(), root_tuple_);
xla = builder()->Build(root);
} else if (!root_tuple_.empty() & (root_tuple_.size() == 1)) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should condition on get_name_string(). Add this check at the top and build for while loop `if get_name_string() == "condctx" or get_name_string() == "bodyctx"; otherwise, you can keep the original build logic.

Have your logic for while loop build in a separate private method, and call it if ``if get_name_string() == "condctx" or get_name_string() == "bodyctx"` is true.

So you can keep BuildXla() simple.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, make sense and updated in the newest commit, due to the logic for while loop is one simple line code, we run it directly without warping it in a separate private method

Comment thread torch_xla/csrc/lowering_context.h Outdated
Comment thread test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py Outdated
@ManfeiBai ManfeiBai force-pushed the ManfeiBai-patch-73 branch from 96ce688 to 173ff44 Compare March 8, 2024 18:44
@ManfeiBai ManfeiBai merged commit 6170df5 into master Mar 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants