Skip to content

Commit e64914e

Browse files
authored
conditionally enable hipsparse const descriptors (#1215)
1 parent 7fce5e1 commit e64914e

2 files changed

Lines changed: 28 additions & 12 deletions

File tree

aten/src/ATen/cuda/CUDASparse.h

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
#pragma once
22

33
#include <ATen/cuda/CUDAContext.h>
4+
#if defined(USE_ROCM)
5+
#include <hipsparse/hipsparse-version.h>
6+
#define HIPSPARSE_VERSION ((hipsparseVersionMajor*100000) + (hipsparseVersionMinor*100) + hipsparseVersionPatch)
7+
#endif
48

59
// cuSparse Generic API added in CUDA 10.1
610
// Windows support added in CUDA 11.0
@@ -25,20 +29,33 @@
2529
#define AT_USE_CUSPARSE_CONST_DESCRIPTORS() 0
2630
#endif
2731

32+
#if defined(USE_ROCM)
33+
34+
// hipSparse const API added in v2.3.6
35+
#if HIPSPARSE_VERSION >= 200306
36+
#define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 1
37+
#define AT_USE_HIPSPARSE_GENERIC_52_API() 0
38+
#define AT_USE_HIPSPARSE_GENERIC_API() 1
39+
#else
40+
#define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 0
41+
2842
// hipSparse Generic API ROCm 5.2
29-
#if defined(USE_ROCM) && ROCM_VERSION >= 50200
43+
#if ROCM_VERSION >= 50200
3044
#define AT_USE_HIPSPARSE_GENERIC_52_API() 1
3145
#else
3246
#define AT_USE_HIPSPARSE_GENERIC_52_API() 0
3347
#endif
3448

3549
// hipSparse Generic API ROCm 5.1
36-
#if defined(USE_ROCM) && ROCM_VERSION >= 50100
50+
#if ROCM_VERSION >= 50100
3751
#define AT_USE_HIPSPARSE_GENERIC_API() 1
3852
#else
3953
#define AT_USE_HIPSPARSE_GENERIC_API() 0
4054
#endif
4155

56+
#endif // HIPSPARSE_VERSION >= 20306
57+
#endif // USE_ROCM
58+
4259
// cuSparse Generic API spsv function was added in CUDA 11.3.0
4360
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11500)
4461
#define AT_USE_CUSPARSE_GENERIC_SPSV() 1

aten/src/ATen/cuda/CUDASparseDescriptors.h

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class CuSparseDescriptor {
3737
std::unique_ptr<T, CuSparseDescriptorDeleter<T, destructor>> descriptor_;
3838
};
3939

40-
#if AT_USE_CUSPARSE_CONST_DESCRIPTORS()
40+
#if AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS()
4141
template <typename T, cusparseStatus_t (*destructor)(const T*)>
4242
struct ConstCuSparseDescriptorDeleter {
4343
void operator()(T* x) {
@@ -60,16 +60,15 @@ class ConstCuSparseDescriptor {
6060
protected:
6161
std::unique_ptr<T, ConstCuSparseDescriptorDeleter<T, destructor>> descriptor_;
6262
};
63-
#endif // AT_USE_CUSPARSE_CONST_DESCRIPTORS
63+
#endif // AT_USE_CUSPARSE_CONST_DESCRIPTORS || AT_USE_HIPSPARSE_CONST_DESCRIPTORS
6464

6565
#if defined(USE_ROCM)
66-
// hipSPARSE doesn't define this
67-
using cusparseMatDescr = std::remove_pointer<cusparseMatDescr_t>::type;
68-
using cusparseDnMatDescr = std::remove_pointer<cusparseDnMatDescr_t>::type;
69-
using cusparseDnVecDescr = std::remove_pointer<cusparseDnVecDescr_t>::type;
70-
using cusparseSpMatDescr = std::remove_pointer<cusparseSpMatDescr_t>::type;
71-
using cusparseSpMatDescr = std::remove_pointer<cusparseSpMatDescr_t>::type;
72-
using cusparseSpGEMMDescr = std::remove_pointer<cusparseSpGEMMDescr_t>::type;
66+
using cusparseMatDescr = std::remove_pointer<hipsparseMatDescr_t>::type;
67+
using cusparseDnMatDescr = std::remove_pointer<hipsparseDnMatDescr_t>::type;
68+
using cusparseDnVecDescr = std::remove_pointer<hipsparseDnVecDescr_t>::type;
69+
using cusparseSpMatDescr = std::remove_pointer<hipsparseSpMatDescr_t>::type;
70+
using cusparseSpMatDescr = std::remove_pointer<hipsparseSpMatDescr_t>::type;
71+
using cusparseSpGEMMDescr = std::remove_pointer<hipsparseSpGEMMDescr_t>::type;
7372
#if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
7473
using bsrsv2Info = std::remove_pointer<bsrsv2Info_t>::type;
7574
using bsrsm2Info = std::remove_pointer<bsrsm2Info_t>::type;
@@ -145,7 +144,7 @@ class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor
145144

146145
//AT_USE_HIPSPARSE_GENERIC_52_API() || (AT_USE_CUSPARSE_GENERIC_API() && AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS())
147146

148-
#elif AT_USE_CUSPARSE_CONST_DESCRIPTORS()
147+
#elif AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS()
149148
class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor
150149
: public ConstCuSparseDescriptor<
151150
cusparseDnMatDescr,

0 commit comments

Comments
 (0)