Skip to content

MAINT,ENH: Simplify CScalar handling and ready it for arbitrary dtypes#9503

Merged
leofang merged 14 commits intocupy:mainfrom
seberg:cscalar-cleanup
Dec 22, 2025
Merged

MAINT,ENH: Simplify CScalar handling and ready it for arbitrary dtypes#9503
leofang merged 14 commits intocupy:mainfrom
seberg:cscalar-cleanup

Conversation

@seberg
Copy link
Member

@seberg seberg commented Nov 26, 2025

This tries to simplify the scalar handling. In part just for maintenance and a small speed boost, but largely to make it easier to support arbitrary dtypes in the scalar code-path (i.e. split out from the ml_dtypes draft PR).

The simplification ideas ideas are:

  1. Rely on NumPy C-API to convert Python scalars to C data that we can use for kernel launches.
  2. Simplify handling by pushing NEP 50 "weak scalar" handling into the CScalar and removing the distinction between "numpy scalar/CScalar" path.
  3. Make CScalar simply hold the original scalar

The other change is making NumPy a build-time dependency.

I'll add some comments in-line, so discussions are easier to focus on each topic.

@seberg seberg requested a review from a team as a code owner November 26, 2025 10:19

x = cupy.empty((16,), dtype=cupy.uint64)
x[:] = -1
x[:] = cupy.int64(-1) # wrap-around cast to uint
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe clearer to comment here, but this is an actual behavior change. Previously, we checked for things like uint8_arr + (-1) raising an error now in NumPy in guess_routine, now it is checked one level lower.

That means that direct elementwise kernel calls now also see the new behavior. This -1 behavior especially can be tedious, though. It's certainly possible to restore the old behavior if we think it is likely to create hassle.
(Or move it later if we find it does in practice after a release.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@asi1024 would know this a lot better since he built the whole CuPy JIT machinery, but I think somewhere in the compiler we have a way to change how Python scalars should be interpreted inside a JIT kernel. Maybe it's this:

def get_ctype_from_scalar(mode: str, x: Any) -> _cuda_types.Scalar:
if isinstance(x, numpy.generic):
return _cuda_types.Scalar(x.dtype)
if mode == 'numpy':
if isinstance(x, bool):
return _cuda_types.Scalar(numpy.bool_)
if isinstance(x, int):
return _cuda_types.Scalar(numpy.int64)
if isinstance(x, float):
return _cuda_types.Scalar(numpy.float64)
if isinstance(x, complex):
return _cuda_types.Scalar(numpy.complex128)
if mode == 'cuda':
if isinstance(x, bool):
return _cuda_types.Scalar(numpy.bool_)
if isinstance(x, int):
if -(1 << 31) <= x < (1 << 31):
return _cuda_types.Scalar(numpy.int32)
return _cuda_types.Scalar(numpy.int64)
if isinstance(x, float):
return _cuda_types.Scalar(numpy.float32)
if isinstance(x, complex):
return _cuda_types.Scalar(numpy.complex64)
raise NotImplementedError(f'{x} is not scalar object.')

My thinking is: Maybe there is a way to leave the CuPy JIT default unchanged by this PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case it is unchanged since the JIT kernel discovers a reasonable type here (int64) which is also the final kernel type.

The change only occurs for places where the kernel C-type is explicitly a uint or narrower int.

return descr->f;
}
#endif
"""
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll confirm that this works for NumPy 1.x locally. As mentioned, we can simplify this dance to just the cdef PyArray_Pack() if we hard require NumPy 2 (i.e. with NPY_TARGET_VERSION=NPY_2_0_API_VERSION, NumPy will generate an error if importing with NumPy 1.x).

But, I wasn't sure that we should do that hard yet and while you may have to be me to know that this is all fine, I am me :).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Remind me what's the conclusion here? That we still allow importing with 1.x, we just wouldn't claim full support for it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I never removed this :/! I think we said we don't need <2.0, it still felt a bit strange to me to enforce strictly in this PR. But let me just do this.

It is easy to restore this if there is any doubt about in the end. (Makes this a lot cleaner, all we'll have left is the Cython define for PyArray_Pack).

# NOTE(seberg): This uses assignment logic, which is very subtly
# different from casting by rejecting nan -> int. This is *only*
# relevant for `casting="unsafe"` passed to ufuncs with `dtype=`.
# It also means we fail for out of bound integers (NEP 50 change).
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe to explain this: This uses the same logic as arr_with_dtype[0] = value and that is exceedingly subtly different from a casting.
That truly only matters for things like cp.add(cp.float64(np.nan), 1, casting="unsafe", dtype=int) style call. So I think it's safe to ignore :).

@asi1024 asi1024 self-assigned this Nov 28, 2025
@asi1024 asi1024 added cat:performance Performance in terms of speed or memory consumption prio:high labels Nov 28, 2025
@leofang
Copy link
Member

leofang commented Dec 11, 2025

/test mini

@seberg
Copy link
Member Author

seberg commented Dec 11, 2025

/test mini

Comment on lines +67 to +73
try:
_scalar.get_typename(dtype) # allow if we know a C typename.
except (ValueError, KeyError):
if not error:
return False
else:
raise ValueError(f'Unsupported dtype {dtype}') from None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we return True when scalar.get_typename(dtype) does not raise an exception?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good catch. This was prep for ml_dtypes and would be wrong for those so we can string them into get_typename() and not in many places.

Copy link
Member

@leofang leofang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @seberg! LGTM overall. Left some comments/questions.

@@ -1,5 +1,7 @@
cimport cython # NOQA

from . cimport _scalar
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use absolute import


x = cupy.empty((16,), dtype=cupy.uint64)
x[:] = -1
x[:] = cupy.int64(-1) # wrap-around cast to uint
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@asi1024 would know this a lot better since he built the whole CuPy JIT machinery, but I think somewhere in the compiler we have a way to change how Python scalars should be interpreted inside a JIT kernel. Maybe it's this:

def get_ctype_from_scalar(mode: str, x: Any) -> _cuda_types.Scalar:
if isinstance(x, numpy.generic):
return _cuda_types.Scalar(x.dtype)
if mode == 'numpy':
if isinstance(x, bool):
return _cuda_types.Scalar(numpy.bool_)
if isinstance(x, int):
return _cuda_types.Scalar(numpy.int64)
if isinstance(x, float):
return _cuda_types.Scalar(numpy.float64)
if isinstance(x, complex):
return _cuda_types.Scalar(numpy.complex128)
if mode == 'cuda':
if isinstance(x, bool):
return _cuda_types.Scalar(numpy.bool_)
if isinstance(x, int):
if -(1 << 31) <= x < (1 << 31):
return _cuda_types.Scalar(numpy.int32)
return _cuda_types.Scalar(numpy.int64)
if isinstance(x, float):
return _cuda_types.Scalar(numpy.float32)
if isinstance(x, complex):
return _cuda_types.Scalar(numpy.complex64)
raise NotImplementedError(f'{x} is not scalar object.')

My thinking is: Maybe there is a way to leave the CuPy JIT default unchanged by this PR?

return descr->f;
}
#endif
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Remind me what's the conclusion here? That we still allow importing with 1.x, we just wouldn't claim full support for it?

@leofang
Copy link
Member

leofang commented Dec 20, 2025

/test mini

@leofang
Copy link
Member

leofang commented Dec 21, 2025

LGTM! CI is green too. Though Kenichi-san mentioned there is a code freeze now. Let merge it after the freeze and also give @asi1024 a bit more time in case he wants to chime in 🙂

@leofang
Copy link
Member

leofang commented Dec 21, 2025

@seberg forgot to ask, is there a simple reproducer for us to check the perf difference before and after this PR? Would be nice to know the expected ballpark improvement.

@leofang leofang added this to the v14 milestone Dec 21, 2025
@leofang leofang added the to-be-backported Pull-requests to be backported to stable branch label Dec 22, 2025
@leofang
Copy link
Member

leofang commented Dec 22, 2025

Let merge it after the freeze and also give @asi1024 a bit more time in case he wants to chime in 🙂

Code freeze is lifted. Let me get this merged. @asi1024 please let us know if you have any concern and we can follow up in a separate PR.

@seberg forgot to ask, is there a simple reproducer for us to check the perf difference before and after this PR? Would be nice to know the expected ballpark improvement.

Would still be nice to keep a record in this PR for future reference, in case people come and ask for what drove the decision of making NumPy a build-time dependency.

@leofang leofang merged commit 4d9486d into cupy:main Dec 22, 2025
61 checks passed
chainer-ci pushed a commit to chainer-ci/cupy that referenced this pull request Dec 22, 2025
MAINT,ENH: Simplify `CScalar` handling and ready it for arbitrary dtypes
@seberg seberg deleted the cscalar-cleanup branch December 23, 2025 12:19
@leofang leofang modified the milestones: v14, v14.0.0, v15 Dec 24, 2025
@seberg
Copy link
Member Author

seberg commented Jan 8, 2026

is there a simple reproducer for us to check the perf difference before and after this PR?

Not sure about this, but things like cp.add(f, f) (with f = np.float32(1.)) is e.g. unchanged compared to v13 (cp.add(1, 1) maybe very slightly faster).
I think it's basically a wash. The call to NumPy costs a few ns maybe, OTOH, I optimized an allocation away, which may even safe some in the end.

On the grand scheme, the perf differences here are just very small I think.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cat:performance Performance in terms of speed or memory consumption prio:high to-be-backported Pull-requests to be backported to stable branch

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants