Skip to content

BUG: where keyword argument is ignored in _array_ufunc__/__array_function__ if it is not an instance of numpy.ndarray #23219

@roytsmart

Description

@roytsmart

Describe the issue:

If the where keyword argument is a ndarray duck type that implements __array_ufunc__ and __array_function__, I would've expected its implementation of those methods to be used. Instead, it appears that where is just ignored if it is not an instance of numpy.ndarray.

Reproduce the code example:

import dataclasses
import numpy as np

@dataclasses.dataclass
class DuckArray(np.lib.mixins.NDArrayOperatorsMixin):

    value: np.ndarray

    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
        print("in DuckArray.__array_ufunc__")
        inputs = [inp.value if isinstance(inp, DuckArray) else inp for inp in inputs]
        kwargs = {k: kwargs[k].value if isinstance(kwargs[k], DuckArray) else kwargs[k] for k in kwargs}
        return DuckArray(getattr(ufunc, method)(*inputs, **kwargs))

    def __array_function__(self, func, types, args, kwargs):
        print("in DuckArray.__array_function_")
        args = [arg.value if isinstance(arg, DuckArray) else arg for arg in args]
        kwargs = {k: kwargs[k].value if isinstance(kwargs[k], DuckArray) else kwargs[k] for k in kwargs}
        return DuckArray(func(*args, **kwargs))

def test_ufunc():
    a = np.random.random(5)
    b = DuckArray(a > 0.5)
    c = np.negative(a, where=b)     # doesn't print "in DuckArray.__array_ufunc__" :(
    print(type(c))                  # prints <class 'numpy.ndarray'>
    print(c)

def test_function():
    a = np.random.random(5)
    b = DuckArray(a > 0.5)
    c = np.sum(a, where=b)          # doesn't print "in DuckArray.__array_function_" :(
    print(type(c))                  # prints <class 'numpy.float64'>
    print(c)

Error message:

No response

Runtime information:

1.24.1
3.11.1 (tags/v3.11.1:a7a450f, Dec  6 2022, 19:58:39) [MSC v.1934 64 bit (AMD64)]
print(numpy.show_runtime())
WARNING: `threadpoolctl` not found in system! Install it by `pip install threadpoolctl`. Once installed, try `np.show_runtime` again for more detailed build information
[{'simd_extensions': {'baseline': ['SSE', 'SSE2', 'SSE3'],
                      'found': ['SSSE3',
                                'SSE41',
                                'POPCNT',
                                'SSE42',
                                'AVX',
                                'F16C',
                                'FMA3',
                                'AVX2'],
                      'not_found': ['AVX512F',
                                    'AVX512CD',
                                    'AVX512_SKX',
                                    'AVX512_CLX',
                                    'AVX512_CNL',
                                    'AVX512_ICL']}}]
None

Context for the issue:

I am working on an array duck type that tries to propagate uncertainties similar to astropy.uncertainty.Distribution. I would like to make sure that uncertainties propagate properly in the case where the where keyword argument is uncertain, but the inputs are exact.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions