-
-
Notifications
You must be signed in to change notification settings - Fork 26.9k
Description
Describe the bug
In parameter validation there are many places where we use ["array-like", "sparse matrix"] so I think at least the former should not be a superset of the latter, but it is the case now. Looking at the class _ArrayLikes, it treats the input as valid as long as the input has __len__, shape, or __array__ and is not a scaler. Clearly both sparse matrices and sparse arrays satisfy this condition, though I think they should be excluded. I propose adding the constraint not sp.issparse(array) to "array-like".
For more context please see #27950 which tries to extend parameter validation to the new sparse arrays.
Also quoting the glossary page
Steps/Code to Reproduce
>>> from sklearn.utils._param_validation import validate_params
>>> @validate_params({"X": ["array-like"]}, prefer_skip_nested_validation=False)
... def func(X):
... return X
...
>>> import scipy.sparse as sp
>>> func(sp.csr_array((3, 4)))
<3x4 sparse array of type '<class 'numpy.float64'>'
with 0 stored elements in Compressed Sparse Row format>
>>> func(sp.csr_matrix((3, 4)))
<3x4 sparse matrix of type '<class 'numpy.float64'>'
with 0 stored elements in Compressed Sparse Row format>Another example can be AgglomerativeClustering, where the validation for connectivity does not include "sparse matrix" but tests such as sklearn/cluster/tests/test_hierarchical.py::test_agglomerative_clustering are passing even though connectivity is sparse.
Expected Results
Both should raise sklearn.utils._param_validation.InvalidParameterError: The 'X' parameter of func must be an array-like.
Actual Results
No error.
Versions
System:
python: 3.9.18 | packaged by conda-forge | (main, Aug 30 2023, 03:40:31) [MSC v.1929 64 bit (AMD64)]
executable: D:\Downloads\mambaforge\envs\sklearn-env\python.exe
machine: Windows-10-10.0.19045-SP0
Python dependencies:
sklearn: 1.5.dev0
pip: 23.2.1
setuptools: 68.2.2
numpy: 1.26.0
scipy: 1.11.2
Cython: 3.0.2
pandas: 2.1.4
matplotlib: 3.8.2
joblib: 1.3.2
threadpoolctl: 3.2.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: mkl
num_threads: 6
prefix: libblas
filepath: D:\Downloads\mambaforge\envs\sklearn-env\Library\bin\libblas.dll
version: 2022.1-Product
threading_layer: intel
user_api: openmp
internal_api: openmp
num_threads: 12
prefix: vcomp
filepath: D:\Downloads\mambaforge\envs\sklearn-env\vcomp140.dll
version: None