Referring to the "SPMD user guide" here: https://github.com/pytorch/xla/blob/535d398b9a7d2952d04abe6395307897352664d2/docs/spmd.md
What's the right approach to use the torch.distributed.checkpoint (and the distributed checkpoint manager) on Cloud TPUs?
I understand I should call init_process_group to use the torch.distributed.* functions. When I use:
xr.use_spmd()
dist.init_process_group("xla", init_method="pjrt://")
I run into this broken assertion: AssertionError: XLA backend is not supported with SPMD. Please use a CPU process group instead.
Do I need to init a CPU process group (with rank, world size etc) in "parallel" with how I normally use torch_xla.core.xla_model for distribution on TPUs?
Referring to the "SPMD user guide" here: https://github.com/pytorch/xla/blob/535d398b9a7d2952d04abe6395307897352664d2/docs/spmd.md
What's the right approach to use the
torch.distributed.checkpoint(and the distributed checkpoint manager) on Cloud TPUs?I understand I should call
init_process_groupto use thetorch.distributed.*functions. When I use:I run into this broken assertion:
AssertionError: XLA backend is not supported with SPMD. Please use a CPU process group instead.Do I need to init a CPU process group (with rank, world size etc) in "parallel" with how I normally use
torch_xla.core.xla_modelfor distribution on TPUs?