Currently, we rely on `AllGatherGrad` to compute gather for GPUs. TODO: - [] Extend this class to support TPU - [] Add tests
Currently, we rely on
AllGatherGradto compute gather for GPUs.TODO: