Add aten::all_reduce with meta impl#93109
Add aten::all_reduce with meta impl#93109wconstab wants to merge 8 commits intogh/wconstab/77/basefrom
Conversation
[ghstack-poisoned]
|
|
||
| # Collectives | ||
| # TODO: add reduce_op and add some form of ranks instead of processgroup obj | ||
| - func: all_reduce(Tensor self, int group_id, str reduce_op) -> Tensor |
There was a problem hiding this comment.
quick question: would this group_id be just a single process group id across all ranks? for example we have two pgs, and we do all_reduce on two pgs together, would this group_id be different across two pgs? If this is the case, I feel this make it not SPMD and I'm wondering we should make this SPMD instead.
There was a problem hiding this comment.
don't look at this part too closely. we've moved on to a new API proposal with list[rank] but I didn't bother to update this stack until we converged on the design
[ghstack-poisoned]
[ghstack-poisoned]
| result: auto_linear | ||
|
|
||
| - name: all_reduce(Tensor self, int group_id, str reduce_op) -> Tensor | ||
| self: at::ones_like(self) |
There was a problem hiding this comment.
That sounds wrong!
This should be a all_scatter() no?
If you don't want it to be differentiable right now, mark it as such with non_differentiable (see doc at the top of this file for details).
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Stack from ghstack (oldest at bottom):