Skip to content

Pytorch XLA SPMD CheckpointManager Optimizer Loading Error #6546

@mfatih7

Description

@mfatih7

Hello

I am using CheckpointManager to save and load checkpoints for models with SPMD.
In my training setup, in the first epoch, the model is trained with lossA only.
After the first epoch, the model is trained using lossA and lossB.
There is no problem while loading checkpoints in the later epochs.
But specifically while loading a checkpoint saved after the first epoch I get the error below.

  File "/home/THEUSER/17_featureMatching/run_train_TPU.py", line 101, in <module>
    training_results = train_1_1_each_sample_in_single_batch_TPU_spmd_train_and_val(   
  File "/home/THEUSER/17_featureMatching/process/train_1_1_each_sample_in_single_batch_TPU_spmd.py", line 81, in train_and_val
    loss_checkpoint, proc_time_checkpoint = checkpoint.load_checkpoint( config, device, model, optimizer, chkpt_mgr=chkpt_mgr)    
  File "/home/THEUSER/17_featureMatching/checkpoint.py", line 100, in load_checkpoint
    chkpt_mgr.restore( max(tracked_steps), checkpoint )
  File "/home/THEUSER/env3_8/lib/python3.8/site-packages/torch_xla/experimental/distributed_checkpoint/manager.py", line 319, in restore
    dist_cp.load_state_dict(
  File "/home/THEUSER/env3_8/lib/python3.8/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 31, in load_state_dict
    return _load_state_dict(
  File "/home/THEUSER/env3_8/lib/python3.8/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 185, in _load_state_dict
    central_plan = distW.reduce_scatter("plan", local_step, global_step)
  File "/home/THEUSER/env3_8/lib/python3.8/site-packages/torch/distributed/checkpoint/utils.py", line 185, in reduce_scatter
    raise result
torch.distributed.checkpoint.api.CheckpointException: CheckpointException ranks:dict_keys([0])
Traceback (most recent call last): (RANK 0)
  File "/home/THEUSER/env3_8/lib/python3.8/site-packages/torch/distributed/checkpoint/utils.py", line 158, in reduce_scatter
    local_data = map_fun()
  File "/home/THEUSER/env3_8/lib/python3.8/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 175, in local_step
    local_plan = planner.create_local_plan()
  File "/home/THEUSER/env3_8/lib/python3.8/site-packages/torch_xla/experimental/distributed_checkpoint/planners.py", line 228, in create_local_plan
    plan = create_default_local_load_plan(self.unsharded_state_dict,
  File "/home/THEUSER/env3_8/lib/python3.8/site-packages/torch/distributed/checkpoint/default_planner.py", line 228, in create_default_local_load_plan
    md = metadata.state_dict_metadata[fqn]
KeyError: 'optimizer_state_dict'

To generate lossB on the computation graph, no extra parameter is generated.
The checkpoint saving and loading operations work well with GPU, TPU single-core, and TPU multi-core runs. The only problem occurs while loading models and optimizers with TPU SPMD runs.

I can provide a repo if needed.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions