Skip to content

CUDA OOM at self.optimizer.consolidate_state_dict() in Trainer when using sharded_ddp  #14542

@yana-xuyan

Description

@yana-xuyan

Environment info

  • transformers version: 4.12.3
  • Platform: Linux-5.4.0-1057-aws-x86_64-with-debian-buster-sid
  • Python version: 3.7.10
  • PyTorch version (GPU?): 1.7.1 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: 8 GPUs
  • Using distributed or parallel set-up in script?: sharded_ddp (fairscale 0.4.2)

Who can help

@sgugger

Information

Model I am using (Bert, XLNet ...): BART-base

The problem arises when using:

The tasks I am working on is:

  • my own task or dataset: (give details below)
  • I'm using wikipedia corpus.

To reproduce

Steps to reproduce the behavior:

  1. run the script run_mlm.py(https://github.com/huggingface/transformers/blob/v4.12.3/examples/pytorch/language-modeling/run_mlm.py)
  2. run the script with the following command line
python -m torch.distributed.launch --nproc_per_node=8 --master_port=10000 run_mlm.py \
    --model_name_or_path roberta-base \
    --dataset_name wikitext \
    --dataset_config_name wikitext-2-raw-v1 \
    --do_train \
    --do_eval \
    --cache_dir /tmp/test-mlm \
    --output_dir /tmp/test-mlm \
    --sharded_ddp simple \
    --overwrite_output_dir \
    --per_device_train_batch_size 2 \
    --per_device_eval_batch_size 4

Traceback (most recent call last):
File "run_mlm.py", line 538, in
main()
File "run_mlm.py", line 487, in main
train_result = trainer.train(resume_from_checkpoint=checkpoint)
File "/home/ubuntu/anaconda3/envs/pytorch_p37/lib/python3.7/site-packages/transformers/trainer.py", line 1383, in train
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
File "/home/ubuntu/anaconda3/envs/pytorch_p37/lib/python3.7/site-packages/transformers/trainer.py", line 1495, in _maybe_log_save_evaluate
self._save_checkpoint(model, trial, metrics=metrics)
File "/home/ubuntu/anaconda3/envs/pytorch_p37/lib/python3.7/site-packages/transformers/trainer.py", line 1565, in _save_checkpoint
self.optimizer.consolidate_state_dict()
File "/home/ubuntu/anaconda3/envs/pytorch_p37/lib/python3.7/site-packages/fairscale/optim/oss.py", line 358, in consolidate_state_dict
obj_list, src=self._local_to_global_rank[rank], group=self.group,
File "/home/ubuntu/anaconda3/envs/pytorch_p37/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 1403, in broadcast_object_list
object_list[i] = _tensor_to_object(obj_view, obj_size)
File "/home/ubuntu/anaconda3/envs/pytorch_p37/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 1187, in _tensor_to_object
out = pickle.loads(buf)
File "/home/ubuntu/anaconda3/envs/pytorch_p37/lib/python3.7/site-packages/torch/storage.py", line 141, in _load_from_bytes
return torch.load(io.BytesIO(b))
File "/home/ubuntu/anaconda3/envs/pytorch_p37/lib/python3.7/site-packages/torch/serialization.py", line 595, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File "/home/ubuntu/anaconda3/envs/pytorch_p37/lib/python3.7/site-packages/torch/serialization.py", line 774, in _legacy_load
result = unpickler.load()
File "/home/ubuntu/anaconda3/envs/pytorch_p37/lib/python3.7/site-packages/torch/serialization.py", line 730, in persistent_load
deserialized_objects[root_key] = restore_location(obj, location)
File "/home/ubuntu/anaconda3/envs/pytorch_p37/lib/python3.7/site-packages/torch/serialization.py", line 175, in default_restore_location
result = fn(storage, location)
File "/home/ubuntu/anaconda3/envs/pytorch_p37/lib/python3.7/site-packages/torch/serialization.py", line 155, in _cuda_deserialize
return storage_type(obj.size())
File "/home/ubuntu/anaconda3/envs/pytorch_p37/lib/python3.7/site-packages/torch/cuda/init.py", line 462, in _lazy_new
return super(_CudaBase, cls).new(cls, *args, **kwargs)
RuntimeError: CUDA error: out of memory

Expected behavior

Could you please tell me how to fix this issue?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions