Skip to content

Commit 360a25a

Browse files
author
emcastillo
committed
Merge pull request #2872 from niboshi/argminmax-dtype
Fix `argmin`/`argmax` `dtype` argument
1 parent a5cd457 commit 360a25a

4 files changed

Lines changed: 71 additions & 14 deletions

File tree

cupy/core/_routines_statistics.pyx

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -225,27 +225,27 @@ nanmax = create_reduction_func(
225225

226226
cdef _argmin = create_reduction_func(
227227
'cupy_argmin',
228-
('?->q', 'B->q', 'h->q', 'H->q', 'i->q', 'I->q', 'l->q', 'L->q',
229-
'q->q', 'Q->q',
230-
('e->q', (None, 'my_argmin_float(a, b)', None, None)),
231-
('f->q', (None, 'my_argmin_float(a, b)', None, None)),
232-
('d->q', (None, 'my_argmin_float(a, b)', None, None)),
233-
('F->q', (None, 'my_argmin_complex(a, b)', None, None)),
234-
('D->q', (None, 'my_argmin_complex(a, b)', None, None))),
228+
tuple(['{}->{}'.format(d, r) for r in 'qlihb' for d in '?BhHiIlLqQ'])
229+
+ (
230+
('e->q', (None, 'my_argmin_float(a, b)', None, None)),
231+
('f->q', (None, 'my_argmin_float(a, b)', None, None)),
232+
('d->q', (None, 'my_argmin_float(a, b)', None, None)),
233+
('F->q', (None, 'my_argmin_complex(a, b)', None, None)),
234+
('D->q', (None, 'my_argmin_complex(a, b)', None, None))),
235235
('min_max_st<type_in0_raw>(in0, _J)', 'my_argmin(a, b)', 'out0 = a.index',
236236
'min_max_st<type_in0_raw>'),
237237
None, _min_max_preamble)
238238

239239

240240
cdef _argmax = create_reduction_func(
241241
'cupy_argmax',
242-
('?->q', 'B->q', 'h->q', 'H->q', 'i->q', 'I->q', 'l->q', 'L->q',
243-
'q->q', 'Q->q',
244-
('e->q', (None, 'my_argmax_float(a, b)', None, None)),
245-
('f->q', (None, 'my_argmax_float(a, b)', None, None)),
246-
('d->q', (None, 'my_argmax_float(a, b)', None, None)),
247-
('F->q', (None, 'my_argmax_complex(a, b)', None, None)),
248-
('D->q', (None, 'my_argmax_complex(a, b)', None, None))),
242+
tuple(['{}->{}'.format(d, r) for r in 'qlihb' for d in '?BhHiIlLqQ'])
243+
+ (
244+
('e->q', (None, 'my_argmax_float(a, b)', None, None)),
245+
('f->q', (None, 'my_argmax_float(a, b)', None, None)),
246+
('d->q', (None, 'my_argmax_float(a, b)', None, None)),
247+
('F->q', (None, 'my_argmax_complex(a, b)', None, None)),
248+
('D->q', (None, 'my_argmax_complex(a, b)', None, None))),
249249
('min_max_st<type_in0_raw>(in0, _J)', 'my_argmax(a, b)', 'out0 = a.index',
250250
'min_max_st<type_in0_raw>'),
251251
None, _min_max_preamble)

cupy/core/core.pyx

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,14 @@ cdef class ndarray:
728728
keepdims=False):
729729
"""Returns the indices of the maximum along a given axis.
730730
731+
.. note::
732+
``dtype`` and ``keepdim`` arguments are specific to CuPy. They are
733+
not in NumPy.
734+
735+
.. note::
736+
``axis`` argument accepts a tuple of ints, but this is specific to
737+
CuPy. NumPy does not support it.
738+
731739
.. seealso::
732740
:func:`cupy.argmax` for full documentation,
733741
:meth:`numpy.ndarray.argmax`
@@ -749,6 +757,14 @@ cdef class ndarray:
749757
keepdims=False):
750758
"""Returns the indices of the minimum along a given axis.
751759
760+
.. note::
761+
``dtype`` and ``keepdim`` arguments are specific to CuPy. They are
762+
not in NumPy.
763+
764+
.. note::
765+
``axis`` argument accepts a tuple of ints, but this is specific to
766+
CuPy. NumPy does not support it.
767+
752768
.. seealso::
753769
:func:`cupy.argmin` for full documentation,
754770
:meth:`numpy.ndarray.argmin`

cupy/sorting/search.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@ def argmax(a, axis=None, dtype=None, out=None, keepdims=False):
1717
Returns:
1818
cupy.ndarray: The indices of the maximum of ``a`` along an axis.
1919
20+
.. note::
21+
``dtype`` and ``keepdim`` arguments are specific to CuPy. They are
22+
not in NumPy.
23+
24+
.. note::
25+
``axis`` argument accepts a tuple of ints, but this is specific to
26+
CuPy. NumPy does not support it.
27+
2028
.. seealso:: :func:`numpy.argmax`
2129
2230
"""
@@ -42,6 +50,14 @@ def argmin(a, axis=None, dtype=None, out=None, keepdims=False):
4250
Returns:
4351
cupy.ndarray: The indices of the minimum of ``a`` along an axis.
4452
53+
.. note::
54+
``dtype`` and ``keepdim`` arguments are specific to CuPy. They are
55+
not in NumPy.
56+
57+
.. note::
58+
``axis`` argument accepts a tuple of ints, but this is specific to
59+
CuPy. NumPy does not support it.
60+
4561
.. seealso:: :func:`numpy.argmin`
4662
4763
"""

tests/cupy_tests/sorting_tests/test_search.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy
44

5+
import cupy
56
from cupy import testing
67

78

@@ -153,6 +154,30 @@ def test_argmin_zero_size_axis1(self, xp, dtype):
153154
return a.argmin(axis=1)
154155

155156

157+
@testing.gpu
158+
@testing.parameterize(*testing.product({
159+
'func': ['argmin', 'argmax'],
160+
'is_module': [True, False],
161+
'shape': [(3, 4), ()],
162+
}))
163+
class TestArgMinMaxDtype(unittest.TestCase):
164+
165+
@testing.for_dtypes(
166+
dtypes=[numpy.int8, numpy.int16, numpy.int32, numpy.int64],
167+
name='result_dtype')
168+
@testing.for_all_dtypes(name='in_dtype')
169+
def test_argminmax_dtype(self, in_dtype, result_dtype):
170+
a = testing.shaped_random(self.shape, cupy, in_dtype)
171+
if self.is_module:
172+
func = getattr(cupy, self.func)
173+
y = func(a, dtype=result_dtype)
174+
else:
175+
func = getattr(a, self.func)
176+
y = func(dtype=result_dtype)
177+
assert y.shape == ()
178+
assert y.dtype == result_dtype
179+
180+
156181
@testing.parameterize(
157182
{'cond_shape': (2, 3, 4), 'x_shape': (2, 3, 4), 'y_shape': (2, 3, 4)},
158183
{'cond_shape': (4,), 'x_shape': (2, 3, 4), 'y_shape': (2, 3, 4)},

0 commit comments

Comments
 (0)