Skip to content

[BUG] impossible to set up optimizer through config file or dict #2361

@vince62s

Description

@vince62s

Describe the bug
Unable to instantiate deepspeed.initialize with a config (file or dict) optimizer (ie: optimizer=None)

To Reproduce

import torch
import deepspeed

def random_dataset(total_samples, hidden_dim, device, dtype=torch.half):
    train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype)
    train_label = torch.empty(total_samples,
                              dtype=torch.long,
                              device=device).random_(hidden_dim)
    train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
    return train_dataset
    
def random_dataloader(model, total_samples, hidden_dim, device, dtype=torch.half):
    batch_size = model.train_micro_batch_size_per_gpu()
    train_dataset = random_dataset(total_samples, hidden_dim, device, dtype=dtype)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
    return train_loader

class SimpleModel(torch.nn.Module):
    def __init__(self, hidden_dim, empty_grad=False, nlayers=1):
        super(SimpleModel, self).__init__()
        self.linears = torch.nn.ModuleList(
            [torch.nn.Linear(hidden_dim,
                             hidden_dim) for i in range(nlayers)])
        if empty_grad:
            self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
        self.empty_grad = empty_grad

    def forward(self, x, y):
        if len(self.linears) == 1:
            x = self.linears[0](x)
        else:
            for i, l in enumerate(self.linears):
                x = self.linears[i // 2](x) + l(x)
        return self.cross_entropy_loss(x, y)
        
def test_none_args():
    config = {
        "train_batch_size": 1,
        "optimizer": {
            "type": "Adam",
            "params": {
                "lr": 0.00015
            }
        },
        "fp16": {
            "enabled": True
        },
        "zero_optimization": {
        "stage": 1
        }
    }

    def _helper():
        model = SimpleModel(hidden_dim=10)
        model, optimizer, _, _ = deepspeed.initialize(args=None, model=model, config=config)
        data_loader = random_dataloader(model=model,
                                        total_samples=5,
                                        hidden_dim=10,
                                        device=model.device)
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])

    _helper()
    
test_none_args() 
Traceback (most recent call last):
  File "mini_ds.py", line 66, in <module>
    test_none_args()        
  File "mini_ds.py", line 64, in test_none_args
    _helper()
  File "mini_ds.py", line 56, in _helper
    model, optimizer, _, _ = deepspeed.initialize(args=None, model=model, config=config)
  File "/home//.local/lib/python3.8/site-packages/deepspeed/__init__.py", line 124, in initialize
    engine = DeepSpeedEngine(args=args,
  File "/home/.local/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 325, in __init__
    self.optimizer = self._configure_zero_optimizer(optimizer=None)
  File "/home//.local/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1380, in _configure_zero_optimizer
    assert not isinstance(optimizer, DummyOptim), "zero stage 2 requires an optimizer"
AssertionError: zero stage 2 requires an optimizer

ds_report output

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
 [WARNING]  please install triton==1.0.0 if you want to use sparse attention
sparse_attn ............ [NO] ....... [NO]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
utils .................. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home//.local/lib/python3.8/site-packages/torch']
torch version .................... 1.12.1+cu113
torch cuda version ............... 11.3
torch hip version ................ None
nvcc version ..................... 11.3
deepspeed install path ........... ['/home//.local/lib/python3.8/site-packages/deepspeed']
deepspeed info ................... 0.7.3, unknown, unknown
deepspeed wheel compiled w. ...... torch 1.12, cuda 10.2

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingtraining

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions