Skip to content

[BUG] Regression in 0.7.1 for optimizer nvme offload - expected input to be on cuda #2263

@timohear

Description

@timohear

Describe the bug
Finetuning EleutherAI/gpt-neox-20b using HF Transformers 4.21.1 works correctly with DeepSpeed 0.7.0 but crashes with 0.7.1

The finetuning is lauched using run_clm on 2 RTX a6000 GPUs. Stage 3 nvme optimizer offload is used

The error occurs immediately after the PartitionedOptimizerSwapper setup, the simplified trace of which is:

  File "/deepspeed/runtime/zero/stage3.py", line 363, in _setup_for_real_optimizer
    self.initialize_optimizer_states()
  File "/deepspeed/runtime/zero/stage3.py", line 923, in initialize_optimizer_states
    self._optimizer_step(i)
File "/deepspeed/runtime/zero/stage3.py", line 843, in _optimizer_step
    self.optimizer.step()
  File "/torch/optim/optimizer.py", line 113, in wrapper
    return func(*args, **kwargs)
  File "/deepspeed/ops/adam/fused_adam.py", line 169, in step
    multi_tensor_applier(self.multi_tensor_adam,
  File "deepspeed/ops/adam/multi_tensor_apply.py", line 14, in __call__
    return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args)
RuntimeError: expected input to be on cuda

The error is in fused_adam, however AdamW (CPU) is specified in config file so it's surprising that it's ending up there and might explain the error as fused adam is GPU only.

Full Error log:

[2022-08-25 09:58:59,879] [INFO] [utils.py:28:print_object] PartitionedOptimizerSwapper:
[2022-08-25 09:58:59,880] [INFO] [utils.py:32:print_object]   aio_config ................... {'block_size': 1048576, 'queue_depth': 8, 'thread_count': 1, 'single_submit': False, 'overlap_events': True}
[2022-08-25 09:58:59,880] [INFO] [utils.py:32:print_object]   aligned_bytes ................ 1024
[2022-08-25 09:58:59,880] [INFO] [utils.py:32:print_object]   dtype ........................ torch.float32
[2022-08-25 09:58:59,880] [INFO] [utils.py:32:print_object]   largest_numel ................ 1060577280
[2022-08-25 09:58:59,880] [INFO] [utils.py:32:print_object]   min_aio_bytes ................ 1048576
[2022-08-25 09:58:59,880] [INFO] [utils.py:32:print_object]   numel_alignment .............. 256
[2022-08-25 09:58:59,880] [INFO] [utils.py:32:print_object]   swap_config .................. device='nvme' nvme_path=PosixPath('/media/tim/DeepSpeed/local_nvme') buffer_count=4 pin_memory=True pipeline=False pipeline_read=False pipeline_write=False fast_init=False
[2022-08-25 09:58:59,880] [INFO] [utils.py:32:print_object]   swap_element_size ............ 4
[2022-08-25 09:58:59,880] [INFO] [utils.py:32:print_object]   swap_folder .................. /media/tim/DeepSpeed/local_nvme/zero_stage_3/optimizer/rank0
Traceback (most recent call last):
  File "./run_clm.py", line 579, in <module>
    main()
  File "./run_clm.py", line 527, in main
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
  File "/home/tim/Documents/Experiments/.env/lib/python3.8/site-packages/transformers/trainer.py", line 1498, in train
    return inner_training_loop(
  File "/home/tim/Documents/Experiments/.env/lib/python3.8/site-packages/transformers/trainer.py", line 1567, in _inner_training_loop
    deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
  File "/home/tim/Documents/Experiments/.env/lib/python3.8/site-packages/transformers/deepspeed.py", line 344, in deepspeed_init
    deepspeed_engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)
  File "/home/tim/Documents/Experiments/.env/lib/python3.8/site-packages/deepspeed/__init__.py", line 124, in initialize
    engine = DeepSpeedEngine(args=args,
  File "/home/tim/Documents/Experiments/.env/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 320, in __init__
    self._configure_optimizer(optimizer, model_parameters)
  File "/home/tim/Documents/Experiments/.env/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1136, in _configure_optimizer
    self.optimizer = self._configure_zero_optimizer(basic_optimizer)
  File "/home/tim/Documents/Experiments/.env/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1439, in _configure_zero_optimizer
    optimizer = DeepSpeedZeroOptimizer_Stage3(
  File "/home/tim/Documents/Experiments/.env/lib/python3.8/site-packages/deepspeed/runtime/zero/stage3.py", line 308, in __init__
    self._setup_for_real_optimizer()
  File "/home/tim/Documents/Experiments/.env/lib/python3.8/site-packages/deepspeed/runtime/zero/stage3.py", line 363, in _setup_for_real_optimizer
    self.initialize_optimizer_states()
  File "/home/tim/Documents/Experiments/.env/lib/python3.8/site-packages/deepspeed/runtime/zero/stage3.py", line 923, in initialize_optimizer_states
    self._optimizer_step(i)
  File "/home/tim/Documents/Experiments/.env/lib/python3.8/site-packages/deepspeed/runtime/zero/stage3.py", line 843, in _optimizer_step
    self.optimizer.step()
  File "/home/tim/Documents/Experiments/.env/lib/python3.8/site-packages/torch/optim/optimizer.py", line 113, in wrapper
    return func(*args, **kwargs)
  File "/home/tim/Documents/Experiments/.env/lib/python3.8/site-packages/deepspeed/ops/adam/fused_adam.py", line 169, in step
    multi_tensor_applier(self.multi_tensor_adam,
  File "/home/tim/Documents/Experiments/.env/lib/python3.8/site-packages/deepspeed/ops/adam/multi_tensor_apply.py", line 14, in __call__
    return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args)
RuntimeError: expected input to be on cuda

To Reproduce

  1. use the ds config file attached ds_config.zip
  2. launch run_clm from Transformers
    CUDA_VISIBLE_DEVICES=1,2 python -u -m torch.distributed.launch --nproc_per_node=2 ./run_clm.py --do_train --model_name_or_path EleutherAI/gpt-neox-20b --overwrite_output_dir --train_file data/train.txt --output_dir models/EleutherAI/gpt-neox-20b/ --gradient_accumulation_steps 8 --per_device_train_batch_size 1 --num_train_epochs 1 --bf16 --deepspeed ds_config.json

System info :

  • OS: Ubuntu 22.04
  • 1 machine using 2x RTX a6000 (without NV-LINK)
  • Python 3.8
  • HuggingFace Transformers 4.21.1
  • Cuda 11.6
  • Pytorch 1.12-cu116

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions