Skip to content

Sparse CSR CUDA: Support mixed memory format input for triangular_solve#66401

Closed
IvanYashchuk wants to merge 30 commits intogh/ivanyashchuk/40/basefrom
gh/ivanyashchuk/40/head
Closed

Sparse CSR CUDA: Support mixed memory format input for triangular_solve#66401
IvanYashchuk wants to merge 30 commits intogh/ivanyashchuk/40/basefrom
gh/ivanyashchuk/40/head

Conversation

@IvanYashchuk
Copy link
Copy Markdown
Collaborator

@IvanYashchuk IvanYashchuk commented Oct 11, 2021

Stack from ghstack:

This PR fixes the case when result and input tensors have different
strides.
cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to
write the result. This is "fixed" in PyTorch code by copying the input
tensor to a tensor with same strides as result tensor has.
The bug can be reproduced using the code given in #66401 (comment).

cc @nikitaved @pearu @cpuhrsch @IvanYashchuk @ngimel

Differential Revision: D32177966

This PR fixes the case when result and input tensors have different
strides.
cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to
write the result. This is "fixed" in PyTorch code by copying the input
tensor to a tensor with same strides as result tensor has.

[ghstack-poisoned]
@pytorch-probot
Copy link
Copy Markdown

pytorch-probot Bot commented Oct 11, 2021

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/pytorch/pytorch/blob/2a8c131d318319004f698c910003d41130e02688/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default,ciflow/cuda

Workflows Labels (bold enabled) Status
Triggered Workflows
libtorch-linux-xenial-cuda10.2-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux ✅ triggered
libtorch-linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux ✅ triggered
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow ✅ triggered
linux-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/xla ✅ triggered
linux-vulkan-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-dynamic ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile ✅ triggered
linux-xenial-py3.6-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers ✅ triggered
linux-xenial-py3.6-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx ✅ triggered
linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled ✅ triggered
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck ✅ triggered
periodic-linux-xenial-cuda11.1-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled ✅ triggered
periodic-win-vs2019-cuda11.1-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/win ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped
docker-builds ciflow/all 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-full-jit ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-x86-64-full-jit ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
linux-xenial-py3-clang5-mobile-code-analysis ciflow/all, ciflow/linux, ciflow/mobile 🚫 skipped
parallelnative-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped

You can add a comment to the PR and tag @pytorchbot with the following commands:
# ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun

# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slow

For more information, please take a look at the CI Flow Wiki.

IvanYashchuk added a commit that referenced this pull request Oct 11, 2021
This PR fixes the case when result and input tensors have different
strides.
cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to
write the result. This is "fixed" in PyTorch code by copying the input
tensor to a tensor with same strides as result tensor has.

ghstack-source-id: 798fcab
Pull Request resolved: #66401
@IvanYashchuk IvanYashchuk added the module: sparse Related to torch.sparse label Oct 11, 2021
@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Oct 11, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 2a8c131 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@IvanYashchuk IvanYashchuk added module: cuda Related to torch.cuda, and CUDA support in general ciflow/cuda labels Oct 11, 2021
…angular_solve"

This PR fixes the case when result and input tensors have different
strides.
cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to
write the result. This is "fixed" in PyTorch code by copying the input
tensor to a tensor with same strides as result tensor has.

cc nikitaved pearu cpuhrsch @IvanYashchuk ngimel

[ghstack-poisoned]
IvanYashchuk added a commit that referenced this pull request Oct 12, 2021
This PR fixes the case when result and input tensors have different
strides.
cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to
write the result. This is "fixed" in PyTorch code by copying the input
tensor to a tensor with same strides as result tensor has.

ghstack-source-id: 1ee0a4e
Pull Request resolved: #66401
…angular_solve"

This PR fixes the case when result and input tensors have different
strides.
cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to
write the result. This is "fixed" in PyTorch code by copying the input
tensor to a tensor with same strides as result tensor has.

cc nikitaved pearu cpuhrsch @IvanYashchuk ngimel

[ghstack-poisoned]
…angular_solve"

This PR fixes the case when result and input tensors have different
strides.
cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to
write the result. This is "fixed" in PyTorch code by copying the input
tensor to a tensor with same strides as result tensor has.

cc nikitaved pearu cpuhrsch @IvanYashchuk ngimel

[ghstack-poisoned]
IvanYashchuk added a commit that referenced this pull request Oct 12, 2021
This PR fixes the case when result and input tensors have different
strides.
cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to
write the result. This is "fixed" in PyTorch code by copying the input
tensor to a tensor with same strides as result tensor has.

ghstack-source-id: f9c368b
Pull Request resolved: #66401
…angular_solve"

This PR fixes the case when result and input tensors have different
strides.
cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to
write the result. This is "fixed" in PyTorch code by copying the input
tensor to a tensor with same strides as result tensor has.

cc nikitaved pearu cpuhrsch @IvanYashchuk ngimel

[ghstack-poisoned]
…angular_solve"

This PR fixes the case when result and input tensors have different
strides.
cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to
write the result. This is "fixed" in PyTorch code by copying the input
tensor to a tensor with same strides as result tensor has.

cc nikitaved pearu cpuhrsch IvanYashchuk ngimel

[ghstack-poisoned]
IvanYashchuk added a commit that referenced this pull request Oct 28, 2021
This PR fixes the case when result and input tensors have different
strides.
cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to
write the result. This is "fixed" in PyTorch code by copying the input
tensor to a tensor with same strides as result tensor has.

ghstack-source-id: f9fb005
Pull Request resolved: #66401
…angular_solve"

This PR fixes the case when result and input tensors have different
strides.
cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to
write the result. This is "fixed" in PyTorch code by copying the input
tensor to a tensor with same strides as result tensor has.

cc nikitaved pearu cpuhrsch IvanYashchuk ngimel

[ghstack-poisoned]
IvanYashchuk added a commit that referenced this pull request Nov 2, 2021
This PR fixes the case when result and input tensors have different
strides.
cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to
write the result. This is "fixed" in PyTorch code by copying the input
tensor to a tensor with same strides as result tensor has.

ghstack-source-id: 9eaa420
Pull Request resolved: #66401
…angular_solve"

This PR fixes the case when result and input tensors have different
strides.
cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to
write the result. This is "fixed" in PyTorch code by copying the input
tensor to a tensor with same strides as result tensor has.

cc nikitaved pearu cpuhrsch IvanYashchuk ngimel

[ghstack-poisoned]
IvanYashchuk added a commit that referenced this pull request Nov 2, 2021
This PR fixes the case when result and input tensors have different
strides.
cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to
write the result. This is "fixed" in PyTorch code by copying the input
tensor to a tensor with same strides as result tensor has.

ghstack-source-id: 74b5305
Pull Request resolved: #66401
…angular_solve"

This PR fixes the case when result and input tensors have different
strides.
cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to
write the result. This is "fixed" in PyTorch code by copying the input
tensor to a tensor with same strides as result tensor has.

cc nikitaved pearu cpuhrsch IvanYashchuk ngimel

[ghstack-poisoned]
…angular_solve"

This PR fixes the case when result and input tensors have different
strides.
cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to
write the result. This is "fixed" in PyTorch code by copying the input
tensor to a tensor with same strides as result tensor has.

cc nikitaved pearu cpuhrsch IvanYashchuk ngimel

[ghstack-poisoned]
IvanYashchuk added a commit that referenced this pull request Nov 2, 2021
This PR fixes the case when result and input tensors have different
strides.
cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to
write the result. This is "fixed" in PyTorch code by copying the input
tensor to a tensor with same strides as result tensor has.

ghstack-source-id: c0383d3
Pull Request resolved: #66401
…angular_solve"

This PR fixes the case when result and input tensors have different
strides.
cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to
write the result. This is "fixed" in PyTorch code by copying the input
tensor to a tensor with same strides as result tensor has.

cc nikitaved pearu cpuhrsch IvanYashchuk ngimel

