@@ -393,7 +393,7 @@ def __init__(
393393 # Create clone of distant repo and output directory if needed
394394 if self .args .push_to_hub :
395395 self .init_git_repo ()
396- if self .is_world_process_zero () :
396+ if self .args . should_save :
397397 os .makedirs (self .args .output_dir , exist_ok = True )
398398
399399 if not callable (self .data_collator ) and callable (getattr (self .data_collator , "collate_batch" , None )):
@@ -899,7 +899,7 @@ def _tune_save_checkpoint(self):
899899 with tune .checkpoint_dir (step = self .state .global_step ) as checkpoint_dir :
900900 output_dir = os .path .join (checkpoint_dir , f"{ PREFIX_CHECKPOINT_DIR } -{ self .state .global_step } " )
901901 self .save_model (output_dir )
902- if self .is_world_process_zero () :
902+ if self .args . should_save :
903903 self .state .save_to_json (os .path .join (output_dir , "trainer_state.json" ))
904904 torch .save (self .optimizer .state_dict (), os .path .join (output_dir , "optimizer.pt" ))
905905 torch .save (self .lr_scheduler .state_dict (), os .path .join (output_dir , "scheduler.pt" ))
@@ -1357,10 +1357,18 @@ def train(
13571357 logger .info (
13581358 f"Loading best model from { self .state .best_model_checkpoint } (score: { self .state .best_metric } )."
13591359 )
1360- # We load the model state dict on the CPU to avoid an OOM error.
1361- state_dict = torch .load (os .path .join (self .state .best_model_checkpoint , WEIGHTS_NAME ), map_location = "cpu" )
1362- # If the model is on the GPU, it still works!
1363- self ._load_state_dict_in_model (state_dict )
1360+
1361+ best_model_path = os .path .join (self .state .best_model_checkpoint , WEIGHTS_NAME )
1362+ if os .path .exists (best_model_path ):
1363+ # We load the model state dict on the CPU to avoid an OOM error.
1364+ state_dict = torch .load (best_model_path , map_location = "cpu" )
1365+ # If the model is on the GPU, it still works!
1366+ self ._load_state_dict_in_model (state_dict )
1367+ else :
1368+ logger .warn (
1369+ f"Could not locate the best model at { best_model_path } , if you are running a distributed training "
1370+ "on multiple nodes, you should activate `--save_on_each_node`."
1371+ )
13641372
13651373 if self .deepspeed :
13661374 self .deepspeed .load_checkpoint (
@@ -1500,14 +1508,14 @@ def _save_checkpoint(self, model, trial, metrics=None):
15001508 # Consolidate the state dict on all processed of dp_rank 0
15011509 opt_state_dict = self .optimizer .state_dict ()
15021510 # Save it and the scheduler on the main process
1503- if self .is_world_process_zero () :
1511+ if self .args . should_save :
15041512 torch .save (opt_state_dict , os .path .join (output_dir , "optimizer.pt" ))
15051513 with warnings .catch_warnings (record = True ) as caught_warnings :
15061514 torch .save (self .lr_scheduler .state_dict (), os .path .join (output_dir , "scheduler.pt" ))
15071515 reissue_pt_warnings (caught_warnings )
15081516 if self .use_amp :
15091517 torch .save (self .scaler .state_dict (), os .path .join (output_dir , "scaler.pt" ))
1510- elif self .is_world_process_zero () and not self .deepspeed :
1518+ elif self .args . should_save and not self .deepspeed :
15111519 # deepspeed.save_checkpoint above saves model/optim/sched
15121520 torch .save (self .optimizer .state_dict (), os .path .join (output_dir , "optimizer.pt" ))
15131521 with warnings .catch_warnings (record = True ) as caught_warnings :
@@ -1533,7 +1541,7 @@ def _save_checkpoint(self, model, trial, metrics=None):
15331541 self .state .best_model_checkpoint = output_dir
15341542
15351543 # Save the Trainer state
1536- if self .is_world_process_zero () :
1544+ if self .args . should_save :
15371545 self .state .save_to_json (os .path .join (output_dir , "trainer_state.json" ))
15381546
15391547 # Save RNG state in non-distributed training
@@ -1562,7 +1570,7 @@ def _save_checkpoint(self, model, trial, metrics=None):
15621570 torch .save (rng_states , os .path .join (output_dir , f"rng_state_{ local_rank } .pth" ))
15631571
15641572 # Maybe delete some older checkpoints.
1565- if self .is_world_process_zero () :
1573+ if self .args . should_save :
15661574 self ._rotate_checkpoints (use_mtime = True , output_dir = run_dir )
15671575
15681576 def _load_optimizer_and_scheduler (self , checkpoint ):
@@ -1831,27 +1839,27 @@ def save_model(self, output_dir: Optional[str] = None):
18311839 elif is_sagemaker_mp_enabled ():
18321840 # Calling the state_dict needs to be done on the wrapped model and on all processes.
18331841 state_dict = self .model_wrapped .state_dict ()
1834- if self .is_world_process_zero () :
1842+ if self .args . should_save :
18351843 self ._save (output_dir , state_dict = state_dict )
18361844 elif (
18371845 ShardedDDPOption .ZERO_DP_2 in self .args .sharded_ddp or ShardedDDPOption .ZERO_DP_3 in self .args .sharded_ddp
18381846 ):
18391847 state_dict = self .model .state_dict ()
18401848
1841- if self .is_world_process_zero () :
1849+ if self .args . should_save :
18421850 self ._save (output_dir , state_dict = state_dict )
18431851 elif self .deepspeed :
18441852
18451853 # this takes care of everything as long as we aren't under zero3
1846- if self .is_world_process_zero () :
1854+ if self .args . should_save :
18471855 self ._save (output_dir )
18481856
18491857 if is_deepspeed_zero3_enabled ():
18501858 # It's too complicated to try to override different places where the weights dump gets
18511859 # saved, so since under zero3 the file is bogus, simply delete it. The user should
18521860 # either user deepspeed checkpoint to resume or to recover full weights use
18531861 # zero_to_fp32.py stored in the checkpoint.
1854- if self .is_world_process_zero () :
1862+ if self .args . should_save :
18551863 file = os .path .join (output_dir , WEIGHTS_NAME )
18561864 if os .path .isfile (file ):
18571865 # logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights")
@@ -1862,7 +1870,7 @@ def save_model(self, output_dir: Optional[str] = None):
18621870 # This must be called on all ranks
18631871 self .deepspeed .save_fp16_model (output_dir , WEIGHTS_NAME )
18641872
1865- elif self .is_world_process_zero () :
1873+ elif self .args . should_save :
18661874 self ._save (output_dir )
18671875
18681876 def _save_tpu (self , output_dir : Optional [str ] = None ):
@@ -1880,7 +1888,7 @@ def _save_tpu(self, output_dir: Optional[str] = None):
18801888 if isinstance (unwrap_model (self .model ), PreTrainedModel ):
18811889 unwrap_model (self .model ).save_pretrained (
18821890 output_dir ,
1883- save_config = self .is_world_process_zero () ,
1891+ save_config = self .args . should_save ,
18841892 state_dict = self .model .state_dict (),
18851893 save_function = xm .save ,
18861894 )
@@ -1889,8 +1897,8 @@ def _save_tpu(self, output_dir: Optional[str] = None):
18891897 state_dict = self .model .state_dict ()
18901898 xm .save (state_dict , os .path .join (output_dir , WEIGHTS_NAME ))
18911899 else :
1892- self .model .save_pretrained (output_dir , save_config = self .is_world_process_zero () , save_function = xm .save )
1893- if self .tokenizer is not None and self .is_world_process_zero () :
1900+ self .model .save_pretrained (output_dir , save_config = self .args . should_save , save_function = xm .save )
1901+ if self .tokenizer is not None and self .args . should_save :
18941902 self .tokenizer .save_pretrained (output_dir )
18951903
18961904 def _save (self , output_dir : Optional [str ] = None , state_dict = None ):
@@ -1960,7 +1968,7 @@ def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
19601968 if len (checkpoints_sorted ) <= self .args .save_total_limit :
19611969 return
19621970
1963- # If save_total_limit=1 with load_best_mode_at_end =True, we could end up deleting the last checkpoint, which
1971+ # If save_total_limit=1 with load_best_model_at_end =True, we could end up deleting the last checkpoint, which
19641972 # we don't do to allow resuming.
19651973 save_total_limit = self .args .save_total_limit
19661974 if (
@@ -2436,7 +2444,7 @@ def init_git_repo(self):
24362444 """
24372445 Initializes a git repo in :obj:`self.args.push_to_hub_model_id`.
24382446 """
2439- if not self .is_world_process_zero () :
2447+ if not self .args . should_save :
24402448 return
24412449 use_auth_token = True if self .args .push_to_hub_token is None else self .args .push_to_hub_token
24422450 repo_url = PushToHubMixin ._get_repo_url_from_name (
@@ -2494,11 +2502,16 @@ def push_to_hub(self, commit_message: Optional[str] = "add model", **kwargs) ->
24942502 Returns:
24952503 The url of the commit of your model in the given repository.
24962504 """
2497- if not self .is_world_process_zero () :
2505+ if not self .args . should_save :
24982506 return
24992507
25002508 self .create_model_card (model_name = self .args .push_to_hub_model_id , ** kwargs )
25012509 self .save_model ()
2510+
2511+ # Only push from one node.
2512+ if not self .is_world_process_zero ():
2513+ return
2514+
25022515 return self .repo .push_to_hub (commit_message = commit_message )
25032516
25042517 #
0 commit comments