1010from sklearn .externals ._packaging .version import parse as parse_version
1111from sklearn .utils import _safe_indexing , resample , shuffle
1212from sklearn .utils ._array_api import (
13+ _convert_to_numpy ,
1314 _get_namespace_device_dtype_ids ,
15+ device ,
16+ move_to ,
1417 yield_namespace_device_dtype_combinations ,
1518)
1619from sklearn .utils ._indexing import (
2225from sklearn .utils ._testing import (
2326 _array_api_for_tests ,
2427 _convert_container ,
28+ assert_allclose ,
2529 assert_allclose_dense_sparse ,
2630 assert_array_equal ,
2731 skip_if_array_api_compat_not_configured ,
@@ -108,22 +112,22 @@ def test_determine_key_type_slice_error():
108112
109113@skip_if_array_api_compat_not_configured
110114@pytest .mark .parametrize (
111- "array_namespace, device , dtype_name" ,
115+ "array_namespace, device_ , dtype_name" ,
112116 yield_namespace_device_dtype_combinations (),
113117 ids = _get_namespace_device_dtype_ids ,
114118)
115- def test_determine_key_type_array_api (array_namespace , device , dtype_name ):
116- xp = _array_api_for_tests (array_namespace , device )
119+ def test_determine_key_type_array_api (array_namespace , device_ , dtype_name ):
120+ xp = _array_api_for_tests (array_namespace , device_ )
117121
118122 with sklearn .config_context (array_api_dispatch = True ):
119- int_array_key = xp .asarray ([1 , 2 , 3 ])
123+ int_array_key = xp .asarray ([1 , 2 , 3 ], device = device_ )
120124 assert _determine_key_type (int_array_key ) == "int"
121125
122- bool_array_key = xp .asarray ([True , False , True ])
126+ bool_array_key = xp .asarray ([True , False , True ], device = device_ )
123127 assert _determine_key_type (bool_array_key ) == "bool"
124128
125129 try :
126- complex_array_key = xp .asarray ([1 + 1j , 2 + 2j , 3 + 3j ])
130+ complex_array_key = xp .asarray ([1 + 1j , 2 + 2j , 3 + 3j ], device = device_ )
127131 except TypeError :
128132 # Complex numbers are not supported by all Array API libraries.
129133 complex_array_key = None
@@ -133,6 +137,42 @@ def test_determine_key_type_array_api(array_namespace, device, dtype_name):
133137 _determine_key_type (complex_array_key )
134138
135139
140+ @skip_if_array_api_compat_not_configured
141+ @pytest .mark .parametrize (
142+ "array_namespace, device_, dtype_name" ,
143+ yield_namespace_device_dtype_combinations (),
144+ ids = _get_namespace_device_dtype_ids ,
145+ )
146+ @pytest .mark .parametrize (
147+ "indexing_key" ,
148+ (
149+ 0 ,
150+ - 1 ,
151+ [1 , 3 ],
152+ np .array ([1 , 3 ]),
153+ slice (1 , 2 ),
154+ [True , False , True , True ],
155+ np .asarray ([False , False , False , False ]),
156+ ),
157+ )
158+ @pytest .mark .parametrize ("axis" , [0 , 1 ])
159+ def test_safe_indexing_array_api_support (
160+ array_namespace , device_ , dtype_name , indexing_key , axis
161+ ):
162+ xp = _array_api_for_tests (array_namespace , device_ )
163+
164+ array_to_index_np = np .arange (16 ).reshape (4 , 4 )
165+ expected_result = _safe_indexing (array_to_index_np , indexing_key , axis = axis )
166+ array_to_index_xp = move_to (array_to_index_np , xp = xp , device = device_ )
167+
168+ with sklearn .config_context (array_api_dispatch = True ):
169+ indexed_array_xp = _safe_indexing (array_to_index_xp , indexing_key , axis = axis )
170+ assert device (indexed_array_xp ) == device (array_to_index_xp )
171+ assert indexed_array_xp .dtype == array_to_index_xp .dtype
172+
173+ assert_allclose (_convert_to_numpy (indexed_array_xp , xp = xp ), expected_result )
174+
175+
136176@pytest .mark .parametrize (
137177 "array_type" , ["list" , "array" , "sparse" , "dataframe" , "polars" , "pyarrow" ]
138178)
0 commit comments