[ghstack-poisoned]
IvanYashchuk added a commit that referenced this pull request Nov 3, 2021
This PR fixes the case when result and input tensors have different
strides.
cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to
write the result. This is "fixed" in PyTorch code by copying the input
tensor to a tensor with same strides as result tensor has.

ghstack-source-id: 54ab717
Pull Request resolved: #66401
…angular_solve"

This PR fixes the case when result and input tensors have different
strides.
cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to
write the result. This is "fixed" in PyTorch code by copying the input
tensor to a tensor with same strides as result tensor has.

cc nikitaved pearu cpuhrsch IvanYashchuk ngimel

[ghstack-poisoned]
IvanYashchuk added a commit that referenced this pull request Nov 3, 2021
This PR fixes the case when result and input tensors have different
strides.
cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to
write the result. This is "fixed" in PyTorch code by copying the input
tensor to a tensor with same strides as result tensor has.

ghstack-source-id: 3503d48
Pull Request resolved: #66401
…angular_solve"

This PR fixes the case when result and input tensors have different
strides.
cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to
write the result. This is "fixed" in PyTorch code by copying the input
tensor to a tensor with same strides as result tensor has.

cc nikitaved pearu cpuhrsch IvanYashchuk ngimel

[ghstack-poisoned]
@IvanYashchuk IvanYashchuk requested a review from cpuhrsch November 4, 2021 12:52
@cpuhrsch
Copy link
Copy Markdown
Contributor

cpuhrsch commented Nov 4, 2021

@cpuhrsch has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ngimel
Copy link
Copy Markdown
Collaborator

ngimel commented Nov 4, 2021

@IvanYashchuk do you know if this bug is fixed in 11.5? Did you report it to nvidia? cc @xwang233

@xwang233
Copy link
Copy Markdown
Collaborator

xwang233 commented Nov 4, 2021

Hi @ngimel @IvanYashchuk , I haven't received any reports yet. I can help report this to cusparse team if a C++ standalone repro is available. Thanks.

@IvanYashchuk
Copy link
Copy Markdown
Collaborator Author

IvanYashchuk commented Nov 4, 2021

Let's use a slightly modified version of the sample code available from the CUDALibrarySamples repo.

I'll give here two patches to verify the problem on CUDA 11.3.1 (I didn't test with a newer version the bug is still present in 11.5.50). First let's modify ldb to be 8 instead of 4 and modify the B array accordingly, the result should not change in this case if everything was correct.

import torch
r = [1.0, 8.0, 23.0, 52.0]
o = [0.0, 0.0, 0.0, 0.0]
B = torch.tensor([r, o, r, o])
# gives
# tensor([[ 1.,  8., 23., 52.],
#         [ 0.,  0.,  0.,  0.],
#         [ 1.,  8., 23., 52.],
#         [ 0.,  0.,  0.,  0.]])
print(B.stride()) # (4, 1)
print(B[::2])
# tensor([[ 1.,  8., 23., 52.],
#         [ 1.,  8., 23., 52.]])
print(B[::2].stride()) # (8, 1)
# so that's why we pick ldb = 8
Patch 1:
diff --git a/cuSPARSE/spsm_csr/spsm_csr_example.c b/cuSPARSE/spsm_csr/spsm_csr_example.c
index 8dfbab4..7336689 100644
--- a/cuSPARSE/spsm_csr/spsm_csr_example.c
+++ b/cuSPARSE/spsm_csr/spsm_csr_example.c
@@ -77,14 +77,16 @@ int main(void) {
   const int A_num_cols      = 4;
   const int A_nnz           = 9;
   const int nrhs            = 2;
-    const int ldb             = A_num_cols;
+    const int ldb             = 2 * A_num_cols;
   const int ldc             = A_num_rows;
   int       hA_csrOffsets[] = { 0, 3, 4, 7, 9 };
   int       hA_columns[]    = { 0, 2, 3, 1, 0, 2, 3, 1, 3 };
   float     hA_values[]     = { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
                                 6.0f, 7.0f, 8.0f, 9.0f };
   float     hB[]            = { 1.0f, 8.0f, 23.0f, 52.0f,
-                                  1.0f, 8.0f, 23.0f, 52.0f };
+                                  0.0f, 0.0f, 0.0f, 0.0f,
+                                  1.0f, 8.0f, 23.0f, 52.0f,
+                                  0.0f, 0.0f, 0.0f, 0.0f };
   float     hC[]            = { 0.0f, 0.0f, 0.0f, 0.0f,
                                 0.0f, 0.0f, 0.0f, 0.0f };
   float     hY_result[]     = { 1.0f, 2.0f, 3.0f, 4.0f,
@@ -98,7 +100,7 @@ int main(void) {
                          (A_num_rows + 1) * sizeof(int)) )
   CHECK_CUDA( cudaMalloc((void**) &dA_columns, A_nnz * sizeof(int))       )
   CHECK_CUDA( cudaMalloc((void**) &dA_values,  A_nnz * sizeof(float))     )
-    CHECK_CUDA( cudaMalloc((void**) &dB, nrhs * A_num_cols * sizeof(float)) )
+    CHECK_CUDA( cudaMalloc((void**) &dB, nrhs * ldb * sizeof(float)) )
   CHECK_CUDA( cudaMalloc((void**) &dC, nrhs * A_num_rows * sizeof(float)) )

   CHECK_CUDA( cudaMemcpy(dA_csrOffsets, hA_csrOffsets,
@@ -108,7 +110,7 @@ int main(void) {
                          cudaMemcpyHostToDevice) )
   CHECK_CUDA( cudaMemcpy(dA_values, hA_values, A_nnz * sizeof(float),
                          cudaMemcpyHostToDevice) )
-    CHECK_CUDA( cudaMemcpy(dB, hB, nrhs * A_num_cols * sizeof(float),
+    CHECK_CUDA( cudaMemcpy(dB, hB, nrhs * ldb * sizeof(float),
                          cudaMemcpyHostToDevice) )
   CHECK_CUDA( cudaMemcpy(dC, hC, nrhs * A_num_rows * sizeof(float),
                          cudaMemcpyHostToDevice) )
@@ -173,11 +175,13 @@ int main(void) {
                          cudaMemcpyDeviceToHost) )
   int correct = 1;
   for (int i = 0; i < nrhs * A_num_rows; i++) {
+        printf("%f ", hC[i]);
       if (hC[i] != hY_result[i]) { // direct floating point comparison is not
           correct = 0;             // reliable
           break;
       }
   }
+    printf("\n");
   if (correct)
       printf("spsm_csr_example test PASSED\n");
   else

Compiling Patch 1 and running the program with cuda-memcheck reveals several out of bound writes errors:

cuda-memcheck on Patch 1
========= CUDA-MEMCHECK
========= Invalid __global__ write of size 4
=========     at 0x000011b0 in void cusparse_transpose_readWrite_alignment_kernel<float, int=1, bool=0, int=6, int=5, int=3>(cusparseTransposeParams<float>, float const *, float*, float const *)
=========     by thread (43,0,0) in block (0,0,0)
=========     Address 0x7fad0520082c is out of bounds
=========     Device Frame:void cusparse_transpose_readWrite_alignment_kernel<float, int=1, bool=0, int=6, int=5, int=3>(cusparseTransposeParams<float>, float const *, float*, float const *) (void cusparse_transpose_readWrite_alignment_kernel<float, int=1, bool=0, int=6, int=5, int=3>(cusparseTransposeParams<float>, float const *, float*, float const *) : 0x11b0)
=========     Saved host backtrace up to driver entry point at kernel launch time
=========     Host Frame:/lib/x86_64-linux-gnu/libcuda.so.1 [0x25428a]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x74d54b]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x7a0c70]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x5b8940]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x5b76ea]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x5fc601]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x61b29d]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 (cusparseSpSM_solve + 0x172) [0xbf262]
=========     Host Frame:./spsm_csr_example [0x148f]
=========     Host Frame:/lib/x86_64-linux-gnu/libc.so.6 (__libc_start_main + 0xf3) [0x270b3]
=========     Host Frame:./spsm_csr_example [0x1c29]
=========
========= Invalid __global__ write of size 4
=========     at 0x000011b0 in void cusparse_transpose_readWrite_alignment_kernel<float, int=1, bool=0, int=6, int=5, int=3>(cusparseTransposeParams<float>, float const *, float*, float const *)
=========     by thread (42,0,0) in block (0,0,0)
=========     Address 0x7fad05200828 is out of bounds
=========     Device Frame:void cusparse_transpose_readWrite_alignment_kernel<float, int=1, bool=0, int=6, int=5, int=3>(cusparseTransposeParams<float>, float const *, float*, float const *) (void cusparse_transpose_readWrite_alignment_kernel<float, int=1, bool=0, int=6, int=5, int=3>(cusparseTransposeParams<float>, float const *, float*, float const *) : 0x11b0)
=========     Saved host backtrace up to driver entry point at kernel launch time
=========     Host Frame:/lib/x86_64-linux-gnu/libcuda.so.1 [0x25428a]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x74d54b]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x7a0c70]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x5b8940]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x5b76ea]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x5fc601]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x61b29d]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 (cusparseSpSM_solve + 0x172) [0xbf262]
=========     Host Frame:./spsm_csr_example [0x148f]
=========     Host Frame:/lib/x86_64-linux-gnu/libc.so.6 (__libc_start_main + 0xf3) [0x270b3]
=========     Host Frame:./spsm_csr_example [0x1c29]
=========
========= Invalid __global__ write of size 4
=========     at 0x000011b0 in void cusparse_transpose_readWrite_alignment_kernel<float, int=1, bool=0, int=6, int=5, int=3>(cusparseTransposeParams<float>, float const *, float*, float const *)
=========     by thread (41,0,0) in block (0,0,0)
=========     Address 0x7fad05200824 is out of bounds
=========     Device Frame:void cusparse_transpose_readWrite_alignment_kernel<float, int=1, bool=0, int=6, int=5, int=3>(cusparseTransposeParams<float>, float const *, float*, float const *) (void cusparse_transpose_readWrite_alignment_kernel<float, int=1, bool=0, int=6, int=5, int=3>(cusparseTransposeParams<float>, float const *, float*, float const *) : 0x11b0)
=========     Saved host backtrace up to driver entry point at kernel launch time
=========     Host Frame:/lib/x86_64-linux-gnu/libcuda.so.1 [0x25428a]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x74d54b]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x7a0c70]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x5b8940]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x5b76ea]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x5fc601]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x61b29d]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 (cusparseSpSM_solve + 0x172) [0xbf262]
=========     Host Frame:./spsm_csr_example [0x148f]
=========     Host Frame:/lib/x86_64-linux-gnu/libc.so.6 (__libc_start_main + 0xf3) [0x270b3]
=========     Host Frame:./spsm_csr_example [0x1c29]
=========
========= Invalid __global__ write of size 4
=========     at 0x000011b0 in void cusparse_transpose_readWrite_alignment_kernel<float, int=1, bool=0, int=6, int=5, int=3>(cusparseTransposeParams<float>, float const *, float*, float const *)
=========     by thread (40,0,0) in block (0,0,0)
=========     Address 0x7fad05200820 is out of bounds
=========     Device Frame:void cusparse_transpose_readWrite_alignment_kernel<float, int=1, bool=0, int=6, int=5, int=3>(cusparseTransposeParams<float>, float const *, float*, float const *) (void cusparse_transpose_readWrite_alignment_kernel<float, int=1, bool=0, int=6, int=5, int=3>(cusparseTransposeParams<float>, float const *, float*, float const *) : 0x11b0)
=========     Saved host backtrace up to driver entry point at kernel launch time
=========     Host Frame:/lib/x86_64-linux-gnu/libcuda.so.1 [0x25428a]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x74d54b]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x7a0c70]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x5b8940]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x5b76ea]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x5fc601]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x61b29d]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 (cusparseSpSM_solve + 0x172) [0xbf262]
=========     Host Frame:./spsm_csr_example [0x148f]
=========     Host Frame:/lib/x86_64-linux-gnu/libc.so.6 (__libc_start_main + 0xf3) [0x270b3]
=========     Host Frame:./spsm_csr_example [0x1c29]
=========
========= Program hit cudaErrorLaunchFailure (error 719) due to "unspecified launch failure" on CUDA API call to cudaMemcpy.
CUDA API failed at line 174 with error: unspecified launch failure (719)
=========     Saved host backtrace up to driver entry point at error
=========     Host Frame:/lib/x86_64-linux-gnu/libcuda.so.1 [0x355b43]
=========     Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcudart.so.11.0 (cudaMemcpy + 0x17d) [0x5d1cd]
=========     Host Frame:./spsm_csr_example [0x1524]
=========     Host Frame:/lib/x86_64-linux-gnu/libc.so.6 (__libc_start_main + 0xf3) [0x270b3]
=========     Host Frame:./spsm_csr_example [0x1c29]
=========
========= ERROR SUMMARY: 5 errors

Let's extend dC array to remove out of bound errors and reveal that ldb value is used to write the result to dC while it should use ldc (let's set ldc = 99999).

Patch 2:
diff --git a/cuSPARSE/spsm_csr/spsm_csr_example.c b/cuSPARSE/spsm_csr/spsm_csr_example.c
index 8dfbab4..026a968 100644
--- a/cuSPARSE/spsm_csr/spsm_csr_example.c
+++ b/cuSPARSE/spsm_csr/spsm_csr_example.c
@@ -77,16 +77,21 @@ int main(void) {
   const int A_num_cols      = 4;
   const int A_nnz           = 9;
   const int nrhs            = 2;
-    const int ldb             = A_num_cols;
-    const int ldc             = A_num_rows;
+    const int ldb             = 2 * A_num_cols;
+    // ldc can now be set to any value, since it is not used
+    const int ldc             = 99999;
   int       hA_csrOffsets[] = { 0, 3, 4, 7, 9 };
   int       hA_columns[]    = { 0, 2, 3, 1, 0, 2, 3, 1, 3 };
   float     hA_values[]     = { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
                                 6.0f, 7.0f, 8.0f, 9.0f };
   float     hB[]            = { 1.0f, 8.0f, 23.0f, 52.0f,
-                                  1.0f, 8.0f, 23.0f, 52.0f };
-    float     hC[]            = { 0.0f, 0.0f, 0.0f, 0.0f,
+                                  0.0f, 0.0f, 0.0f, 0.0f,
+                                  1.0f, 8.0f, 23.0f, 52.0f,
                                 0.0f, 0.0f, 0.0f, 0.0f };
+    float     hC[]            = { 0.0f, 0.0f, 0.0f, 0.0f,
+                                  0.0f, 0.0f, 0.0f, 0.0f,
+                                  0.0f, 0.0f, 0.0f, 0.0f,
+                                  0.0f, 0.0f, 0.0f, 0.0f, };
   float     hY_result[]     = { 1.0f, 2.0f, 3.0f, 4.0f,
                                 1.0f, 2.0f, 3.0f, 4.0f };
   float     alpha           = 1.0f;
@@ -98,8 +103,8 @@ int main(void) {
                          (A_num_rows + 1) * sizeof(int)) )
   CHECK_CUDA( cudaMalloc((void**) &dA_columns, A_nnz * sizeof(int))       )
   CHECK_CUDA( cudaMalloc((void**) &dA_values,  A_nnz * sizeof(float))     )
-    CHECK_CUDA( cudaMalloc((void**) &dB, nrhs * A_num_cols * sizeof(float)) )
-    CHECK_CUDA( cudaMalloc((void**) &dC, nrhs * A_num_rows * sizeof(float)) )
+    CHECK_CUDA( cudaMalloc((void**) &dB, nrhs * ldb * sizeof(float)) )
+    CHECK_CUDA( cudaMalloc((void**) &dC, nrhs * A_num_rows*2 * sizeof(float)) )

   CHECK_CUDA( cudaMemcpy(dA_csrOffsets, hA_csrOffsets,
                          (A_num_rows + 1) * sizeof(int),
@@ -108,9 +113,9 @@ int main(void) {
                          cudaMemcpyHostToDevice) )
   CHECK_CUDA( cudaMemcpy(dA_values, hA_values, A_nnz * sizeof(float),
                          cudaMemcpyHostToDevice) )
-    CHECK_CUDA( cudaMemcpy(dB, hB, nrhs * A_num_cols * sizeof(float),
+    CHECK_CUDA( cudaMemcpy(dB, hB, nrhs * ldb * sizeof(float),
                          cudaMemcpyHostToDevice) )
-    CHECK_CUDA( cudaMemcpy(dC, hC, nrhs * A_num_rows * sizeof(float),
+    CHECK_CUDA( cudaMemcpy(dC, hC, nrhs * A_num_rows*2 * sizeof(float),
                          cudaMemcpyHostToDevice) )
   //--------------------------------------------------------------------------
   // CUSPARSE APIs
@@ -169,15 +174,17 @@ int main(void) {
   CHECK_CUSPARSE( cusparseDestroy(handle) )
   //--------------------------------------------------------------------------
   // device result check
-    CHECK_CUDA( cudaMemcpy(hC, dC, nrhs * A_num_rows * sizeof(float),
+    CHECK_CUDA( cudaMemcpy(hC, dC, nrhs * A_num_rows*2 * sizeof(float),
                          cudaMemcpyDeviceToHost) )
   int correct = 1;
-    for (int i = 0; i < nrhs * A_num_rows; i++) {
+    for (int i = 0; i < nrhs * A_num_rows*2; i++) {
+        printf("%f ", hC[i]);
       if (hC[i] != hY_result[i]) { // direct floating point comparison is not
           correct = 0;             // reliable
-            break;
+            // break;
       }
   }
+    printf("\n");
   if (correct)
       printf("spsm_csr_example test PASSED\n");
   else

Compiling and running Patch 2 gives:

========= CUDA-MEMCHECK
1.000000 2.000000 3.000000 4.000000 0.000000 0.000000 0.000000 0.000000 1.000000 2.000000 3.000000 4.000000 0.000000 0.000000 0.000000 0.000000
spsm_csr_example test FAILED: wrong result
========= ERROR SUMMARY: 0 errors

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@cpuhrsch merged this pull request in d5d342b.

IvanYashchuk added a commit to IvanYashchuk/pytorch that referenced this pull request Nov 8, 2021
This PR fixes the case when result and input tensors have different
strides.
cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to
write the result. This is "fixed" in PyTorch code by copying the input
tensor to a tensor with same strides as result tensor has.

ghstack-source-id: 7bb46bd
Pull Request resolved: pytorch#66401
@facebook-github-bot facebook-github-bot deleted the gh/ivanyashchuk/40/head branch November 8, 2021 15:16
@yukini2009
Copy link
Copy Markdown

Correct.zip
Hi , we checked the issue internally . Note that the matrix B must be ldb * nrhs (for column-major) . Add a modified version . The 'cusparse_transpose_readWrite_alignment_kerne' error will be fixed in next major CUDA release after 11.5 . Thanks.

laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 25, 2026
…ve (pytorch#66401)

Summary:
Pull Request resolved: pytorch#66401

This PR fixes the case when result and input tensors have different
strides.
cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to
write the result. This is "fixed" in PyTorch code by copying the input
tensor to a tensor with same strides as result tensor has.

cc nikitaved pearu cpuhrsch IvanYashchuk ngimel

Test Plan: Imported from OSS

Reviewed By: davidberard98

Differential Revision: D32177966

Pulled By: cpuhrsch

fbshipit-source-id: 118437409df147f04dce02763aff9bfd33f87c63
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: cuda Related to torch.cuda, and CUDA support in general module: sparse Related to torch.sparse open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants