Skip to content

Commit 7d42ddd

Browse files
sguggerstas00
authored andcommitted
Add option to save on each training node (#12421)
* Add option to save on each training node * Apply suggestions from code review Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Address review comments Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
1 parent 22bb717 commit 7d42ddd

2 files changed

Lines changed: 59 additions & 21 deletions

File tree

src/transformers/trainer.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def __init__(
393393
# Create clone of distant repo and output directory if needed
394394
if self.args.push_to_hub:
395395
self.init_git_repo()
396-
if self.is_world_process_zero():
396+
if self.args.should_save:
397397
os.makedirs(self.args.output_dir, exist_ok=True)
398398

399399
if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
@@ -899,7 +899,7 @@ def _tune_save_checkpoint(self):
899899
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
900900
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
901901
self.save_model(output_dir)
902-
if self.is_world_process_zero():
902+
if self.args.should_save:
903903
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
904904
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
905905
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
@@ -1357,10 +1357,18 @@ def train(
13571357
logger.info(
13581358
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
13591359
)
1360-
# We load the model state dict on the CPU to avoid an OOM error.
1361-
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME), map_location="cpu")
1362-
# If the model is on the GPU, it still works!
1363-
self._load_state_dict_in_model(state_dict)
1360+
1361+
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
1362+
if os.path.exists(best_model_path):
1363+
# We load the model state dict on the CPU to avoid an OOM error.
1364+
state_dict = torch.load(best_model_path, map_location="cpu")
1365+
# If the model is on the GPU, it still works!
1366+
self._load_state_dict_in_model(state_dict)
1367+
else:
1368+
logger.warn(
1369+
f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
1370+
"on multiple nodes, you should activate `--save_on_each_node`."
1371+
)
13641372

13651373
if self.deepspeed:
13661374
self.deepspeed.load_checkpoint(
@@ -1500,14 +1508,14 @@ def _save_checkpoint(self, model, trial, metrics=None):
15001508
# Consolidate the state dict on all processed of dp_rank 0
15011509
opt_state_dict = self.optimizer.state_dict()
15021510
# Save it and the scheduler on the main process
1503-
if self.is_world_process_zero():
1511+
if self.args.should_save:
15041512
torch.save(opt_state_dict, os.path.join(output_dir, "optimizer.pt"))
15051513
with warnings.catch_warnings(record=True) as caught_warnings:
15061514
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
15071515
reissue_pt_warnings(caught_warnings)
15081516
if self.use_amp:
15091517
torch.save(self.scaler.state_dict(), os.path.join(output_dir, "scaler.pt"))
1510-
elif self.is_world_process_zero() and not self.deepspeed:
1518+
elif self.args.should_save and not self.deepspeed:
15111519
# deepspeed.save_checkpoint above saves model/optim/sched
15121520
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
15131521
with warnings.catch_warnings(record=True) as caught_warnings:
@@ -1533,7 +1541,7 @@ def _save_checkpoint(self, model, trial, metrics=None):
15331541
self.state.best_model_checkpoint = output_dir
15341542

15351543
# Save the Trainer state
1536-
if self.is_world_process_zero():
1544+
if self.args.should_save:
15371545
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
15381546

15391547
# Save RNG state in non-distributed training
@@ -1562,7 +1570,7 @@ def _save_checkpoint(self, model, trial, metrics=None):
15621570
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth"))
15631571

15641572
# Maybe delete some older checkpoints.
1565-
if self.is_world_process_zero():
1573+
if self.args.should_save:
15661574
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
15671575

15681576
def _load_optimizer_and_scheduler(self, checkpoint):
@@ -1831,27 +1839,27 @@ def save_model(self, output_dir: Optional[str] = None):
18311839
elif is_sagemaker_mp_enabled():
18321840
# Calling the state_dict needs to be done on the wrapped model and on all processes.
18331841
state_dict = self.model_wrapped.state_dict()
1834-
if self.is_world_process_zero():
1842+
if self.args.should_save:
18351843
self._save(output_dir, state_dict=state_dict)
18361844
elif (
18371845
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
18381846
):
18391847
state_dict = self.model.state_dict()
18401848

1841-
if self.is_world_process_zero():
1849+
if self.args.should_save:
18421850
self._save(output_dir, state_dict=state_dict)
18431851
elif self.deepspeed:
18441852

18451853
# this takes care of everything as long as we aren't under zero3
1846-
if self.is_world_process_zero():
1854+
if self.args.should_save:
18471855
self._save(output_dir)
18481856

18491857
if is_deepspeed_zero3_enabled():
18501858
# It's too complicated to try to override different places where the weights dump gets
18511859
# saved, so since under zero3 the file is bogus, simply delete it. The user should
18521860
# either user deepspeed checkpoint to resume or to recover full weights use
18531861
# zero_to_fp32.py stored in the checkpoint.
1854-
if self.is_world_process_zero():
1862+
if self.args.should_save:
18551863
file = os.path.join(output_dir, WEIGHTS_NAME)
18561864
if os.path.isfile(file):
18571865
# logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights")
@@ -1862,7 +1870,7 @@ def save_model(self, output_dir: Optional[str] = None):
18621870
# This must be called on all ranks
18631871
self.deepspeed.save_fp16_model(output_dir, WEIGHTS_NAME)
18641872

1865-
elif self.is_world_process_zero():
1873+
elif self.args.should_save:
18661874
self._save(output_dir)
18671875

18681876
def _save_tpu(self, output_dir: Optional[str] = None):
@@ -1880,7 +1888,7 @@ def _save_tpu(self, output_dir: Optional[str] = None):
18801888
if isinstance(unwrap_model(self.model), PreTrainedModel):
18811889
unwrap_model(self.model).save_pretrained(
18821890
output_dir,
1883-
save_config=self.is_world_process_zero(),
1891+
save_config=self.args.should_save,
18841892
state_dict=self.model.state_dict(),
18851893
save_function=xm.save,
18861894
)
@@ -1889,8 +1897,8 @@ def _save_tpu(self, output_dir: Optional[str] = None):
18891897
state_dict = self.model.state_dict()
18901898
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
18911899
else:
1892-
self.model.save_pretrained(output_dir, save_config=self.is_world_process_zero(), save_function=xm.save)
1893-
if self.tokenizer is not None and self.is_world_process_zero():
1900+
self.model.save_pretrained(output_dir, save_config=self.args.should_save, save_function=xm.save)
1901+
if self.tokenizer is not None and self.args.should_save:
18941902
self.tokenizer.save_pretrained(output_dir)
18951903

18961904
def _save(self, output_dir: Optional[str] = None, state_dict=None):
@@ -1960,7 +1968,7 @@ def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
19601968
if len(checkpoints_sorted) <= self.args.save_total_limit:
19611969
return
19621970

1963-
# If save_total_limit=1 with load_best_mode_at_end=True, we could end up deleting the last checkpoint, which
1971+
# If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
19641972
# we don't do to allow resuming.
19651973
save_total_limit = self.args.save_total_limit
19661974
if (
@@ -2436,7 +2444,7 @@ def init_git_repo(self):
24362444
"""
24372445
Initializes a git repo in :obj:`self.args.push_to_hub_model_id`.
24382446
"""
2439-
if not self.is_world_process_zero():
2447+
if not self.args.should_save:
24402448
return
24412449
use_auth_token = True if self.args.push_to_hub_token is None else self.args.push_to_hub_token
24422450
repo_url = PushToHubMixin._get_repo_url_from_name(
@@ -2494,11 +2502,16 @@ def push_to_hub(self, commit_message: Optional[str] = "add model", **kwargs) ->
24942502
Returns:
24952503
The url of the commit of your model in the given repository.
24962504
"""
2497-
if not self.is_world_process_zero():
2505+
if not self.args.should_save:
24982506
return
24992507

25002508
self.create_model_card(model_name=self.args.push_to_hub_model_id, **kwargs)
25012509
self.save_model()
2510+
2511+
# Only push from one node.
2512+
if not self.is_world_process_zero():
2513+
return
2514+
25022515
return self.repo.push_to_hub(commit_message=commit_message)
25032516

25042517
#

src/transformers/training_args.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,12 @@ class TrainingArguments:
182182
save_total_limit (:obj:`int`, `optional`):
183183
If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
184184
:obj:`output_dir`.
185+
save_on_each_node (:obj:`bool`, `optional`, defaults to :obj:`False`):
186+
When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on
187+
the main one.
188+
189+
This should not be activated when the different nodes use the same storage as the files will be saved with
190+
the same names for each node.
185191
no_cuda (:obj:`bool`, `optional`, defaults to :obj:`False`):
186192
Whether to not use CUDA even when it is available or not.
187193
seed (:obj:`int`, `optional`, defaults to 42):
@@ -455,6 +461,12 @@ class TrainingArguments:
455461
)
456462
},
457463
)
464+
save_on_each_node: bool = field(
465+
default=False,
466+
metadata={
467+
"help": "When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on the main one"
468+
},
469+
)
458470
no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"})
459471
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
460472

@@ -936,6 +948,19 @@ def should_log(self):
936948
else:
937949
return self.process_index == 0
938950

951+
@property
952+
def should_save(self):
953+
"""
954+
Whether or not the current process should write to disk, e.g., to save models and checkpoints.
955+
"""
956+
if self.save_on_each_node:
957+
return self.local_process_index == 0
958+
else:
959+
if is_sagemaker_mp_enabled():
960+
return smp.rank() == 0
961+
else:
962+
return self.process_index == 0
963+
939964
def get_process_log_level(self):
940965
"""
941966
Returns the log level to be used depending on whether this process is the main process of node 0, main process

0 commit comments

Comments
 (0)