Skip to content

Signatures and behaviors of argmax and argmin are incompatible with NumPy #2595

@leofang

Description

@leofang

First, for the signature: In CuPy the signature for argmax() and argmin() is

cupy.argm*(a, axis=None, dtype=None, out=None, keepdims=False)

But in NumPy it's

numpy.argm*(a, axis=None, out=None)

That is, dtype and keepdims should be removed.

Second, for the behavior: In CuPy the axis could be a tuple, but in Numpy it can only be an integer:

>>> import cupy as cp
>>> a = cp.arange(60).reshape(3,4,5)
>>> a.argmax(axis=(0,1))
array([11, 11, 11, 11, 11], dtype=int64)
>>> 
>>> import numpy as np
>>> a = np.arange(60).reshape(3,4,5)
>>> a.argmax(axis=(0,1))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: 'tuple' object cannot be interpreted as an integer

I think both are easy to fix. Just add a few guards prior to calling the actual workhorses.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions