Skip to content

Commit 69eef5a

Browse files
Aidyn-Apytorchmergebot
authored andcommitted
[CUDA12] set_device change (#94864)
This PR adds workaround for CUDA 12 [`cudaSetDevice` change](https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1g159587909ffa0791bbe4b40187a4c6bb) which will always create primary context on target device. So operations like this: ```Python import torch x = torch.randn(1, device="cuda:1") ``` would always create primary context on on device `cuda:1` because it is creating a tensor on it and on device `cuda:0` because the destructor of CUDA Device guard calls `cudaSetDevice(0)`. After this PR the CUDA Device guard will not call `cudaSetDevice(0)` if primary context does not exist on `cuda:0`. Pull Request resolved: #94864 Approved by: https://github.com/malfet, https://github.com/atalman, https://github.com/ezyang
1 parent 3fcc5ff commit 69eef5a

26 files changed

Lines changed: 282 additions & 63 deletions

.lintrunner.toml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,37 @@ command = [
637637
'@{{PATHSFILE}}'
638638
]
639639

640+
[[linter]]
641+
code = 'RAWCUDADEVICE'
642+
include_patterns = [
643+
'aten/**',
644+
'c10/**',
645+
'torch/csrc/**',
646+
]
647+
exclude_patterns = [
648+
'aten/src/ATen/cuda/CUDAContext.cpp',
649+
'aten/src/ATen/cuda/CUDAGeneratorImpl.cpp',
650+
'aten/src/ATen/test/**',
651+
'c10/core/impl/InlineDeviceGuard.h',
652+
'c10/cuda/CUDAFunctions.cpp',
653+
'c10/cuda/CUDAGuard.h',
654+
'c10/cuda/impl/CUDATest.cpp',
655+
'torch/csrc/cuda/nccl.cpp',
656+
]
657+
command = [
658+
'python3',
659+
'tools/linter/adapters/grep_linter.py',
660+
'--pattern=cudaSetDevice',
661+
'--pattern=cudaGetDevice',
662+
'--linter-name=RAWCUDADEVICE',
663+
'--error-name=raw CUDA API usage',
664+
"""--error-description=\
665+
This line calls raw CUDA APIs directly; please use c10::cuda wrappers instead.
666+
""",
667+
'--',
668+
'@{{PATHSFILE}}'
669+
]
670+
640671
[[linter]]
641672
code = 'ROOT_LOGGING'
642673
include_patterns = [

aten/src/ATen/cuda/CuSparseHandlePool.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ using CuSparsePoolType = DeviceThreadHandlePool<cusparseHandle_t, createCusparse
2727

2828
cusparseHandle_t getCurrentCUDASparseHandle() {
2929
int device;
30-
AT_CUDA_CHECK(cudaGetDevice(&device));
30+
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
3131

3232
// Thread local PoolWindows are lazily-initialized
3333
// to avoid initialization issues that caused hangs on Windows.

aten/src/ATen/cuda/CublasHandlePool.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ at::DataPtr getNewWorkspace() {
8181

8282
cublasHandle_t getCurrentCUDABlasHandle() {
8383
int device;
84-
AT_CUDA_CHECK(cudaGetDevice(&device));
84+
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
8585

8686
// Thread local PoolWindows are lazily-initialized
8787
// to avoid initialization issues that caused hangs on Windows.

aten/src/ATen/cuda/detail/CUDAHooks.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <ATen/native/cuda/CuFFTPlanCache.h>
1616
#include <c10/util/Exception.h>
1717
#include <c10/cuda/CUDACachingAllocator.h>
18+
#include <c10/cuda/CUDAFunctions.h>
1819
#include <c10/util/irange.h>
1920

2021
#if AT_CUDNN_ENABLED()
@@ -225,7 +226,7 @@ const at::cuda::NVRTC& CUDAHooks::nvrtc() const {
225226

226227
int64_t current_device() {
227228
int device;
228-
cudaError_t err = cudaGetDevice(&device);
229+
cudaError_t err = c10::cuda::GetDevice(&device);
229230
if (err == cudaSuccess) {
230231
return device;
231232
}

aten/src/ATen/cudnn/Handle.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ using CudnnPoolType = at::cuda::DeviceThreadHandlePool<cudnnHandle_t, createCuDN
3333

3434
cudnnHandle_t getCudnnHandle() {
3535
int device;
36-
AT_CUDA_CHECK(cudaGetDevice(&device));
36+
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
3737

3838
// Thread local PoolWindows are lazily-initialized
3939
// to avoid initialization issues that caused hangs on Windows.

aten/src/ATen/native/cuda/RNN.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ bool allContiguous(at::TensorList tensors) {
5656

5757
void getLaunchConfig(dim3* block, dim3* grid, int64_t numel) {
5858
int curDevice = -1;
59-
cudaGetDevice(&curDevice);
59+
c10::cuda::GetDevice(&curDevice);
6060
*block = cuda::getApplyBlock();
6161
TORCH_INTERNAL_ASSERT(cuda::getApplyGrid(numel, *grid, curDevice),
6262
"Could not get grid size for pointwise apply.");

aten/src/ATen/native/cuda/UniqueCub.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ std::tuple<Tensor, Tensor, Tensor> compute_unique(
8585
dim3(std::min(static_cast<int64_t>(cuda::getApplyBlock().x), num_inp));
8686
dim3 grid;
8787
int curDevice = -1;
88-
cudaGetDevice(&curDevice);
88+
c10::cuda::GetDevice(&curDevice);
8989
cuda::getApplyGrid(num_inp, grid, curDevice);
9090
adjacent_difference_kernel<<<grid, block, 0, stream>>>(
9191
num_inp, data, inv_loc_ptr);

aten/src/ATen/native/cuda/linalg/CusolverDnHandlePool.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ using CuSolverDnPoolType = DeviceThreadHandlePool<cusolverDnHandle_t, createCuso
3030

3131
cusolverDnHandle_t getCurrentCUDASolverDnHandle() {
3232
int device;
33-
AT_CUDA_CHECK(cudaGetDevice(&device));
33+
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
3434

3535
// Thread local PoolWindows are lazily-initialized
3636
// to avoid initialization issues that caused hangs on Windows.

aten/src/ATen/native/cudnn/Conv_v8.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ auto get_generator_sources(const cudnnBackendDescriptorType_t& desc, const Tenso
326326

327327
int64_t get_available_workspace() {
328328
int device;
329-
C10_CUDA_CHECK(cudaGetDevice(&device));
329+
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
330330
size_t max_block_size = 0;
331331
c10::cuda::CUDACachingAllocator::cacheInfo(device, &max_block_size);
332332
return static_cast<int64_t>(max_block_size);

aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseT
314314
const dim3 block = cuda::getApplyBlock();
315315
dim3 grid;
316316
int curDevice = -1;
317-
cudaGetDevice(&curDevice);
317+
c10::cuda::GetDevice(&curDevice);
318318
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
319319
if (sparse.dense_dim() == 0) {
320320
TORCH_CHECK(cuda::getApplyGrid(nnz, grid, curDevice), "add: Argument #0: tensor too large or too many dimensions");
@@ -606,7 +606,7 @@ Tensor _sparse_sum_backward_cuda(const Tensor& grad_, const SparseTensor& input_
606606
}
607607
else {
608608
int curDevice = -1;
609-
cudaGetDevice(&curDevice);
609+
c10::cuda::GetDevice(&curDevice);
610610
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
611611
at::cuda::ThrustAllocator allocator;
612612
auto policy = thrust::cuda::par(allocator).on(stream);
@@ -711,7 +711,7 @@ __global__ void search_end_matrix_indices_cuda_kernel(
711711
// indices to find the end index for each matrix
712712
void search_end_matrix_indices(int64_t* mat_el_end_indices, int64_t num_matrices, const Tensor& indices_1D) {
713713
int curDevice = -1;
714-
cudaGetDevice(&curDevice);
714+
c10::cuda::GetDevice(&curDevice);
715715
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
716716

717717
auto indices_1D_ti = getTensorInfo<int64_t, int64_t>(indices_1D);

0 commit comments

Comments
 (0)