I am using CheckpointManager to save and load checkpoints for models with SPMD.
In my training setup, in the first epoch, the model is trained with lossA only.
After the first epoch, the model is trained using lossA and lossB.
There is no problem while loading checkpoints in the later epochs.
But specifically while loading a checkpoint saved after the first epoch I get the error below.
File "/home/THEUSER/17_featureMatching/run_train_TPU.py", line 101, in <module>
training_results = train_1_1_each_sample_in_single_batch_TPU_spmd_train_and_val(
File "/home/THEUSER/17_featureMatching/process/train_1_1_each_sample_in_single_batch_TPU_spmd.py", line 81, in train_and_val
loss_checkpoint, proc_time_checkpoint = checkpoint.load_checkpoint( config, device, model, optimizer, chkpt_mgr=chkpt_mgr)
File "/home/THEUSER/17_featureMatching/checkpoint.py", line 100, in load_checkpoint
chkpt_mgr.restore( max(tracked_steps), checkpoint )
File "/home/THEUSER/env3_8/lib/python3.8/site-packages/torch_xla/experimental/distributed_checkpoint/manager.py", line 319, in restore
dist_cp.load_state_dict(
File "/home/THEUSER/env3_8/lib/python3.8/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 31, in load_state_dict
return _load_state_dict(
File "/home/THEUSER/env3_8/lib/python3.8/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 185, in _load_state_dict
central_plan = distW.reduce_scatter("plan", local_step, global_step)
File "/home/THEUSER/env3_8/lib/python3.8/site-packages/torch/distributed/checkpoint/utils.py", line 185, in reduce_scatter
raise result
torch.distributed.checkpoint.api.CheckpointException: CheckpointException ranks:dict_keys([0])
Traceback (most recent call last): (RANK 0)
File "/home/THEUSER/env3_8/lib/python3.8/site-packages/torch/distributed/checkpoint/utils.py", line 158, in reduce_scatter
local_data = map_fun()
File "/home/THEUSER/env3_8/lib/python3.8/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 175, in local_step
local_plan = planner.create_local_plan()
File "/home/THEUSER/env3_8/lib/python3.8/site-packages/torch_xla/experimental/distributed_checkpoint/planners.py", line 228, in create_local_plan
plan = create_default_local_load_plan(self.unsharded_state_dict,
File "/home/THEUSER/env3_8/lib/python3.8/site-packages/torch/distributed/checkpoint/default_planner.py", line 228, in create_default_local_load_plan
md = metadata.state_dict_metadata[fqn]
KeyError: 'optimizer_state_dict'
To generate lossB on the computation graph, no extra parameter is generated.
The checkpoint saving and loading operations work well with GPU, TPU single-core, and TPU multi-core runs. The only problem occurs while loading models and optimizers with TPU SPMD runs.
I can provide a repo if needed.
Hello
I am using CheckpointManager to save and load checkpoints for models with SPMD.
In my training setup, in the first epoch, the model is trained with lossA only.
After the first epoch, the model is trained using lossA and lossB.
There is no problem while loading checkpoints in the later epochs.
But specifically while loading a checkpoint saved after the first epoch I get the error below.
To generate lossB on the computation graph, no extra parameter is generated.
The checkpoint saving and loading operations work well with GPU, TPU single-core, and TPU multi-core runs. The only problem occurs while loading models and optimizers with TPU SPMD runs.
I can provide a repo if needed.