MAINT,ENH: Simplify CScalar handling and ready it for arbitrary dtypes#9503
MAINT,ENH: Simplify CScalar handling and ready it for arbitrary dtypes#9503
CScalar handling and ready it for arbitrary dtypes#9503Conversation
This simplifies the CScalar, it does however use NumPy C-API to do so. The thought is two-fold: 1. Try to get rid of any dtype specific code as much as possible 2. Remove NumPy/CScalar detection to always prefer CScalar for simplicity 3. Smaller misc improvements
|
|
||
| x = cupy.empty((16,), dtype=cupy.uint64) | ||
| x[:] = -1 | ||
| x[:] = cupy.int64(-1) # wrap-around cast to uint |
There was a problem hiding this comment.
Maybe clearer to comment here, but this is an actual behavior change. Previously, we checked for things like uint8_arr + (-1) raising an error now in NumPy in guess_routine, now it is checked one level lower.
That means that direct elementwise kernel calls now also see the new behavior. This -1 behavior especially can be tedious, though. It's certainly possible to restore the old behavior if we think it is likely to create hassle.
(Or move it later if we find it does in practice after a release.)
There was a problem hiding this comment.
@asi1024 would know this a lot better since he built the whole CuPy JIT machinery, but I think somewhere in the compiler we have a way to change how Python scalars should be interpreted inside a JIT kernel. Maybe it's this:
cupy/cupyx/jit/_cuda_typerules.py
Lines 111 to 137 in f60edcf
My thinking is: Maybe there is a way to leave the CuPy JIT default unchanged by this PR?
There was a problem hiding this comment.
In that case it is unchanged since the JIT kernel discovers a reasonable type here (int64) which is also the final kernel type.
The change only occurs for places where the kernel C-type is explicitly a uint or narrower int.
cupy/_core/_scalar.pyx
Outdated
| return descr->f; | ||
| } | ||
| #endif | ||
| """ |
There was a problem hiding this comment.
I'll confirm that this works for NumPy 1.x locally. As mentioned, we can simplify this dance to just the cdef PyArray_Pack() if we hard require NumPy 2 (i.e. with NPY_TARGET_VERSION=NPY_2_0_API_VERSION, NumPy will generate an error if importing with NumPy 1.x).
But, I wasn't sure that we should do that hard yet and while you may have to be me to know that this is all fine, I am me :).
There was a problem hiding this comment.
Q: Remind me what's the conclusion here? That we still allow importing with 1.x, we just wouldn't claim full support for it?
There was a problem hiding this comment.
Ah, I never removed this :/! I think we said we don't need <2.0, it still felt a bit strange to me to enforce strictly in this PR. But let me just do this.
It is easy to restore this if there is any doubt about in the end. (Makes this a lot cleaner, all we'll have left is the Cython define for PyArray_Pack).
| # NOTE(seberg): This uses assignment logic, which is very subtly | ||
| # different from casting by rejecting nan -> int. This is *only* | ||
| # relevant for `casting="unsafe"` passed to ufuncs with `dtype=`. | ||
| # It also means we fail for out of bound integers (NEP 50 change). |
There was a problem hiding this comment.
Maybe to explain this: This uses the same logic as arr_with_dtype[0] = value and that is exceedingly subtly different from a casting.
That truly only matters for things like cp.add(cp.float64(np.nan), 1, casting="unsafe", dtype=int) style call. So I think it's safe to ignore :).
|
/test mini |
|
/test mini |
| try: | ||
| _scalar.get_typename(dtype) # allow if we know a C typename. | ||
| except (ValueError, KeyError): | ||
| if not error: | ||
| return False | ||
| else: | ||
| raise ValueError(f'Unsupported dtype {dtype}') from None |
There was a problem hiding this comment.
Should we return True when scalar.get_typename(dtype) does not raise an exception?
There was a problem hiding this comment.
Yes, good catch. This was prep for ml_dtypes and would be wrong for those so we can string them into get_typename() and not in many places.
| @@ -1,5 +1,7 @@ | |||
| cimport cython # NOQA | |||
|
|
|||
| from . cimport _scalar | |||
|
|
||
| x = cupy.empty((16,), dtype=cupy.uint64) | ||
| x[:] = -1 | ||
| x[:] = cupy.int64(-1) # wrap-around cast to uint |
There was a problem hiding this comment.
@asi1024 would know this a lot better since he built the whole CuPy JIT machinery, but I think somewhere in the compiler we have a way to change how Python scalars should be interpreted inside a JIT kernel. Maybe it's this:
cupy/cupyx/jit/_cuda_typerules.py
Lines 111 to 137 in f60edcf
My thinking is: Maybe there is a way to leave the CuPy JIT default unchanged by this PR?
cupy/_core/_scalar.pyx
Outdated
| return descr->f; | ||
| } | ||
| #endif | ||
| """ |
There was a problem hiding this comment.
Q: Remind me what's the conclusion here? That we still allow importing with 1.x, we just wouldn't claim full support for it?
|
/test mini |
|
LGTM! CI is green too. Though Kenichi-san mentioned there is a code freeze now. Let merge it after the freeze and also give @asi1024 a bit more time in case he wants to chime in 🙂 |
|
@seberg forgot to ask, is there a simple reproducer for us to check the perf difference before and after this PR? Would be nice to know the expected ballpark improvement. |
Code freeze is lifted. Let me get this merged. @asi1024 please let us know if you have any concern and we can follow up in a separate PR.
Would still be nice to keep a record in this PR for future reference, in case people come and ask for what drove the decision of making NumPy a build-time dependency. |
MAINT,ENH: Simplify `CScalar` handling and ready it for arbitrary dtypes
Not sure about this, but things like On the grand scheme, the perf differences here are just very small I think. |
This tries to simplify the scalar handling. In part just for maintenance and a small speed boost, but largely to make it easier to support arbitrary dtypes in the scalar code-path (i.e. split out from the
ml_dtypesdraft PR).The simplification ideas ideas are:
CScalarand removing the distinction between "numpy scalar/CScalar" path.The other change is making NumPy a build-time dependency.
I'll add some comments in-line, so discussions are easier to focus on each topic.