Skip to content

Commit c120fdc

Browse files
malfetfacebook-github-bot
authored andcommitted
Unify torch/csrc/cuda/shared/cudnn.cpp include path (#40525)
Summary: Pull Request resolved: #40525 Move `USE_CUDNN` define under `USE_CUDA` guard, add `cuda/shared/cudnn.cpp` to filelist if either USE_ROCM or USE_CUDNN is set. This is a prep change for PyTorch CUDA src filelist unification change. Test Plan: CI Differential Revision: D22214899 fbshipit-source-id: b71b32fc603783b41cdef0e7fab2cc9cbe750a4e
1 parent cef35e3 commit c120fdc

1 file changed

Lines changed: 13 additions & 13 deletions

File tree

torch/CMakeLists.txt

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ if(USE_CUDA)
124124
${GENERATED_THNN_CXX_CUDA}
125125
)
126126
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_CUDA)
127+
if(USE_CUDNN)
128+
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_CUDNN)
129+
endif()
127130

128131
if(MSVC)
129132
list(APPEND TORCH_PYTHON_LINK_LIBRARIES ${NVTOOLEXT_HOME}/lib/x64/nvToolsExt64_1.lib)
@@ -137,18 +140,6 @@ if(USE_CUDA)
137140

138141
endif()
139142

140-
if(USE_CUDNN)
141-
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_CUDNN)
142-
143-
list(APPEND TORCH_PYTHON_SRCS
144-
${TORCH_SRC_DIR}/csrc/cuda/shared/cudnn.cpp
145-
)
146-
endif()
147-
148-
if(USE_NUMPY)
149-
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_NUMPY)
150-
endif()
151-
152143
if(USE_ROCM)
153144
list(APPEND TORCH_PYTHON_SRCS
154145
${TORCH_SRC_DIR}/csrc/cuda/Module.cpp
@@ -159,7 +150,6 @@ if(USE_ROCM)
159150
${TORCH_SRC_DIR}/csrc/cuda/python_comm.cpp
160151
${TORCH_SRC_DIR}/csrc/cuda/serialization.cpp
161152
${TORCH_SRC_DIR}/csrc/cuda/shared/cudart.cpp
162-
${TORCH_SRC_DIR}/csrc/cuda/shared/cudnn.cpp
163153
${TORCH_SRC_DIR}/csrc/cuda/shared/nvtx.cpp
164154
${GENERATED_THNN_CXX_CUDA}
165155
)
@@ -172,6 +162,16 @@ if(USE_ROCM)
172162
list(APPEND TORCH_PYTHON_INCLUDE_DIRECTORIES ${roctracer_INCLUDE_DIRS})
173163
endif()
174164

165+
if(USE_CUDNN OR USE_ROCM)
166+
list(APPEND TORCH_PYTHON_SRCS
167+
${TORCH_SRC_DIR}/csrc/cuda/shared/cudnn.cpp
168+
)
169+
endif()
170+
171+
if(USE_NUMPY)
172+
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_NUMPY)
173+
endif()
174+
175175
if(USE_DISTRIBUTED)
176176
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_DISTRIBUTED)
177177
if(NOT MSVC)

0 commit comments

Comments
 (0)