-
Notifications
You must be signed in to change notification settings - Fork 33.1k
Expand file tree
/
Copy pathtrainer.py
More file actions
executable file
·4418 lines (3824 loc) · 211 KB
/
trainer.py
File metadata and controls
executable file
·4418 lines (3824 loc) · 211 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
"""
import contextlib
import functools
import glob
import inspect
import json
import math
import os
import random
import shutil
import sys
import tempfile
import time
import warnings
from collections.abc import Callable, Iterator, Mapping
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any
# Integrations must be imported before ML frameworks:
# ruff: isort: off
from .integrations import (
get_reporting_integration_callbacks,
)
# ruff: isort: on
import numpy as np
import safetensors.torch
import torch
import torch.distributed as dist
from huggingface_hub import CommitInfo, ModelCard, create_repo, upload_folder
from packaging import version
from torch import nn
from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
from . import __version__
from .configuration_utils import PreTrainedConfig
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .debug_utils import DebugOption, DebugUnderflowOverflow
from .feature_extraction_sequence_utils import SequenceFeatureExtractor
from .feature_extraction_utils import FeatureExtractionMixin
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
from .image_processing_utils import BaseImageProcessor
from .integrations.deepspeed import (
deepspeed_init,
deepspeed_load_checkpoint,
deepspeed_sp_compute_loss,
is_deepspeed_available,
propagate_args_to_deepspeed,
)
from .integrations.fsdp import get_fsdp_ckpt_kwargs, update_fsdp_plugin_peft
from .integrations.liger import apply_liger_kernel
from .integrations.neftune import activate_neftune, deactivate_neftune
from .integrations.peft import MIN_PEFT_VERSION
from .integrations.tpu import save_tpu_checkpoint, tpu_spmd_dataloader, wrap_model_xla_fsdp
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, unwrap_model
from .models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_MAPPING_NAMES,
)
from .optimization import GreedyLR, get_scheduler
from .processing_utils import ProcessorMixin
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import (
CallbackHandler,
DefaultFlowCallback,
ExportableState,
PrinterCallback,
ProgressCallback,
TrainerCallback,
TrainerControl,
TrainerState,
)
from .trainer_optimizer import (
_OPTIMIZER_HANDLERS,
OptimizerContext,
_parse_optim_args,
is_optimizer_factory,
)
from .trainer_pt_utils import (
EvalLoopContainer,
IterableDatasetShard,
LabelSmoother,
LengthGroupedSampler,
distributed_broadcast_scalars,
find_batch_size,
get_model_param_count,
get_parameter_names,
is_attention_mask_causal,
nested_detach,
nested_gather,
reissue_pt_warnings,
remove_dummy_checkpoint,
safe_globals,
set_rng_state_for_device,
)
from .trainer_utils import (
PREFIX_CHECKPOINT_DIR,
BestRun,
EvalLoopOutput,
EvalPrediction,
HPSearchBackend,
HubStrategy,
PredictionOutput,
RemoveColumnsCollator,
SaveStrategy,
TrainerMemoryTracker,
TrainOutput,
_is_peft_model,
align_special_tokens,
compare_trainer_and_checkpoint_args,
default_compute_objective,
denumpify_detensorize,
enable_full_determinism,
find_executable_batch_size,
get_last_checkpoint,
has_length,
load_sharded_checkpoint,
number_of_arguments,
rotate_checkpoints,
seed_worker,
set_seed,
sort_checkpoints,
speed_metrics,
suppress_progress_bars,
unwrap_peft_model,
validate_quantization_for_training,
)
from .training_args import OptimizerNames, ParallelMode, TrainingArguments
from .utils import (
ADAPTER_CONFIG_NAME,
ADAPTER_SAFE_WEIGHTS_NAME,
ADAPTER_WEIGHTS_NAME,
CONFIG_NAME,
GENERATION_CONFIG_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
XLA_FSDPV2_MIN_VERSION,
PushInProgress,
can_return_loss,
check_torch_load_is_safe,
find_labels,
is_accelerate_available,
is_datasets_available,
is_in_notebook,
is_peft_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_hpu_available,
is_torch_mlu_available,
is_torch_musa_available,
is_torch_npu_available,
is_torch_xla_available,
logging,
)
from .utils.import_utils import requires
from .utils.quantization_config import QuantizationMethod
DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
if is_in_notebook():
from .utils.notebook import NotebookProgressCallback
DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
if is_datasets_available():
import datasets
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.runtime as xr
from torch_xla import __version__ as XLA_VERSION
IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION)
if IS_XLA_FSDPV2_POST_2_2:
import torch_xla.distributed.spmd as xs
else:
IS_XLA_FSDPV2_POST_2_2 = False
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_nested_concat
if is_peft_available():
from peft import PeftModel
if is_accelerate_available():
from accelerate import Accelerator, skip_first_batches
from accelerate.state import AcceleratorState
from accelerate.utils import (
DataLoaderConfiguration,
DistributedDataParallelKwargs,
DistributedType,
GradientAccumulationPlugin,
load_fsdp_model,
load_fsdp_optimizer,
release_memory,
save_fsdp_model,
save_fsdp_optimizer,
)
from accelerate.utils.memory import clear_device_cache
if is_deepspeed_available():
from accelerate.utils import DeepSpeedSchedulerWrapper
if TYPE_CHECKING:
import optuna
logger = logging.get_logger(__name__)
# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
SCALER_NAME = "scaler.pt"
OPTIMIZER_NAME_BIN = "optimizer.bin"
SCHEDULER_NAME = "scheduler.pt"
FSDP_MODEL_NAME = "pytorch_model_fsdp"
@requires(
backends=(
"torch",
"accelerate",
)
)
class Trainer:
"""
Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
Args:
model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*):
The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.
<Tip>
[`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use
your own models defined as `torch.nn.Module` as long as they work the same way as the 🤗 Transformers
models.
</Tip>
args ([`TrainingArguments`], *optional*):
The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the
`output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.
data_collator (`DataCollator`, *optional*):
The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will
default to [`default_data_collator`] if no `processing_class` is provided, an instance of
[`DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or tokenizer.
train_dataset (`torch.utils.data.Dataset` | `torch.utils.data.IterableDataset` | `datasets.Dataset`, *optional*):
The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the
`model.forward()` method are automatically removed.
Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a
distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a
`torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
sets the seed of the RNGs used.
eval_dataset (`torch.utils.data.Dataset` | dict[str, `torch.utils.data.Dataset`] | `datasets.Dataset`, *optional*):
The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
`model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
dataset prepending the dictionary key to the metric name.
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
Processing class used to process the data. If provided, will be used to automatically process the inputs
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
reuse the fine-tuned model.
model_init (`Callable[[], PreTrainedModel]`, *optional*):
A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start
from a new instance of the model as given by this function.
The function may have zero argument, or a single one containing the optuna/Ray Tune trial object, to
be able to choose different architectures according to hyperparameters (such as layer count, sizes of
inner layers, dropout probabilities etc).
compute_loss_func (`Callable`, *optional*):
A function that accepts the raw model outputs, labels, and the number of items in the entire accumulated
batch (batch_size * gradient_accumulation_steps) and returns the loss. For example, see the default [loss function](https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/trainer.py#L3618) used by [`Trainer`].
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to
`True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered
after the last eval batch to signal that the function needs to calculate and return the global summary
statistics rather than accumulating the batch-level statistics
callbacks (List of [`TrainerCallback`], *optional*):
A list of callbacks to customize the training loop. Will add those to the list of default callbacks
detailed in [here](callback).
If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
optimizer_cls_and_kwargs (`tuple[Type[torch.optim.Optimizer], dict[str, Any]]`, *optional*):
A tuple containing the optimizer class and keyword arguments to use.
Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument.
Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before initializing the Trainer.
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
A function that preprocess the logits right before caching them at each evaluation step. Must take two
tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
by this function will be reflected in the predictions received by `compute_metrics`.
Note that the labels (second parameter) will be `None` if the dataset does not have them.
Important attributes:
- **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]
subclass.
- **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,
the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner
model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.
- **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
data parallelism, this means some of the model layers are split on different GPUs).
- **place_model_on_device** -- Whether or not to automatically place the model on the device. Defaults to
`True` unless model parallel, DeepSpeed, FSDP, full fp16/bf16 eval, or SageMaker MP is active. Can be
overridden by subclassing `TrainingArguments` and overriding the `place_model_on_device` property.
- **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while
in `train`)
"""
# Those methods are not used in Trainer itself but are available as methods for external use.
from .trainer_pt_utils import (
get_learning_rates,
get_num_trainable_parameters,
get_optimizer_group,
log_metrics,
metrics_format,
save_metrics,
save_state,
)
# ---- Initialization & Validation ----
def __init__(
self,
model: PreTrainedModel | nn.Module | None = None,
args: TrainingArguments | None = None,
data_collator: DataCollator | None = None,
train_dataset: "Dataset | IterableDataset | datasets.Dataset | None" = None,
eval_dataset: "Dataset | dict[str, Dataset] | datasets.Dataset | None" = None,
processing_class: PreTrainedTokenizerBase
| BaseImageProcessor
| FeatureExtractionMixin
| ProcessorMixin
| None = None,
model_init: Callable[..., PreTrainedModel] | None = None,
compute_loss_func: Callable | None = None,
compute_metrics: Callable[[EvalPrediction], dict] | None = None,
callbacks: list[TrainerCallback] | None = None,
optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None),
optimizer_cls_and_kwargs: tuple[type[torch.optim.Optimizer], dict[str, Any]] | None = None,
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
):
# Init flow:
# 1. Args & seed – defaults, determinism
# 2. Accelerator & logging – accelerator, memory tracker, log level, device setup
# 3. Model resolution – model / model_init, Liger Kernel, quantization checks
# 4. Distributed strategy – model-parallel, FSDP, SageMaker MP flags
# 5. Device placement – move model to device, model wrapping
# 6. Model introspection – loss kwargs, label names, label smoother
# 7. Store init arguments – data, callables, optimizer, scheduler, validation
# 8. Callbacks – reporting integrations, JIT checkpoint, progress bar
# 9. Hub & output – repo init, output directory
# 10. Training state – TrainerState, TrainerControl, internal bookkeeping
# 11. Finalize – use_cache, XLA FSDPv2 mesh, memory tracker stop
# ---- 1. Args & seed --------------------------------------------------------
if args is None:
output_dir = "tmp_trainer"
logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
args = TrainingArguments(output_dir=output_dir)
self.args = args
# Seed must be set before instantiating the model when using model_init
enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
# ---- 2. Accelerator & logging ----------------------------------------------
# `create_accelerator_and_postprocess` reads self.model and self.args,
# and may set self.deepspeed — store temporary refs before calling it.
self.deepspeed = None
self.model = model
self.create_accelerator_and_postprocess()
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
self._memory_tracker.start()
log_level = args.get_process_log_level()
logging.set_verbosity(log_level)
args._setup_devices # force device and distributed setup init explicitly
# ---- 3. Model resolution ----------------------------------------------------
if model is None:
if model_init is not None:
self.model_init = model_init
model = self.call_model_init()
else:
raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
else:
if model_init is not None:
raise ValueError("`Trainer` requires either a `model` or `model_init` argument, but not both.")
self.model_init = model_init
if model.__class__.__name__ in MODEL_MAPPING_NAMES:
raise ValueError(
f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only "
"computes hidden states and does not accept any labels. You should choose a model with a head "
"suitable for your task like any of the `AutoModelForXxx` listed at "
"https://huggingface.co/docs/transformers/model_doc/auto"
)
validate_quantization_for_training(model)
# ---- 4. Distributed strategy ------------------------------------------------
self.is_model_parallel = False
if getattr(model, "hf_device_map", None) is not None:
devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]]
if len(devices) > 1:
self.is_model_parallel = True
elif len(devices) == 1:
self.is_model_parallel = self.args.device != torch.device(devices[0])
self.is_fsdp_xla_enabled = args.fsdp_config["xla"]
if len(args.fsdp) > 0:
if self.is_deepspeed_enabled:
raise ValueError(
"Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
)
if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED:
raise ValueError("Using fsdp only works in distributed training.")
# Postpone switching model to cuda when MP, DeepSpeed, full bf16/fp16 eval, or FSDP
if args.place_model_on_device is not None:
self.place_model_on_device = args.place_model_on_device
elif (
self.is_model_parallel
or self.is_deepspeed_enabled
or (args.fp16_full_eval or args.bf16_full_eval)
or self.is_fsdp_xla_enabled
or self.is_fsdp_enabled
or is_sagemaker_mp_enabled()
):
self.place_model_on_device = False
else:
self.place_model_on_device = True
# ---- 5. Device placement ----------------------------------------------------
# Bnb Quantized models don't support `.to` operation.
if (
self.place_model_on_device
and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES
):
self._move_model_to_device(model, args.device)
# Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
if self.is_model_parallel:
self.args._n_gpu = 1
# `self.model is self.model_wrapped` is used later to check if it's wrapped
self.model_wrapped = model
self.model = model
# ---- 6. Model introspection -------------------------------------------------
unwrapped_model = unwrap_peft_model(self.accelerator.unwrap_model(model))
if hasattr(unwrapped_model, "accepts_loss_kwargs"):
self.model_accepts_loss_kwargs = unwrapped_model.accepts_loss_kwargs
else:
forward_params = inspect.signature(unwrapped_model.forward).parameters
self.model_accepts_loss_kwargs = any(
k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values()
)
# Sequence Parallelism computes its own good_tokens count
pc = getattr(self.accelerator, "parallelism_config", None)
if pc is not None and pc.sp_backend == "deepspeed" and pc.sp_enabled:
self.model_accepts_loss_kwargs = False
model_to_inspect = unwrap_peft_model(self.model)
default_label_names = find_labels(model_to_inspect.__class__)
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
self.can_return_loss = can_return_loss(model_to_inspect.__class__)
if self.args.label_smoothing_factor != 0:
if getattr(self.model.config, "problem_type", None) == "multi_label_classification":
warnings.warn(
"Label smoothing is not compatible with multi-label classification. "
"Disabling label smoothing for this training run.",
UserWarning,
)
self.label_smoother = None
else:
self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
else:
self.label_smoother = None
# ---- 7. Store init arguments ------------------------------------------------
# Data
default_collator = (
DataCollatorWithPadding(processing_class)
if processing_class is not None
and isinstance(processing_class, (PreTrainedTokenizerBase, SequenceFeatureExtractor))
else default_data_collator
)
self.data_collator = data_collator if data_collator is not None else default_collator
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.processing_class = processing_class
self.neftune_noise_alpha = args.neftune_noise_alpha
# Callables
self.compute_loss_func = compute_loss_func
self.compute_metrics = compute_metrics
self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
# Optimizer & scheduler
self.optimizer, self.lr_scheduler = optimizers
self.optimizer_cls_and_kwargs = optimizer_cls_and_kwargs
self._validate_args()
# ---- 8. Callbacks -----------------------------------------------------------
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
if self.args.enable_jit_checkpoint:
from .trainer_jit_checkpoint import JITCheckpointCallback
jit_callback = JITCheckpointCallback()
default_callbacks = default_callbacks + [jit_callback]
jit_callback.set_trainer(self)
callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
self.callback_handler = CallbackHandler(
callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
)
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
# ---- 9. Hub & output ---------------------------------------------------------
self.hub_model_id = None # Set by init_hf_repo() when push_to_hub is enabled
if self.args.push_to_hub:
self.init_hf_repo()
if self.args.should_save:
os.makedirs(self.args.output_dir, exist_ok=True)
# ---- 10. Training state -----------------------------------------------------
self.control = TrainerControl()
self.state = TrainerState(
is_local_process_zero=self.is_local_process_zero(),
is_world_process_zero=self.is_world_process_zero(),
stateful_callbacks=[
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
],
)
self.is_in_train = False # True between train() entry and exit
self.hp_name = None # Set by hyperparameter_search() to label the trial
self.hp_search_backend = None # Set by hyperparameter_search() (optuna / ray / wandb)
# Per-process FLOP counter; accumulated into self.state.total_flos then reset
self.current_flos = 0
# Set True by _setup_loggers() on first call to self.log()
self._loggers_initialized = False
# Lazily filled by _set_signature_columns_if_needed(); caches model.forward param names
self._signature_columns = None
# Effective batch size; may be reduced by find_executable_batch_size
self._train_batch_size = args.train_batch_size
# Guards one-time LR scheduler creation in create_optimizer_and_scheduler
self._created_lr_scheduler = False
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
# ---- 11. Finalize -----------------------------------------------------------
if getattr(self.model, "config", None) is not None:
self.model.config.use_cache = self.args.use_cache
self.is_fsdp_xla_v2_enabled = args.fsdp_config.get("xla_fsdp_v2", False)
if self.is_fsdp_xla_v2_enabled:
if not IS_XLA_FSDPV2_POST_2_2:
raise ValueError("FSDPv2 requires `torch_xla` 2.2 or higher.")
num_devices = xr.global_runtime_device_count()
xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))
self.is_fsdp_xla_v1_enabled = self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled
self._memory_tracker.stop_and_update_metrics()
def _validate_args(self) -> None:
"""Validate constructor arguments and fail fast on incompatible combinations."""
args = self.args
# --- SageMaker Model Parallel mixed-precision validation ---
if is_sagemaker_mp_enabled():
if args.bf16:
raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")
if args.fp16 != smp.state.cfg.fp16:
logger.warning(
f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
f"but FP16 provided in trainer argument is {args.fp16}, "
f"setting to {smp.state.cfg.fp16}"
)
args.fp16 = smp.state.cfg.fp16
# --- Training-argument validations ---
if args.batch_eval_metrics and self.compute_metrics is not None:
if "compute_result" not in inspect.signature(self.compute_metrics).parameters:
raise ValueError(
"When using `batch_eval_metrics`, your `compute_metrics` function must take a `compute_result`"
" boolean argument which will be triggered after the last batch of the eval set to signal that the"
" summary statistics should be returned by the function."
)
if args.eval_strategy is not None and args.eval_strategy != "no" and self.eval_dataset is None:
raise ValueError(
f"You have set `args.eval_strategy` to {args.eval_strategy} but you didn't pass an `eval_dataset` to `Trainer`. Either set `args.eval_strategy` to `no` or pass an `eval_dataset`. "
)
if args.save_strategy == SaveStrategy.BEST or args.load_best_model_at_end:
if args.metric_for_best_model is None:
raise ValueError(
"`args.metric_for_best_model` must be provided when using 'best' save_strategy or if `args.load_best_model_at_end` is set to `True`."
)
# --- Optimizer validations ---
if self.optimizer_cls_and_kwargs is not None and self.optimizer is not None:
raise RuntimeError("Passing both `optimizers` and `optimizer_cls_and_kwargs` arguments is incompatible.")
if self.model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
raise RuntimeError(
"Passing a `model_init` is incompatible with providing the `optimizers` argument. "
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
)
if is_torch_xla_available() and self.optimizer is not None:
for param in self.model.parameters():
model_device = param.device
break
for param_group in self.optimizer.param_groups:
if len(param_group["params"]) > 0:
optimizer_device = param_group["params"][0].device
break
if model_device != optimizer_device:
raise ValueError(
"The model and the optimizer parameters are not on the same device, which probably means you"
" created an optimizer around your model **before** putting on the device and passing it to the"
" `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
" `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
)
if (self.is_fsdp_xla_enabled or self.is_fsdp_enabled) and (
self.optimizer is not None or self.lr_scheduler is not None
):
raise RuntimeError(
"Passing `optimizers` is not allowed if PyTorch FSDP is enabled. "
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
)
# --- Dataset validations ---
if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
raise TypeError("The `data_collator` should be a simple callable (function, class with `__call__`).")
if args.max_steps > 0 and args.num_train_epochs > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs")
if self.train_dataset is not None and not has_length(self.train_dataset) and args.max_steps <= 0:
raise ValueError(
"The train_dataset does not implement __len__, max_steps has to be specified. "
"The number of steps needs to be known in advance for the learning rate scheduler."
)
if self.train_dataset is not None and isinstance(self.train_dataset, torch.utils.data.IterableDataset):
logger.info(
f"The `train_sampling_strategy='{args.train_sampling_strategy}'` option is ignored when using an `IterableDataset`. "
"Samplers cannot be used with IterableDataset as they require indexed access to the dataset."
)
def _build_accelerator_args(self, **kwargs) -> dict[str, Any]:
"""Helper method to build accelerator-specific keyword arguments."""
args = {
"mixed_precision": self.args.mixed_precision,
"deepspeed_plugin": self.args.deepspeed_plugin,
}
args.update(kwargs)
if self.args.ddp_find_unused_parameters is not None:
find_unused = self.args.ddp_find_unused_parameters
elif isinstance(self.model, PreTrainedModel):
# find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
find_unused = not (self.model.is_gradient_checkpointing or self.args.gradient_checkpointing)
else:
find_unused = True
ddp_kwargs = {"find_unused_parameters": find_unused}
if self.args.ddp_bucket_cap_mb is not None:
ddp_kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
if self.args.ddp_broadcast_buffers is not None:
ddp_kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers
if self.args.ddp_static_graph is not None:
ddp_kwargs["static_graph"] = self.args.ddp_static_graph
args["kwargs_handlers"] = [DistributedDataParallelKwargs(**ddp_kwargs)]
# We defer compatibility checks to accelerator
if self.args.parallelism_config is not None:
min_accelerate_version = "1.12.0"
if not is_accelerate_available(min_accelerate_version):
raise ImportError(
f"ParallelismConfig requires accelerate>={min_accelerate_version}). Please upgrade accelerate to use this feature."
)
args["parallelism_config"] = self.args.parallelism_config
if getattr(self.model, "tp_size", None) is not None and self.model.tp_size > 1:
if self.args.parallelism_config is None:
if is_accelerate_available("1.12.0"):
if self.args.parallelism_config is None:
from accelerate import ParallelismConfig
args["parallelism_config"] = ParallelismConfig(tp_size=self.model.tp_size)
else:
raise ValueError("Requires accelerate>1.12.0 to use Tensor Parallelism.")
elif args["parallelism_config"].tp_size != self.model.tp_size:
args["parallelism_config"].tp_size = self.model.tp_size
if is_accelerate_available("1.2.0"):
# it we don't have the correct version, we will rely on env var instead that were set in TrainingArguments
from accelerate.utils import TorchDynamoPlugin
dynamo_plugin = TorchDynamoPlugin(
backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode
)
args["dynamo_plugin"] = dynamo_plugin
return args
def create_accelerator_and_postprocess(self) -> None:
"""Create the accelerator and perform post-creation setup (FSDP, DeepSpeed, etc.)."""
# We explicitly don't rely on the `Accelerator` to do gradient accumulation
grad_acc_kwargs = {}
if self.args.accelerator_config.gradient_accumulation_kwargs is not None:
grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs
# check if num_steps is attempted to be passed in gradient_accumulation_kwargs
if "num_steps" in grad_acc_kwargs:
if self.args.gradient_accumulation_steps > 1:
# raise because we do not know which setting is intended.
raise ValueError(
"The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
"If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
)
else:
self.args.gradient_accumulation_steps = grad_acc_kwargs["num_steps"]
# The Trainer handles GAS itself, so GAS=1 in Accelerate to avoid any double-division
grad_acc_kwargs["num_steps"] = 1
# Just making sure that gradient_state have the correct values passed.
# We don't rely on `accumulate` from accelerate to set sync_gradients in gradient_state.
# Rather, we do it ourselves by setting self.accelerator.gradient_state._set_sync_gradients.
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
accelerator_config = self.args.accelerator_config.to_dict()
# Extract dataloader config params from accelerator config
dataloader_params = ["split_batches", "dispatch_batches", "even_batches", "use_seedable_sampler"]
dataloader_config = DataLoaderConfiguration(
**{param: accelerator_config.pop(param) for param in dataloader_params}
)
dataloader_config.data_seed = self.args.data_seed
non_blocking = accelerator_config.pop("non_blocking")
if non_blocking and not self.args.dataloader_pin_memory:
logger.warning(
"`non_blocking` is enabled but `dataloader_pin_memory` is not. For the best performance, it's recommended to enable both."
)
dataloader_config.non_blocking = non_blocking
# this would have been updated above, no need for it anymore
accelerator_config.pop("gradient_accumulation_kwargs")
fsdp_plugin = None
if self.args.fsdp_plugin_args is not None:
from accelerate.utils import FullyShardedDataParallelPlugin
fsdp_plugin = FullyShardedDataParallelPlugin(**self.args.fsdp_plugin_args)
args = self._build_accelerator_args(
dataloader_config=dataloader_config,
fsdp_plugin=fsdp_plugin,
gradient_accumulation_plugin=gradient_accumulation_plugin,
)
# create accelerator object
self.accelerator = Accelerator(**args)
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
self.gather_function = self.accelerator.gather_for_metrics
if "use_gather_object" in inspect.signature(self.gather_function).parameters:
self.gather_function = functools.partial(
self.gather_function, use_gather_object=self.args.eval_use_gather_object
)
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
# post accelerator creation setup
if self.is_fsdp_enabled:
fsdp_plugin = self.accelerator.state.fsdp_plugin
for param in ["limit_all_gathers", "activation_checkpointing"]:
setattr(fsdp_plugin, param, self.args.fsdp_config.get(param, getattr(fsdp_plugin, param)))
if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing:
raise ValueError(
"The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg "
"can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic "
"when using FSDP."
)
if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None:
propagate_args_to_deepspeed(self.accelerator, self.args)
# `save_only_model` can't be used with DeepSpeed/FSDP along with `load_best_model_at_end`
if (
self.args.save_only_model
and (self.is_deepspeed_enabled or self.is_fsdp_enabled)
and self.args.load_best_model_at_end
):
wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP"
raise ValueError(f"{wrapper} can't be used with `save_only_model` along with `load_best_model_at_end`.")
# `auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3
if (
self.is_deepspeed_enabled
and self.accelerator.state.deepspeed_plugin.zero_stage == 3
and self.args.auto_find_batch_size
):
raise ValueError(
"`auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3. Please consider using Zero-2, Zero-1, or FSDP"
)
if (
self.args.save_only_model
and self.is_fsdp_enabled
and "SHARDED_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type)
):
raise ValueError("save_only_model option is not compatible with FSDP state dict type 'SHARDED_STATE_DICT'")
# ---- Data Loading ----
def get_train_dataloader(self) -> DataLoader:
"""
Returns the training [`~torch.utils.data.DataLoader`].
Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
training if necessary) otherwise.
Subclass and override this method if you want to inject some custom behavior.
"""
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
return self._get_dataloader(
dataset=self.train_dataset,
description="Training",
batch_size=self._train_batch_size,
sampler_fn=self._get_train_sampler,
is_training=True,
)
def get_eval_dataloader(self, eval_dataset: str | Dataset | None = None) -> DataLoader:
"""
Returns the evaluation [`~torch.utils.data.DataLoader`].
Subclass and override this method if you want to inject some custom behavior.
Args:
eval_dataset (`str` or `torch.utils.data.Dataset`, *optional*):
If a `str`, will use `self.eval_dataset[eval_dataset]` as the evaluation dataset. If a `Dataset`, will override `self.eval_dataset` and must implement `__len__`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed.
"""
if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
# If we have persistent workers, don't do a fork bomb especially as eval datasets
# don't change during training
dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
if (
hasattr(self, "_eval_dataloaders")
and dataloader_key in self._eval_dataloaders
and self.args.dataloader_persistent_workers
):
return self._eval_dataloaders[dataloader_key]
eval_dataset = (
self.eval_dataset[eval_dataset]
if isinstance(eval_dataset, str)
else eval_dataset
if eval_dataset is not None
else self.eval_dataset
)
return self._get_dataloader(
dataset=eval_dataset,
description="Evaluation",
batch_size=self.args.eval_batch_size,
sampler_fn=self._get_eval_sampler,
dataloader_key=dataloader_key,
)
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
"""
Returns the test [`~torch.utils.data.DataLoader`].
Subclass and override this method if you want to inject some custom behavior.
Args:
test_dataset (`torch.utils.data.Dataset`, *optional*):
The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
`model.forward()` method are automatically removed. It must implement `__len__`.
"""
return self._get_dataloader(
dataset=test_dataset,
description="test",
batch_size=self.args.eval_batch_size,
sampler_fn=self._get_eval_sampler,
)
def num_examples(self, dataloader: DataLoader) -> int:
"""
Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When
dataloader.dataset does not exist or has no length, estimates as best it can
"""
try:
dataset = dataloader.dataset
# Special case for IterableDatasetShard, we need to dig deeper
if isinstance(dataset, IterableDatasetShard):
return len(dataloader.dataset.dataset)
return len(dataloader.dataset)
except (NameError, AttributeError, TypeError): # no dataset or length, estimate by length of dataloader
return len(dataloader) * self.args.per_device_train_batch_size
def _get_dataloader(
self,
dataset: Dataset,
description: str,
batch_size: int,
sampler_fn: Callable[[Dataset], torch.utils.data.Sampler] | None = None,
is_training: bool = False,
dataloader_key: str | None = None,
) -> DataLoader:
"""Create a [`~torch.utils.data.DataLoader`] from the given dataset."""
data_collator = self.data_collator
if is_datasets_available() and isinstance(dataset, datasets.Dataset):
dataset = self._remove_unused_columns(dataset, description=description)
else:
data_collator = self._get_collator_with_removed_columns(self.data_collator, description=description)
# MPS requrires forking if multiple workers are specified
should_fork = torch.backends.mps.is_available() and self.args.dataloader_num_workers > 1
dataloader_params = {
"batch_size": batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
"multiprocessing_context": "fork" if should_fork else None,
}
if not isinstance(dataset, torch.utils.data.IterableDataset):
if sampler_fn is not None:
dataloader_params["sampler"] = sampler_fn(dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
if is_training:
dataloader_params["worker_init_fn"] = partial(
seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index
)
dataloader = self.accelerator.prepare(DataLoader(dataset, **dataloader_params))
# Store the prepared dataloader for subsequent evaluations if using persistent workers.
if dataloader_key is not None and self.args.dataloader_persistent_workers:
if hasattr(self, "_eval_dataloaders"):
self._eval_dataloaders[dataloader_key] = dataloader