Support variadic returns in Schema's operator<<#23204
Support variadic returns in Schema's operator<<#23204houseroad wants to merge 6 commits intogh/houseroad/7/basefrom
Conversation
old: prim::PythonOp(...) -> new: prim::PythonOp(...) -> ... Differential Revision: [D16433592](https://our.internmc.facebook.com/intern/diff/D16433592/)
old: prim::PythonOp(...) -> new: prim::PythonOp(...) -> ... Differential Revision: [D16433592](https://our.internmc.facebook.com/intern/diff/D16433592/)
old: prim::PythonOp(...) -> new: prim::PythonOp(...) -> ... Differential Revision: [D16433592](https://our.internmc.facebook.com/intern/diff/D16433592/)
old: prim::PythonOp(...) -> new: prim::PythonOp(...) -> ... Differential Revision: [D16433592](https://our.internmc.facebook.com/intern/diff/D16433592/)
| } else if (schema.returns().size() > 1) { | ||
|
|
||
| const auto& returns = schema.returns(); | ||
| if (schema.is_varret()) { |
There was a problem hiding this comment.
if varret is analogous to vararg, it should be appended after the previous args. So it'd be nice to do so for generality
There was a problem hiding this comment.
sure, let's make it nice :-)
old: prim::PythonOp(...) -> new: prim::PythonOp(...) -> ... Differential Revision: [D16433592](https://our.internmc.facebook.com/intern/diff/D16433592/)
| out << schema.returns()[i].type()->str(); | ||
|
|
||
| const auto& returns = schema.returns(); | ||
| out << "("; |
There was a problem hiding this comment.
after some digging, I think this diff is fine with existing function schema parser. Here is the reasoning:
-
when parsing the results, we automatically remove the outer pair of parentheses. related code This code also tells us multiple returns should be something like
(int, float), one example isaten::divmod(int x, int y) -> (int, int)defined here -
since we use the outer pair parentheses to represent multi-returns, to return a tuple, we should use two pairs of parentheses. My updates on the expect files are actually correct, because the returns are all single tuples, one example
The test in #23208 already make sure that all the serialized schemas can be correctly parsed (i.e., parsed schemas are exactly equal to the original ones.)
There was a problem hiding this comment.
also executed the following code:
s1 = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> (Tensor, Tensor)')
s2 = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> ((Tensor, Tensor))')
print(len(s1.returns))
print(len(s2.returns))
output is
2
1
this also shows that, for tuple, we should use '(())'
There was a problem hiding this comment.
Yeah, existing parsing is a bit non-pythonic (as python would treat it as a tuple). However, I think we can land it avoid changing existing parser and schemas
old: prim::PythonOp(...) -> new: prim::PythonOp(...) -> ... Differential Revision: [D16433592](https://our.internmc.facebook.com/intern/diff/D16433592/)
Summary: old: prim::PythonOp(...) -> new: prim::PythonOp(...) -> ... Pull Request resolved: pytorch/pytorch#23204 ghstack-source-id: 87208343 Reviewed By: zrphercule Differential Revision: D16433592 fbshipit-source-id: 36cbb329188f112e09c3b1708a8090781b830dfe
|
This pull request has been merged in c8c5e11. |
Stack from ghstack:
old: prim::PythonOp(...) ->
new: prim::PythonOp(...) -> ...
Differential Revision: D16433592