🚀 Feature
TensorPipeRpcAgent should implicitly use the reverse device mapping specified in "set_device_map" for the backward pass.
Motivation
In the current implementation, user's need to specify the forward and backward pass mapping themselves as follows to run the backward pass appropriately:
# Forward pass.
options.set_device_map(dst, {self.rank: (self.rank + 1) % self.world_size})
# Backward pass.
reverse_rank = (self.rank - 1 + self.world_size) % self.world_size
options.set_device_map(worker_name(reverse_rank), {self.rank: reverse_rank})
This isn't a good user experience since it would be hard for users to predict which all nodes are involved in a backward pass and how to set up the backward pass device mapping. Instead, it would be ideal if the distributed autograd engine remembers the forward mapping (as part of the send-recv nodes) and uses the inverse of that during the backward pass.
cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @rohan-varma @xush6528 @jjlilley @osalpekar @jiayisuse @lw @beauby
🚀 Feature
TensorPipeRpcAgent should implicitly use the reverse device mapping specified in "set_device_map" for the backward pass.
Motivation
In the current implementation, user's need to specify the forward and backward pass mapping themselves as follows to run the backward pass appropriately:
This isn't a good user experience since it would be hard for users to predict which all nodes are involved in a backward pass and how to set up the backward pass device mapping. Instead, it would be ideal if the distributed autograd engine remembers the forward mapping (as part of the send-recv nodes) and uses the inverse of that during the backward pass.
cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @rohan-varma @xush6528 @jjlilley @osalpekar @jiayisuse @lw @beauby