[Fori_loop|While_loop] Enable fori_loop with add/sub test case#6603
[Fori_loop|While_loop] Enable fori_loop with add/sub test case#6603
Conversation
7a8f9f5 to
c0d3359
Compare
| if (!root_tuple_.empty() & (root_tuple_.size() > 1)) { | ||
| xla::XlaOp root = xla::Tuple(builder(), root_tuple_); | ||
| xla = builder()->Build(root); | ||
| } else if (!root_tuple_.empty() & (root_tuple_.size() == 1)) { |
There was a problem hiding this comment.
Explain: we need to skip tuple for cond/body computation creation to match the xla::While format check for cond, error log
|
Hi, @JackCaoG, since this PR would add new function to |
|
@amithrm FYI |
|
kokoro failure should be fixed on master branch, let's skip it now |
c35ac53 to
c7f09d5
Compare
| &PyLoweringContext::GetParameterIdTensorMapping) | ||
| .def("tensor_parameter_id", &PyLoweringContext::GetTensorParameterId); | ||
| .def("tensor_parameter_id", &PyLoweringContext::GetTensorParameterId) | ||
| .def("set_name_string", &PyLoweringContext::SetNameString) |
| if (!root_tuple_.empty() & (root_tuple_.size() > 1)) { | ||
| xla::XlaOp root = xla::Tuple(builder(), root_tuple_); | ||
| xla = builder()->Build(root); | ||
| } else if (!root_tuple_.empty() & (root_tuple_.size() == 1)) { |
There was a problem hiding this comment.
I think we should condition on get_name_string(). Add this check at the top and build for while loop `if get_name_string() == "condctx" or get_name_string() == "bodyctx"; otherwise, you can keep the original build logic.
Have your logic for while loop build in a separate private method, and call it if ``if get_name_string() == "condctx" or get_name_string() == "bodyctx"` is true.
So you can keep BuildXla() simple.
There was a problem hiding this comment.
Thanks, make sense and updated in the newest commit, due to the logic for while loop is one simple line code, we run it directly without warping it in a separate private method
96ce688 to
173ff44
Compare
For fori_loop implementation with while_loop, this PR is for lowering body/cond to replace formal placeholder
This is the step two PR, and father PR(#6532), child PR(#6529), source PR(#6563)
some issue fixed:
(before)body fn is(after)tried torch.sub(a, b), passed too locally-, not a torch func, will test later(before)input are limited to list/tuple(after)this match torch._higher_order_ops.while_loop required(before)input was trans from list to not list after torch.compile, TODO, add the same logic like torch.compile to use inputs, not like currently create a duplicated tensor in the fori_loop.py file