-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Description
Describe the bug
Following runtime error when enabling fp16 config with tasks involving FloatTensor Inputs such as Image Classification, Speech Recognition ...
Traceback (most recent call last):
File "examples/cv_example.py", line 213, in <module>
main()
File "examples/cv_example.py", line 209, in main
training_function(config, args)
File "examples/cv_example.py", line 162, in training_function
outputs = model(inputs)
File "/home/sourab/test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/sourab/test/lib/python3.8/site-packages/deepspeed/utils/nvtx.py", line 11, in wrapped_fn
return func(*args, **kwargs)
File "/home/sourab/test/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1616, in forward
loss = self.module(*inputs, **kwargs)
File "/home/sourab/test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/sourab/test/lib/python3.8/site-packages/timm/models/resnet.py", line 685, in forward
x = self.forward_features(x)
File "/home/sourab/test/lib/python3.8/site-packages/timm/models/resnet.py", line 673, in forward_features
x = self.conv1(x)
File "/home/sourab/test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/sourab/test/lib/python3.8/site-packages/torch/nn/modules/container.py", line 141, in forward
input = module(input)
File "/home/sourab/test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/sourab/test/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 447, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/home/sourab/test/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 443, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the sameThis is happening because self.module.half() is called in _configure_distributed_model when initializing DeepSpeedEngine object. For language model, this didn't create any issues as the input was LongTensor to embedding layers. For other modalities wherein FloatTensor is input, it is leading to above error. Please handle this scenario in forward call of DeepSpeedEngine model object. I tried to convert inputs.half() and this fixes the issue but I feel this check and logic should be part of the forward call of DeepSpeedEngine model object.
To Reproduce
Steps to reproduce the behavior:
- Official HF Accelerate
cv_example.pyscript - Setting up DeepSpeed Zero-2 through command
accelerate config. The output config yaml:
compute_environment: LOCAL_MACHINE
deepspeed_config:
gradient_accumulation_steps: 1
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
fsdp_config: {}
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 2
use_cpu: falseThe resulting deepspeed config being:
{
"train_batch_size": 128,
"train_micro_batch_size_per_gpu": 64,
"gradient_accumulation_steps": 1,
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "none"
},
"offload_param": {
"device": "none"
},
"stage3_gather_16bit_weights_on_model_save": false
},
"steps_per_print": inf,
"fp16": {
"enabled": true
},
"zero_allow_untested_optimizer": true
}- Run the script with below command with location to folder with images. Information on downloading images here. Output has the error stated above.
accelerate launch examples/cv_example.py --mixed_precision fp16 --data_dir /path/to/images/folderExpected behavior
No errors.
ds_report output
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
runtime if needed. Op compatibility means that your system
meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
[WARNING] please install triton==1.0.0 if you want to use sparse attention
sparse_attn ............ [NO] ....... [NO]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
async_io ............... [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/sourab/test/lib/python3.8/site-packages/torch']
torch version .................... 1.11.0+cu102
torch cuda version ............... 10.2
torch hip version ................ None
nvcc version ..................... 10.2
deepspeed install path ........... ['/home/sourab/test/lib/python3.8/site-packages/deepspeed']
deepspeed info ................... 0.6.5, unknown, unknown
deepspeed wheel compiled w. ...... torch 1.11, cuda 10.2System info (please complete the following information):
- OS: Ubuntu 20.04.3 LTS (Focal Fossa)
- GPU count and types: 1 machine with x2 NVIDIA TITAN RTX each
- Python version: Python 3.8.10
Launcher context
Are you launching your experiment with the deepspeed launcher, MPI, or something else?
Accelerate launcher which just triggers deepspeed launcher