Skip to content

Commit 43b2da6

Browse files
committed
Set backends in setUp and tearDown of CubTests
1 parent 2f61909 commit 43b2da6

5 files changed

Lines changed: 45 additions & 15 deletions

File tree

cupy/cuda/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@
3535
except ImportError:
3636
thrust_enabled = False
3737

38+
try:
39+
from cupy.cuda import cub # NOQA
40+
cub_enabled = True
41+
except ImportError:
42+
cub_enabled = False
43+
3844
try:
3945
from cupy.cuda import nccl # NOQA
4046
nccl_enabled = True

tests/cupy_tests/core_tests/test_ndarray_reduction.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,16 @@ def test_ptp_nan_imag(self, xp, dtype):
216216
'order': ('C', 'F'),
217217
}))
218218
@testing.gpu
219-
@unittest.skipUnless(
220-
cupy.core._backend.get_routine_backends() == ['cub'],
221-
'The CUB routine is not enabled')
219+
@unittest.skipUnless(cupy.cuda.cub_enabled, 'The CUB routine is not enabled')
222220
class TestCubReduction(unittest.TestCase):
221+
222+
def setUp(self):
223+
self.old_backends = cupy.core._backend.get_routine_backends()
224+
cupy.core._backend.set_routine_backends(['cub'])
225+
226+
def tearDown(self):
227+
cupy.core._backend.set_routine_backends(self.old_backends)
228+
223229
@testing.for_contiguous_axes()
224230
@testing.for_all_dtypes(no_bool=True, no_float16=True)
225231
@testing.numpy_cupy_allclose(rtol=1E-5)

tests/cupy_tests/math_tests/test_sumprod.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,10 +198,16 @@ def test_prod_dtype(self, xp, src_dtype, dst_dtype):
198198
'order': ('C', 'F'),
199199
}))
200200
@testing.gpu
201-
@unittest.skipUnless(
202-
cupy.core._backend.get_routine_backends() == ['cub'],
203-
'The CUB routine is not enabled')
204-
class TestCUBreduction(unittest.TestCase):
201+
@unittest.skipUnless(cupy.cuda.cub_enabled, 'The CUB routine is not enabled')
202+
class TestCubReduction(unittest.TestCase):
203+
204+
def setUp(self):
205+
self.old_backends = cupy.core._backend.get_routine_backends()
206+
cupy.core._backend.set_routine_backends(['cub'])
207+
208+
def tearDown(self):
209+
cupy.core._backend.set_routine_backends(self.old_backends)
210+
205211
@testing.for_contiguous_axes()
206212
# sum supports less dtypes; don't test float16 as it's not as accurate?
207213
@testing.for_dtypes('lLfdFD')

tests/cupy_tests/sorting_tests/test_search.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,16 @@ def test_argmin_zero_size_axis1(self, xp, dtype):
166166
'order': ('C', 'F'),
167167
}))
168168
@testing.gpu
169-
@unittest.skipUnless(
170-
cupy.core._backend.get_routine_backends() == ['cub'],
171-
'The CUB routine is not enabled')
172-
class TestCUBreduction(unittest.TestCase):
169+
@unittest.skipUnless(cupy.cuda.cub_enabled, 'The CUB routine is not enabled')
170+
class TestCubReduction(unittest.TestCase):
171+
172+
def setUp(self):
173+
self.old_backends = cupy.core._backend.get_routine_backends()
174+
cupy.core._backend.set_routine_backends(['cub'])
175+
176+
def tearDown(self):
177+
cupy.core._backend.set_routine_backends(self.old_backends)
178+
173179
@testing.for_dtypes('bhilBHILefdFD')
174180
@testing.numpy_cupy_allclose(rtol=1E-5)
175181
def test_cub_argmin(self, xp, dtype):

tests/cupyx_tests/scipy_tests/sparse_tests/test_csr.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1534,10 +1534,16 @@ def test_getitem_slice_stop_too_large(self, xp, sp):
15341534
}))
15351535
@testing.with_requires('scipy')
15361536
@testing.gpu
1537-
@unittest.skipUnless(
1538-
cupy.core._backend.get_routine_backends() == ['cub'],
1539-
'The CUB routine is not enabled')
1540-
class TestCUBspmv(unittest.TestCase):
1537+
@unittest.skipUnless(cupy.cuda.cub_enabled, 'The CUB routine is not enabled')
1538+
class TestCubSpmv(unittest.TestCase):
1539+
1540+
def setUp(self):
1541+
self.old_backends = cupy.core._backend.get_routine_backends()
1542+
cupy.core._backend.set_routine_backends(['cub'])
1543+
1544+
def tearDown(self):
1545+
cupy.core._backend.set_routine_backends(self.old_backends)
1546+
15411547
@property
15421548
def make(self):
15431549
return globals()[self.make_method]

0 commit comments

Comments
 (0)