Skip to content

Commit 17f2f46

Browse files
authored
Test CUB spmv
style in line with the existing tests and #2598
1 parent 1654d38 commit 17f2f46

1 file changed

Lines changed: 23 additions & 0 deletions

File tree

tests/cupyx_tests/scipy_tests/sparse_tests/test_csr.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,3 +1524,26 @@ def test_getitem_slice_stop_too_small(self, xp, sp):
15241524
@testing.numpy_cupy_allclose(sp_name='sp')
15251525
def test_getitem_slice_stop_too_large(self, xp, sp):
15261526
return _make(xp, sp, self.dtype)[None:4]
1527+
1528+
1529+
@testing.parameterize(*testing.product({
1530+
'make_method': [
1531+
'_make', '_make_unordered', '_make_empty', '_make_duplicate',
1532+
'_make_shape'],
1533+
'dtype': [numpy.float32, numpy.float64, cupy.complex64, cupy.complex128],
1534+
}))
1535+
@testing.with_requires('scipy')
1536+
@testing.gpu
1537+
@unittest.skipIf(cupy.cuda.cub_enabled is False, 'The CUB module is not built')
1538+
class TestCUBspmv(unittest.TestCase):
1539+
@property
1540+
def make(self):
1541+
return globals()[self.make_method]
1542+
1543+
@testing.numpy_cupy_allclose(sp_name='sp')
1544+
def test_mul_dense_vector(self, xp, sp):
1545+
assert cupy.cuda.cub_enabled
1546+
1547+
m = self.make(xp, sp, self.dtype)
1548+
x = xp.arange(4).astype(self.dtype)
1549+
return m * x

0 commit comments

Comments
 (0)