Sparse CSR CUDA: Support mixed memory format input for triangular_solve#66401
Sparse CSR CUDA: Support mixed memory format input for triangular_solve#66401IvanYashchuk wants to merge 30 commits intogh/ivanyashchuk/40/basefrom
Conversation
This PR fixes the case when result and input tensors have different strides. cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to write the result. This is "fixed" in PyTorch code by copying the input tensor to a tensor with same strides as result tensor has. [ghstack-poisoned]
CI Flow Status⚛️ CI FlowRuleset - Version:
You can add a comment to the PR and tag @pytorchbot with the following commands: # ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun
# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slowFor more information, please take a look at the CI Flow Wiki. |
This PR fixes the case when result and input tensors have different strides. cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to write the result. This is "fixed" in PyTorch code by copying the input tensor to a tensor with same strides as result tensor has. ghstack-source-id: 798fcab Pull Request resolved: #66401
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 2a8c131 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
…angular_solve" This PR fixes the case when result and input tensors have different strides. cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to write the result. This is "fixed" in PyTorch code by copying the input tensor to a tensor with same strides as result tensor has. cc nikitaved pearu cpuhrsch @IvanYashchuk ngimel [ghstack-poisoned]
This PR fixes the case when result and input tensors have different strides. cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to write the result. This is "fixed" in PyTorch code by copying the input tensor to a tensor with same strides as result tensor has. ghstack-source-id: 1ee0a4e Pull Request resolved: #66401
…angular_solve" This PR fixes the case when result and input tensors have different strides. cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to write the result. This is "fixed" in PyTorch code by copying the input tensor to a tensor with same strides as result tensor has. cc nikitaved pearu cpuhrsch @IvanYashchuk ngimel [ghstack-poisoned]
…angular_solve" This PR fixes the case when result and input tensors have different strides. cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to write the result. This is "fixed" in PyTorch code by copying the input tensor to a tensor with same strides as result tensor has. cc nikitaved pearu cpuhrsch @IvanYashchuk ngimel [ghstack-poisoned]
This PR fixes the case when result and input tensors have different strides. cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to write the result. This is "fixed" in PyTorch code by copying the input tensor to a tensor with same strides as result tensor has. ghstack-source-id: f9c368b Pull Request resolved: #66401
…angular_solve" This PR fixes the case when result and input tensors have different strides. cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to write the result. This is "fixed" in PyTorch code by copying the input tensor to a tensor with same strides as result tensor has. cc nikitaved pearu cpuhrsch @IvanYashchuk ngimel [ghstack-poisoned]
…angular_solve" This PR fixes the case when result and input tensors have different strides. cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to write the result. This is "fixed" in PyTorch code by copying the input tensor to a tensor with same strides as result tensor has. cc nikitaved pearu cpuhrsch IvanYashchuk ngimel [ghstack-poisoned]
This PR fixes the case when result and input tensors have different strides. cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to write the result. This is "fixed" in PyTorch code by copying the input tensor to a tensor with same strides as result tensor has. ghstack-source-id: f9fb005 Pull Request resolved: #66401
…angular_solve" This PR fixes the case when result and input tensors have different strides. cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to write the result. This is "fixed" in PyTorch code by copying the input tensor to a tensor with same strides as result tensor has. cc nikitaved pearu cpuhrsch IvanYashchuk ngimel [ghstack-poisoned]
This PR fixes the case when result and input tensors have different strides. cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to write the result. This is "fixed" in PyTorch code by copying the input tensor to a tensor with same strides as result tensor has. ghstack-source-id: 9eaa420 Pull Request resolved: #66401
…angular_solve" This PR fixes the case when result and input tensors have different strides. cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to write the result. This is "fixed" in PyTorch code by copying the input tensor to a tensor with same strides as result tensor has. cc nikitaved pearu cpuhrsch IvanYashchuk ngimel [ghstack-poisoned]
This PR fixes the case when result and input tensors have different strides. cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to write the result. This is "fixed" in PyTorch code by copying the input tensor to a tensor with same strides as result tensor has. ghstack-source-id: 74b5305 Pull Request resolved: #66401
…angular_solve" This PR fixes the case when result and input tensors have different strides. cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to write the result. This is "fixed" in PyTorch code by copying the input tensor to a tensor with same strides as result tensor has. cc nikitaved pearu cpuhrsch IvanYashchuk ngimel [ghstack-poisoned]
…angular_solve" This PR fixes the case when result and input tensors have different strides. cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to write the result. This is "fixed" in PyTorch code by copying the input tensor to a tensor with same strides as result tensor has. cc nikitaved pearu cpuhrsch IvanYashchuk ngimel [ghstack-poisoned]
This PR fixes the case when result and input tensors have different strides. cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to write the result. This is "fixed" in PyTorch code by copying the input tensor to a tensor with same strides as result tensor has. ghstack-source-id: c0383d3 Pull Request resolved: #66401
…angular_solve" This PR fixes the case when result and input tensors have different strides. cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to write the result. This is "fixed" in PyTorch code by copying the input tensor to a tensor with same strides as result tensor has. cc nikitaved pearu cpuhrsch IvanYashchuk ngimel [ghstack-poisoned]
This PR fixes the case when result and input tensors have different strides. cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to write the result. This is "fixed" in PyTorch code by copying the input tensor to a tensor with same strides as result tensor has. ghstack-source-id: 54ab717 Pull Request resolved: #66401
…angular_solve" This PR fixes the case when result and input tensors have different strides. cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to write the result. This is "fixed" in PyTorch code by copying the input tensor to a tensor with same strides as result tensor has. cc nikitaved pearu cpuhrsch IvanYashchuk ngimel [ghstack-poisoned]
This PR fixes the case when result and input tensors have different strides. cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to write the result. This is "fixed" in PyTorch code by copying the input tensor to a tensor with same strides as result tensor has. ghstack-source-id: 3503d48 Pull Request resolved: #66401
…angular_solve" This PR fixes the case when result and input tensors have different strides. cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to write the result. This is "fixed" in PyTorch code by copying the input tensor to a tensor with same strides as result tensor has. cc nikitaved pearu cpuhrsch IvanYashchuk ngimel [ghstack-poisoned]
|
@cpuhrsch has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
|
@IvanYashchuk do you know if this bug is fixed in 11.5? Did you report it to nvidia? cc @xwang233 |
|
Hi @ngimel @IvanYashchuk , I haven't received any reports yet. I can help report this to cusparse team if a C++ standalone repro is available. Thanks. |
|
Let's use a slightly modified version of the sample code available from the CUDALibrarySamples repo. I'll give here two patches to verify the problem on CUDA 11.3.1 ( Patch 1:diff --git a/cuSPARSE/spsm_csr/spsm_csr_example.c b/cuSPARSE/spsm_csr/spsm_csr_example.c
index 8dfbab4..7336689 100644
--- a/cuSPARSE/spsm_csr/spsm_csr_example.c
+++ b/cuSPARSE/spsm_csr/spsm_csr_example.c
@@ -77,14 +77,16 @@ int main(void) {
const int A_num_cols = 4;
const int A_nnz = 9;
const int nrhs = 2;
- const int ldb = A_num_cols;
+ const int ldb = 2 * A_num_cols;
const int ldc = A_num_rows;
int hA_csrOffsets[] = { 0, 3, 4, 7, 9 };
int hA_columns[] = { 0, 2, 3, 1, 0, 2, 3, 1, 3 };
float hA_values[] = { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
6.0f, 7.0f, 8.0f, 9.0f };
float hB[] = { 1.0f, 8.0f, 23.0f, 52.0f,
- 1.0f, 8.0f, 23.0f, 52.0f };
+ 0.0f, 0.0f, 0.0f, 0.0f,
+ 1.0f, 8.0f, 23.0f, 52.0f,
+ 0.0f, 0.0f, 0.0f, 0.0f };
float hC[] = { 0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f };
float hY_result[] = { 1.0f, 2.0f, 3.0f, 4.0f,
@@ -98,7 +100,7 @@ int main(void) {
(A_num_rows + 1) * sizeof(int)) )
CHECK_CUDA( cudaMalloc((void**) &dA_columns, A_nnz * sizeof(int)) )
CHECK_CUDA( cudaMalloc((void**) &dA_values, A_nnz * sizeof(float)) )
- CHECK_CUDA( cudaMalloc((void**) &dB, nrhs * A_num_cols * sizeof(float)) )
+ CHECK_CUDA( cudaMalloc((void**) &dB, nrhs * ldb * sizeof(float)) )
CHECK_CUDA( cudaMalloc((void**) &dC, nrhs * A_num_rows * sizeof(float)) )
CHECK_CUDA( cudaMemcpy(dA_csrOffsets, hA_csrOffsets,
@@ -108,7 +110,7 @@ int main(void) {
cudaMemcpyHostToDevice) )
CHECK_CUDA( cudaMemcpy(dA_values, hA_values, A_nnz * sizeof(float),
cudaMemcpyHostToDevice) )
- CHECK_CUDA( cudaMemcpy(dB, hB, nrhs * A_num_cols * sizeof(float),
+ CHECK_CUDA( cudaMemcpy(dB, hB, nrhs * ldb * sizeof(float),
cudaMemcpyHostToDevice) )
CHECK_CUDA( cudaMemcpy(dC, hC, nrhs * A_num_rows * sizeof(float),
cudaMemcpyHostToDevice) )
@@ -173,11 +175,13 @@ int main(void) {
cudaMemcpyDeviceToHost) )
int correct = 1;
for (int i = 0; i < nrhs * A_num_rows; i++) {
+ printf("%f ", hC[i]);
if (hC[i] != hY_result[i]) { // direct floating point comparison is not
correct = 0; // reliable
break;
}
}
+ printf("\n");
if (correct)
printf("spsm_csr_example test PASSED\n");
elseCompiling Patch 1 and running the program with cuda-memcheck reveals several out of bound writes errors: cuda-memcheck on Patch 1========= CUDA-MEMCHECK
========= Invalid __global__ write of size 4
========= at 0x000011b0 in void cusparse_transpose_readWrite_alignment_kernel<float, int=1, bool=0, int=6, int=5, int=3>(cusparseTransposeParams<float>, float const *, float*, float const *)
========= by thread (43,0,0) in block (0,0,0)
========= Address 0x7fad0520082c is out of bounds
========= Device Frame:void cusparse_transpose_readWrite_alignment_kernel<float, int=1, bool=0, int=6, int=5, int=3>(cusparseTransposeParams<float>, float const *, float*, float const *) (void cusparse_transpose_readWrite_alignment_kernel<float, int=1, bool=0, int=6, int=5, int=3>(cusparseTransposeParams<float>, float const *, float*, float const *) : 0x11b0)
========= Saved host backtrace up to driver entry point at kernel launch time
========= Host Frame:/lib/x86_64-linux-gnu/libcuda.so.1 [0x25428a]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x74d54b]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x7a0c70]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x5b8940]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x5b76ea]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x5fc601]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x61b29d]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 (cusparseSpSM_solve + 0x172) [0xbf262]
========= Host Frame:./spsm_csr_example [0x148f]
========= Host Frame:/lib/x86_64-linux-gnu/libc.so.6 (__libc_start_main + 0xf3) [0x270b3]
========= Host Frame:./spsm_csr_example [0x1c29]
=========
========= Invalid __global__ write of size 4
========= at 0x000011b0 in void cusparse_transpose_readWrite_alignment_kernel<float, int=1, bool=0, int=6, int=5, int=3>(cusparseTransposeParams<float>, float const *, float*, float const *)
========= by thread (42,0,0) in block (0,0,0)
========= Address 0x7fad05200828 is out of bounds
========= Device Frame:void cusparse_transpose_readWrite_alignment_kernel<float, int=1, bool=0, int=6, int=5, int=3>(cusparseTransposeParams<float>, float const *, float*, float const *) (void cusparse_transpose_readWrite_alignment_kernel<float, int=1, bool=0, int=6, int=5, int=3>(cusparseTransposeParams<float>, float const *, float*, float const *) : 0x11b0)
========= Saved host backtrace up to driver entry point at kernel launch time
========= Host Frame:/lib/x86_64-linux-gnu/libcuda.so.1 [0x25428a]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x74d54b]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x7a0c70]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x5b8940]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x5b76ea]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x5fc601]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x61b29d]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 (cusparseSpSM_solve + 0x172) [0xbf262]
========= Host Frame:./spsm_csr_example [0x148f]
========= Host Frame:/lib/x86_64-linux-gnu/libc.so.6 (__libc_start_main + 0xf3) [0x270b3]
========= Host Frame:./spsm_csr_example [0x1c29]
=========
========= Invalid __global__ write of size 4
========= at 0x000011b0 in void cusparse_transpose_readWrite_alignment_kernel<float, int=1, bool=0, int=6, int=5, int=3>(cusparseTransposeParams<float>, float const *, float*, float const *)
========= by thread (41,0,0) in block (0,0,0)
========= Address 0x7fad05200824 is out of bounds
========= Device Frame:void cusparse_transpose_readWrite_alignment_kernel<float, int=1, bool=0, int=6, int=5, int=3>(cusparseTransposeParams<float>, float const *, float*, float const *) (void cusparse_transpose_readWrite_alignment_kernel<float, int=1, bool=0, int=6, int=5, int=3>(cusparseTransposeParams<float>, float const *, float*, float const *) : 0x11b0)
========= Saved host backtrace up to driver entry point at kernel launch time
========= Host Frame:/lib/x86_64-linux-gnu/libcuda.so.1 [0x25428a]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x74d54b]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x7a0c70]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x5b8940]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x5b76ea]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x5fc601]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x61b29d]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 (cusparseSpSM_solve + 0x172) [0xbf262]
========= Host Frame:./spsm_csr_example [0x148f]
========= Host Frame:/lib/x86_64-linux-gnu/libc.so.6 (__libc_start_main + 0xf3) [0x270b3]
========= Host Frame:./spsm_csr_example [0x1c29]
=========
========= Invalid __global__ write of size 4
========= at 0x000011b0 in void cusparse_transpose_readWrite_alignment_kernel<float, int=1, bool=0, int=6, int=5, int=3>(cusparseTransposeParams<float>, float const *, float*, float const *)
========= by thread (40,0,0) in block (0,0,0)
========= Address 0x7fad05200820 is out of bounds
========= Device Frame:void cusparse_transpose_readWrite_alignment_kernel<float, int=1, bool=0, int=6, int=5, int=3>(cusparseTransposeParams<float>, float const *, float*, float const *) (void cusparse_transpose_readWrite_alignment_kernel<float, int=1, bool=0, int=6, int=5, int=3>(cusparseTransposeParams<float>, float const *, float*, float const *) : 0x11b0)
========= Saved host backtrace up to driver entry point at kernel launch time
========= Host Frame:/lib/x86_64-linux-gnu/libcuda.so.1 [0x25428a]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x74d54b]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x7a0c70]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x5b8940]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x5b76ea]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x5fc601]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 [0x61b29d]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcusparse.so.11 (cusparseSpSM_solve + 0x172) [0xbf262]
========= Host Frame:./spsm_csr_example [0x148f]
========= Host Frame:/lib/x86_64-linux-gnu/libc.so.6 (__libc_start_main + 0xf3) [0x270b3]
========= Host Frame:./spsm_csr_example [0x1c29]
=========
========= Program hit cudaErrorLaunchFailure (error 719) due to "unspecified launch failure" on CUDA API call to cudaMemcpy.
CUDA API failed at line 174 with error: unspecified launch failure (719)
========= Saved host backtrace up to driver entry point at error
========= Host Frame:/lib/x86_64-linux-gnu/libcuda.so.1 [0x355b43]
========= Host Frame:/home/yashchuk/cuda/cuda-11.3.1/lib64/libcudart.so.11.0 (cudaMemcpy + 0x17d) [0x5d1cd]
========= Host Frame:./spsm_csr_example [0x1524]
========= Host Frame:/lib/x86_64-linux-gnu/libc.so.6 (__libc_start_main + 0xf3) [0x270b3]
========= Host Frame:./spsm_csr_example [0x1c29]
=========
========= ERROR SUMMARY: 5 errorsLet's extend dC array to remove out of bound errors and reveal that Patch 2:diff --git a/cuSPARSE/spsm_csr/spsm_csr_example.c b/cuSPARSE/spsm_csr/spsm_csr_example.c
index 8dfbab4..026a968 100644
--- a/cuSPARSE/spsm_csr/spsm_csr_example.c
+++ b/cuSPARSE/spsm_csr/spsm_csr_example.c
@@ -77,16 +77,21 @@ int main(void) {
const int A_num_cols = 4;
const int A_nnz = 9;
const int nrhs = 2;
- const int ldb = A_num_cols;
- const int ldc = A_num_rows;
+ const int ldb = 2 * A_num_cols;
+ // ldc can now be set to any value, since it is not used
+ const int ldc = 99999;
int hA_csrOffsets[] = { 0, 3, 4, 7, 9 };
int hA_columns[] = { 0, 2, 3, 1, 0, 2, 3, 1, 3 };
float hA_values[] = { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
6.0f, 7.0f, 8.0f, 9.0f };
float hB[] = { 1.0f, 8.0f, 23.0f, 52.0f,
- 1.0f, 8.0f, 23.0f, 52.0f };
- float hC[] = { 0.0f, 0.0f, 0.0f, 0.0f,
+ 0.0f, 0.0f, 0.0f, 0.0f,
+ 1.0f, 8.0f, 23.0f, 52.0f,
0.0f, 0.0f, 0.0f, 0.0f };
+ float hC[] = { 0.0f, 0.0f, 0.0f, 0.0f,
+ 0.0f, 0.0f, 0.0f, 0.0f,
+ 0.0f, 0.0f, 0.0f, 0.0f,
+ 0.0f, 0.0f, 0.0f, 0.0f, };
float hY_result[] = { 1.0f, 2.0f, 3.0f, 4.0f,
1.0f, 2.0f, 3.0f, 4.0f };
float alpha = 1.0f;
@@ -98,8 +103,8 @@ int main(void) {
(A_num_rows + 1) * sizeof(int)) )
CHECK_CUDA( cudaMalloc((void**) &dA_columns, A_nnz * sizeof(int)) )
CHECK_CUDA( cudaMalloc((void**) &dA_values, A_nnz * sizeof(float)) )
- CHECK_CUDA( cudaMalloc((void**) &dB, nrhs * A_num_cols * sizeof(float)) )
- CHECK_CUDA( cudaMalloc((void**) &dC, nrhs * A_num_rows * sizeof(float)) )
+ CHECK_CUDA( cudaMalloc((void**) &dB, nrhs * ldb * sizeof(float)) )
+ CHECK_CUDA( cudaMalloc((void**) &dC, nrhs * A_num_rows*2 * sizeof(float)) )
CHECK_CUDA( cudaMemcpy(dA_csrOffsets, hA_csrOffsets,
(A_num_rows + 1) * sizeof(int),
@@ -108,9 +113,9 @@ int main(void) {
cudaMemcpyHostToDevice) )
CHECK_CUDA( cudaMemcpy(dA_values, hA_values, A_nnz * sizeof(float),
cudaMemcpyHostToDevice) )
- CHECK_CUDA( cudaMemcpy(dB, hB, nrhs * A_num_cols * sizeof(float),
+ CHECK_CUDA( cudaMemcpy(dB, hB, nrhs * ldb * sizeof(float),
cudaMemcpyHostToDevice) )
- CHECK_CUDA( cudaMemcpy(dC, hC, nrhs * A_num_rows * sizeof(float),
+ CHECK_CUDA( cudaMemcpy(dC, hC, nrhs * A_num_rows*2 * sizeof(float),
cudaMemcpyHostToDevice) )
//--------------------------------------------------------------------------
// CUSPARSE APIs
@@ -169,15 +174,17 @@ int main(void) {
CHECK_CUSPARSE( cusparseDestroy(handle) )
//--------------------------------------------------------------------------
// device result check
- CHECK_CUDA( cudaMemcpy(hC, dC, nrhs * A_num_rows * sizeof(float),
+ CHECK_CUDA( cudaMemcpy(hC, dC, nrhs * A_num_rows*2 * sizeof(float),
cudaMemcpyDeviceToHost) )
int correct = 1;
- for (int i = 0; i < nrhs * A_num_rows; i++) {
+ for (int i = 0; i < nrhs * A_num_rows*2; i++) {
+ printf("%f ", hC[i]);
if (hC[i] != hY_result[i]) { // direct floating point comparison is not
correct = 0; // reliable
- break;
+ // break;
}
}
+ printf("\n");
if (correct)
printf("spsm_csr_example test PASSED\n");
elseCompiling and running Patch 2 gives: ========= CUDA-MEMCHECK
1.000000 2.000000 3.000000 4.000000 0.000000 0.000000 0.000000 0.000000 1.000000 2.000000 3.000000 4.000000 0.000000 0.000000 0.000000 0.000000
spsm_csr_example test FAILED: wrong result
========= ERROR SUMMARY: 0 errors |
This PR fixes the case when result and input tensors have different strides. cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to write the result. This is "fixed" in PyTorch code by copying the input tensor to a tensor with same strides as result tensor has. ghstack-source-id: 7bb46bd Pull Request resolved: pytorch#66401
|
Correct.zip |
…ve (pytorch#66401) Summary: Pull Request resolved: pytorch#66401 This PR fixes the case when result and input tensors have different strides. cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to write the result. This is "fixed" in PyTorch code by copying the input tensor to a tensor with same strides as result tensor has. cc nikitaved pearu cpuhrsch IvanYashchuk ngimel Test Plan: Imported from OSS Reviewed By: davidberard98 Differential Revision: D32177966 Pulled By: cpuhrsch fbshipit-source-id: 118437409df147f04dce02763aff9bfd33f87c63
Stack from ghstack:
triangular_solve_out#62180torch.addmm#65606torch.addwith all inputs sparse #64391addmv_out#61536This PR fixes the case when result and input tensors have different
strides.
cuSPARSE from CUDA 11.3.1 has a bug: it doesn't use correct strides to
write the result. This is "fixed" in PyTorch code by copying the input
tensor to a tensor with same strides as result tensor has.
The bug can be reproduced using the code given in #66401 (comment).
cc @nikitaved @pearu @cpuhrsch @IvanYashchuk @ngimel
Differential Revision: D32177966