@@ -37,7 +37,7 @@ class CuSparseDescriptor {
3737 std::unique_ptr<T, CuSparseDescriptorDeleter<T, destructor>> descriptor_;
3838};
3939
40- #if AT_USE_CUSPARSE_CONST_DESCRIPTORS()
40+ #if AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS()
4141template <typename T, cusparseStatus_t (*destructor)(const T*)>
4242struct ConstCuSparseDescriptorDeleter {
4343 void operator ()(T* x) {
@@ -60,16 +60,15 @@ class ConstCuSparseDescriptor {
6060 protected:
6161 std::unique_ptr<T, ConstCuSparseDescriptorDeleter<T, destructor>> descriptor_;
6262};
63- #endif // AT_USE_CUSPARSE_CONST_DESCRIPTORS
63+ #endif // AT_USE_CUSPARSE_CONST_DESCRIPTORS || AT_USE_HIPSPARSE_CONST_DESCRIPTORS
6464
6565#if defined(USE_ROCM)
66- // hipSPARSE doesn't define this
67- using cusparseMatDescr = std::remove_pointer<cusparseMatDescr_t>::type;
68- using cusparseDnMatDescr = std::remove_pointer<cusparseDnMatDescr_t>::type;
69- using cusparseDnVecDescr = std::remove_pointer<cusparseDnVecDescr_t>::type;
70- using cusparseSpMatDescr = std::remove_pointer<cusparseSpMatDescr_t>::type;
71- using cusparseSpMatDescr = std::remove_pointer<cusparseSpMatDescr_t>::type;
72- using cusparseSpGEMMDescr = std::remove_pointer<cusparseSpGEMMDescr_t>::type;
66+ using cusparseMatDescr = std::remove_pointer<hipsparseMatDescr_t>::type;
67+ using cusparseDnMatDescr = std::remove_pointer<hipsparseDnMatDescr_t>::type;
68+ using cusparseDnVecDescr = std::remove_pointer<hipsparseDnVecDescr_t>::type;
69+ using cusparseSpMatDescr = std::remove_pointer<hipsparseSpMatDescr_t>::type;
70+ using cusparseSpMatDescr = std::remove_pointer<hipsparseSpMatDescr_t>::type;
71+ using cusparseSpGEMMDescr = std::remove_pointer<hipsparseSpGEMMDescr_t>::type;
7372#if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
7473using bsrsv2Info = std::remove_pointer<bsrsv2Info_t>::type;
7574using bsrsm2Info = std::remove_pointer<bsrsm2Info_t>::type;
@@ -145,7 +144,7 @@ class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor
145144
146145// AT_USE_HIPSPARSE_GENERIC_52_API() || (AT_USE_CUSPARSE_GENERIC_API() && AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS())
147146
148- #elif AT_USE_CUSPARSE_CONST_DESCRIPTORS()
147+ #elif AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS()
149148 class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor
150149 : public ConstCuSparseDescriptor<
151150 cusparseDnMatDescr,
0 commit comments