-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
I pass a tensor with shape is (1024,) and dtype is float16 to init DeepSpeedCPUAdam. But when I call step(), the program aborts.
But when the passed tensor's shape is (1024, 1024) or its dtype is float32, the program works perfectly ok.
To Reproduce
Just run the code below:
import torch
from deepspeed.ops.adam import DeepSpeedCPUAdam
N = 1024
device = torch.device("cpu")
tmp = torch.randn(N, device=device).half()
param_bias = torch.nn.Parameter(tmp)
param = [param_bias]
optimizer = DeepSpeedCPUAdam(param)
param_bias.grad = torch.randn(N, device=device).half()
optimizer.step()ds_report output
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
runtime if needed. Op compatibility means that your system
meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
cpu_adam ............... [YES] ...... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
[WARNING] please install triton==1.0.0 if you want to use sparse attention
sparse_attn ............ [NO] ....... [NO]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
async_io ............... [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['xxx]
torch version .................... 1.10.0+cu111
torch cuda version ............... 11.1
torch hip version ................ None
nvcc version ..................... 11.1
deepspeed install path ........... ['xxx']
deepspeed info ................... 0.6.6+3da84185, 3da84185, master
deepspeed wheel compiled w. ...... torch 1.10, cuda 11.1
System info (please complete the following information):
- OS: Ubuntu 18.04
- CPU: Intel Xeon E5-2620 v4 @ 16x 3GHz
- Python version: 3.8
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working