Skip to content

Commit 764249a

Browse files
ogrisellesteve
authored andcommitted
Fix _safe_indexing with non integer arrays on array API inputs (#32840)
1 parent eca5e0a commit 764249a

3 files changed

Lines changed: 65 additions & 9 deletions

File tree

sklearn/utils/_array_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def move_to(*arrays, xp, device):
469469
`array` may contain `None` entries, these are left unchanged.
470470
471471
Sparse arrays are accepted (as pass through) if the reference namespace is
472-
Numpy, in which case they are returned unchanged. Otherwise a `TypeError`
472+
NumPy, in which case they are returned unchanged. Otherwise a `TypeError`
473473
is raised.
474474
475475
Parameters

sklearn/utils/_indexing.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,24 @@ def _array_indexing(array, key, key_dtype, axis):
3636
"""Index an array or scipy.sparse consistently across NumPy version."""
3737
xp, is_array_api, device_ = get_namespace_and_device(array)
3838
if is_array_api:
39-
key = move_to(key, xp=xp, device=device_)
40-
return xp.take(array, key, axis=axis)
39+
if hasattr(key, "shape"):
40+
key = move_to(key, xp=xp, device=device_)
41+
elif isinstance(key, (int, slice)):
42+
# Passthrough for valid __getitem__ inputs as noted in the array
43+
# API spec.
44+
pass
45+
else:
46+
key = xp.asarray(key, device=device_)
47+
48+
if hasattr(key, "dtype"):
49+
if xp.isdtype(key.dtype, "integral"):
50+
return xp.take(array, key, axis=axis)
51+
elif xp.isdtype(key.dtype, "bool"):
52+
# Array API does not support boolean indexing for n-dim arrays
53+
# yet hence the need to turn to equivalent integer indexing.
54+
indices = xp.arange(array.shape[axis], device=device_)
55+
return xp.take(array, indices[key], axis=axis)
56+
4157
if issparse(array) and key_dtype == "bool":
4258
key = np.asarray(key)
4359
if isinstance(key, tuple):

sklearn/utils/tests/test_indexing.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
from sklearn.externals._packaging.version import parse as parse_version
1111
from sklearn.utils import _safe_indexing, resample, shuffle
1212
from sklearn.utils._array_api import (
13+
_convert_to_numpy,
1314
_get_namespace_device_dtype_ids,
15+
device,
16+
move_to,
1417
yield_namespace_device_dtype_combinations,
1518
)
1619
from sklearn.utils._indexing import (
@@ -22,6 +25,7 @@
2225
from sklearn.utils._testing import (
2326
_array_api_for_tests,
2427
_convert_container,
28+
assert_allclose,
2529
assert_allclose_dense_sparse,
2630
assert_array_equal,
2731
skip_if_array_api_compat_not_configured,
@@ -108,22 +112,22 @@ def test_determine_key_type_slice_error():
108112

109113
@skip_if_array_api_compat_not_configured
110114
@pytest.mark.parametrize(
111-
"array_namespace, device, dtype_name",
115+
"array_namespace, device_, dtype_name",
112116
yield_namespace_device_dtype_combinations(),
113117
ids=_get_namespace_device_dtype_ids,
114118
)
115-
def test_determine_key_type_array_api(array_namespace, device, dtype_name):
116-
xp = _array_api_for_tests(array_namespace, device)
119+
def test_determine_key_type_array_api(array_namespace, device_, dtype_name):
120+
xp = _array_api_for_tests(array_namespace, device_)
117121

118122
with sklearn.config_context(array_api_dispatch=True):
119-
int_array_key = xp.asarray([1, 2, 3])
123+
int_array_key = xp.asarray([1, 2, 3], device=device_)
120124
assert _determine_key_type(int_array_key) == "int"
121125

122-
bool_array_key = xp.asarray([True, False, True])
126+
bool_array_key = xp.asarray([True, False, True], device=device_)
123127
assert _determine_key_type(bool_array_key) == "bool"
124128

125129
try:
126-
complex_array_key = xp.asarray([1 + 1j, 2 + 2j, 3 + 3j])
130+
complex_array_key = xp.asarray([1 + 1j, 2 + 2j, 3 + 3j], device=device_)
127131
except TypeError:
128132
# Complex numbers are not supported by all Array API libraries.
129133
complex_array_key = None
@@ -133,6 +137,42 @@ def test_determine_key_type_array_api(array_namespace, device, dtype_name):
133137
_determine_key_type(complex_array_key)
134138

135139

140+
@skip_if_array_api_compat_not_configured
141+
@pytest.mark.parametrize(
142+
"array_namespace, device_, dtype_name",
143+
yield_namespace_device_dtype_combinations(),
144+
ids=_get_namespace_device_dtype_ids,
145+
)
146+
@pytest.mark.parametrize(
147+
"indexing_key",
148+
(
149+
0,
150+
-1,
151+
[1, 3],
152+
np.array([1, 3]),
153+
slice(1, 2),
154+
[True, False, True, True],
155+
np.asarray([False, False, False, False]),
156+
),
157+
)
158+
@pytest.mark.parametrize("axis", [0, 1])
159+
def test_safe_indexing_array_api_support(
160+
array_namespace, device_, dtype_name, indexing_key, axis
161+
):
162+
xp = _array_api_for_tests(array_namespace, device_)
163+
164+
array_to_index_np = np.arange(16).reshape(4, 4)
165+
expected_result = _safe_indexing(array_to_index_np, indexing_key, axis=axis)
166+
array_to_index_xp = move_to(array_to_index_np, xp=xp, device=device_)
167+
168+
with sklearn.config_context(array_api_dispatch=True):
169+
indexed_array_xp = _safe_indexing(array_to_index_xp, indexing_key, axis=axis)
170+
assert device(indexed_array_xp) == device(array_to_index_xp)
171+
assert indexed_array_xp.dtype == array_to_index_xp.dtype
172+
173+
assert_allclose(_convert_to_numpy(indexed_array_xp, xp=xp), expected_result)
174+
175+
136176
@pytest.mark.parametrize(
137177
"array_type", ["list", "array", "sparse", "dataframe", "polars", "pyarrow"]
138178
)

0 commit comments

Comments
 (0)