Skip to content

Conversation

@karlonw
Copy link
Collaborator

@karlonw karlonw commented Jun 29, 2023

Description

When running pip install deepspeed on AMD based systems, with either DS_BUILD_OPS=1 or DS_BUILD_AIO=1, then method simd_width(self) returns an empty string as the result. When gcc encounters this empty string on the compile line, it returns a No such file or directory error, shown below, with the empty string seen following the -laio option:

# gcc -pthread -B /opt/conda/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-p
rototypes -fPIC -Icsrc/aio/py_lib -Icsrc/aio/common -I/opt/conda/lib/python3.8/site-packages/torch/include -I/opt/conda/lib/python3.8/site-packages/torch/include
/torch/csrc/api/include -I/opt/conda/lib/python3.8/site-packages/torch/include/TH -I/opt/conda/lib/python3.8/site-packages/torch/include/THC -I/opt/conda/include
/python3.8 -c csrc/aio/common/deepspeed_aio_common.cpp -o build/temp.linux-x86_64-cpython-38/csrc/aio/common/deepspeed_aio_common.o -g -Wall -O0 -std=c++14 -shar
ed -fPIC -Wno-reorder -march=native -fopenmp -laio ""  -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIN
D11_BUILD_ABI=\"_cxxabi1011\" -DTORCH_EXTENSION_NAME=async_io_op -D_GLIBCXX_USE_CXX11_ABI=0
gcc: error: : No such file or directory

Since AMD and Intel use the same instruction set, this change includes Vendor ID AuthenticAMD as accepting the -D__AVX512__ or -D__AVX256__ flags. A future change, to support running on PowerPC or ARM chips, might be to replace the empty string with -D__SCALAR__ instead, but I don't have access to a machine I can test that on.

With this change in place, in working branch karlonw/PAD-43-fix-for-AMD I was able to run

environments$ git diff
diff --git a/Makefile b/Makefile
index a203712..9f4f162 100644
--- a/Makefile
+++ b/Makefile
@@ -240,7 +240,7 @@ build-gpt-neox-deepspeed-gpu: build-gpu-cuda-113-base
                --build-arg TORCH_CUDA_ARCH_LIST="6.0;6.1;6.2;7.0;7.5;8.0" \
                --build-arg APEX_GIT="https://github.com/determined-ai/apex.git@3caf0f40c92e92b40051d3afff8568a24b8be28d" \
                --build-arg "$(NCCL_BUILD_ARG)" \
-               --build-arg DEEPSPEED_PIP="git+https://github.com/determined-ai/deepspeed.git@eleuther_dai" \
+               --build-arg DEEPSPEED_PIP="git+https://github.com/determined-ai/deepspeed.git@karlonw/PAD-43-fix-for-amd" \
                -t $(DOCKERHUB_REGISTRY)/$(GPU_GPT_NEOX_DEEPSPEED_ENVIRONMENT_NAME)-$(SHORT_GIT_HASH) \
                -t $(DOCKERHUB_REGISTRY)/$(GPU_GPT_NEOX_DEEPSPEED_ENVIRONMENT_NAME)-$(VERSION) \
                -t $(NGC_REGISTRY)/$(GPU_GPT_NEOX_DEEPSPEED_ENVIRONMENT_NAME)-$(SHORT_GIT_HASH) \

environments$ make  build-gpt-neox-deepspeed-gpu WITH_MPI=0

successfully, to generate an environments image on an AMD system with all OPS pre-compiled for deepspeed.

Testing

--------------------------------------------------                                                                                                      [0/72647]
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
cpu_adam ............... [YES] ...... [OKAY]
fused_adam ............. [YES] ...... [OKAY]
fused_lamb ............. [YES] ...... [OKAY]
sparse_attn ............ [YES] ...... [OKAY]
transformer ............ [YES] ...... [OKAY]
stochastic_transformer . [YES] ...... [OKAY]
utils .................. [YES] ...... [OKAY]
async_io ............... [YES] ...... [OKAY]
--------------------------------------------------
No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'
DeepSpeed general environment info:
torch install path ............... ['/opt/conda/lib/python3.8/site-packages/torch']
torch version .................... 1.10.2+cu113
torch cuda version ............... 11.3
nvcc version ..................... 11.3
deepspeed install path ........... ['/opt/conda/lib/python3.8/site-packages/deepspeed']
deepspeed info ................... 0.3.15+298c044, 298c044, karlonw/PAD-43-fix-for-amd
deepspeed wheel compiled w. ...... torch 1.10, cuda 11.3
Removing intermediate container 026067f035dc
 ---> f37deafaf725
Step 52/52 : RUN rm -r /tmp/*
 ---> Running in 4ce4df8d6072
Removing intermediate container 4ce4df8d6072
 ---> 1e4bf7689af6
Successfully built 1e4bf7689af6
Successfully tagged determinedai/environments:cuda-11.3-pytorch-1.10-gpt-neox-deepspeed-gpu-9d14054
Successfully tagged determinedai/environments:cuda-11.3-pytorch-1.10-gpt-neox-deepspeed-gpu-0.22.1

@karlonw karlonw requested a review from liamcli June 29, 2023 15:42
Copy link

@liamcli liamcli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@karlonw karlonw merged commit ff1b554 into eleuther_dai Jul 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants