[SPMD] Support manual all-reduce#7576
Conversation
JackCaoG
left a comment
There was a problem hiding this comment.
approve to unblock, but I think we should fix the tensor method name
| } | ||
| } | ||
|
|
||
| XLATensorPtr all_reduce(const XLATensorPtr& input, AllReduceType reduce_type, |
There was a problem hiding this comment.
can you call it all_reduce _no_token, the only difference in signature is it does not take pin_layout but the main difference in the op is that it does not set token.. It is better to reflect that in the name.
There was a problem hiding this comment.
Sure. I can follow up with that.
|
for array support do you plan to call |
I don't think that's necessary. I'm thinking the compiler should be smart enough to fuse all-reduces if the fusion is necessary. |
|
Thanks Jack for approving. |
Summary:
This is to add manual all-reduce support to SPMD and it currently only supports one input tensor. For array support, we can do that in python layer instead.
Test Plan:
python ./test/spmd/test_xla_sharding.py -v -k test_spmd_all_reduce