API: Introduce np.isdtype function [Array API]#25054
Conversation
np.isdtype function
9f3a100 to
0264091
Compare
mhvk
left a comment
There was a problem hiding this comment.
Bit of a fly-by review out of curiosity, but one small comment inline.
In favour of adding 'flexible' would seem to be that it is fairly common to have functions that work well with either float or complex, and it may just have been an oversight of the array API (which initially had no support for complex numbers).
|
Thanks @mtsokol!
It's not an oversight; Now I can see a case for the opposite point as well, but I think we should start with only implementing what the standard says, and only reconsidering if it becomes clear that this is cumbersome. At that point a naming discussion should be had I think, because One of the upsides is that it's often unclear how well complex dtypes are supported by a function. If we read: np.isdtype(dtype, ('real floating', 'complex floating'))that is quite clear, and easier to understand than the old numpy names for this.
I think you should enforce the typing of the standard - so it'd either be a dtype object or a string - and if it's a string then it must be one of the strings that is a named collection of dtypes in the standard. NumPy has always had a ton of different ways to say |
|
Yes, makes sense to have a replacement for "flexible" only if it turns out there is an actual need. |
@rgommers Sure! I updated the implementation (and tests) to match the standard exactly. |
28e435e to
3e1c06e
Compare
3e1c06e to
a227f5d
Compare
np.isdtype functionnp.isdtype function [Array API]
numpy/_core/numerictypes.py
Outdated
|
|
||
| """ | ||
| # validate and preprocess arguments | ||
| if not isinstance(dtype, (type, ma.dtype)): |
There was a problem hiding this comment.
I'm not sure if there's an easy way to spell this that's more accurate than just checking for type, but this would let you do e.g. np.isdtype(np.int64, dict) and return False. I guess that is true but is perhaps against the spirit of this check. I don't know if there's an easy way to get a list of all the dtype scalar types numpy knows about at any given moment (which could include user dtypes or dtypes defined in downstream packages).
There was a problem hiding this comment.
I don't know it either. Let's ask about it during today's community call.
There was a problem hiding this comment.
Maybe checking issubclass(dtype, np.generic) would make snese?
In [31]: np.int64.mro()
Out[31]:
[numpy.int64,
numpy.signedinteger,
numpy.integer,
numpy.number,
numpy.generic,
object]There was a problem hiding this comment.
We decided at the triage meeting the other day that we wanted to make this error reject all dtypes or scalar types that aren't in the array API standard
There was a problem hiding this comment.
reject all dtypes or scalar types that aren't in the array API standard
That would make it harder to use while supporting, e.g., float16, and seems like an odd suggestion - what is the justification?
There was a problem hiding this comment.
Looking at the jax bfloat16 type as a real-world example, I see:
In [10]: np.issubdtype(ml_dtypes.bfloat16, np.floating)
Out[10]: False
In [11]: ml_dtypes.bfloat16.mro()
Out[11]: [ml_dtypes.bfloat16, numpy.generic, object]
Which isn't particularly useful but is correct according to the implementation of issubdtype.
So I think we might need to just not handle third-party dtypes for now, until we have a better story for registering dtypes in a type hierarchy. Does that sound reasonable?
Handling all the builtin numpy dtypes here makes sense, no worries about doing that.
There was a problem hiding this comment.
So I think we might need to just not handle third-party dtypes for now, until we have a better story for registering dtypes in a type hierarchy. Does that sound reasonable?
Thanks for checking. That sounds fine to me.
There was a problem hiding this comment.
I updated the implementation so now it accepts all NumPy's dtypes as inputs.
There was a problem hiding this comment.
For what it's worth, the reason that np.issubdtype(ml_dtypes.bfloat16, np.floating) returns False is that in several places NumPy hard-codes the assumption that there is only a single 16-bit floating-point type, so we could not make bfloat16 a subclass of np.floating without causing collisions across the codebase. If you have suggestions for how to do better, we'd love to hear it!
There was a problem hiding this comment.
I’d love to relax that assumption inside numpy, I’m sure patches that do that would be reviewed if you can get anyone to work on it. I also wasn’t privy to previous discussions about this but at this point I think numpy should probably support bfloat16 natively. Of course that’s just my opinion. Also with the new dtype API it should be much more straightforward to upstream a new dtype into numpy, there’s no need anymore to mess with all the complicated custom templating and codegen.
7e7b550 to
cd8a56c
Compare
a22def4 to
c74dbb4
Compare
| False | ||
| >>> np.isdtype(np.int64, (np.uint64, "signed integer")) | ||
| True | ||
|
|
There was a problem hiding this comment.
It'd be nice to explicitly add the example in the PR discussion regarding checking for real-floating only vs. an API that supports both real and complex floating.
There was a problem hiding this comment.
Do you mean in the Examples section? Or PR's description?
I added two more examples to the docstring for checking real-floating only and real and complex floating.
numpy/_core/numerictypes.py
Outdated
| True | ||
|
|
||
| """ | ||
| # validate and preprocess arguments |
There was a problem hiding this comment.
It would be useful to extend this comment to state what is actually happening here. The intended inputs are np.float32 & co, which aren't instances of multiarray.dtype. So I'm not sure that this validation is right.
Also, rejection of non-compliant objects may be good to do more explicitly, now the errors may be a bit obscure:
In [9]: np.isdtype(np.float32, np.ones(1))
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[9], line 1
----> 1 np.isdtype(np.float32, np.ones(1))
File ~/code/numpy/build-install/usr/lib/python3.11/site-packages/numpy/_core/numerictypes.py:436, in isdtype(dtype, kind)
434 if isinstance(kind, ma.dtype):
435 kind = kind.type
--> 436 if kind not in allTypes.values():
437 raise TypeError(
438 "kind argument must be comprised of NumPy dtypes or "
439 f"strings only, but it is a {kind}."
440 )
442 processed_kinds.add(kind)
TypeError: descriptor '__array_wrap__' for 'numpy.generic' objects doesn't apply to a 'numpy.ndarray' objectThere was a problem hiding this comment.
I added a short docstring for the helper function instead.
Basically, if we have a dtype instance we need to extract type. Then we check if the dtype is in allTypes (I added a check for numpy.ndarray to avoid the error you posted).
f830455 to
3a3684b
Compare
3a3684b to
151d70a
Compare
151d70a to
4598f0c
Compare
|
Sorry for not merging this earlier in the week, I dropped it without looking at it again accidentally. Thanks for bringing this one home @mtsokol! |
np.isdtype function [Array API]np.isdtype function [Array API]
Hi @rgommers @ngoldbaum,
This PR adds
np.isdtypementioned in the NEP 52 and the tracking issue #23999.I followed the description in https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html.
Questions:
kindargument aflexibleis missing (union of floating and complex types), butintegralis present. Is it on purpose?dtypeargument to be adtypebut in the implementation I accept anything that can be consumed bynp.dtype, such as strings "int64" etc. Should I enforce dtype instances only?