🚀 Feature
PyTorch/XLA xs.mark_sharding is an in-place operation that adds sharding annotation to an XLA tensor. However, gradients to be applied to the tensor are not annotated with sharding annotations.
Motivation
In some cases, GSPMD fails to propagate sharding annotation from the tensor to its gradient. It's useful to shard both tensor and its gradient with the same sharding annotation.
Pitch
We could write a torch.autograd.Function implementation to do this.
Additional context
JAX mark_sharding shards the gradients too.
cc @bhavya01
🚀 Feature
PyTorch/XLA
xs.mark_shardingis an in-place operation that adds sharding annotation to an XLA tensor. However, gradients to be applied to the tensor are not annotated with sharding annotations.Motivation
In some cases, GSPMD fails to propagate sharding annotation from the tensor to its gradient. It's useful to shard both tensor and its gradient with the same sharding annotation.
Pitch
We could write a
torch.autograd.Functionimplementation to do this.Additional context
JAX
mark_shardingshards the gradients too.cc @bhavya01