Skip to content

HalvingRandomSearchCV does not support param_distribution as a list #26885

@pwoller

Description

@pwoller

Describe the bug

Before scikit-learn version 1.3.0 (e.g. 1.2.0) HalvingRandomSearchCV could be used with a list[dict] as the input for param_distribution (similar to RandomizedSearchCV).
The type hint in the documentation states that only dict is possible but the description talks about the option to pass also a list. As for version 1.3.0 the parameters are validated so list[dict] no longer works.

Steps/Code to Reproduce

from sklearn.experimental import enable_halving_search_cv # noqa
from sklearn.model_selection import HalvingRandomSearchCV, RandomizedSearchCV

from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.datasets import load_iris

X, y = load_iris(return_X_y=True)

pipe = Pipeline([("clf", None)])
params = [
    {'clf': (DecisionTreeClassifier(),), 'clf__criterion': ['entropy', 'gini']},
    {'clf': (LogisticRegression(),), 'clf__C': [0.01, 0.1, 1, 10, 100], 'clf__max_iter': [10_000]}
]

search = HalvingRandomSearchCV(
    pipe, param_distributions=params, scoring="accuracy")

search.fit(X, y)

Expected Results

No error is thrown

Actual Results

---------------------------------------------------------------------------
InvalidParameterError                     Traceback (most recent call last)
Cell In[22], line 20
     12 params = [
     13     {'clf': (DecisionTreeClassifier(),), 'clf__criterion': ['entropy', 'gini']},
     14     {'clf': (LogisticRegression(),), 'clf__C': [0.01, 0.1, 1, 10, 100], 'clf__max_iter': [10_000]}
     15 ]
     17 search = HalvingRandomSearchCV(
     18     pipe, param_distributions=params, scoring="accuracy")
---> 20 search.fit(X, y)

File ~\AppData\Local\miniconda3\envs\automl\lib\site-packages\sklearn\base.py:1144, in _fit_context.<locals>.decorator.<locals>.wrapper(estimator, *args, **kwargs)
   1139 partial_fit_and_fitted = (
   1140     fit_method.__name__ == "partial_fit" and _is_fitted(estimator)
   1141 )
   1143 if not global_skip_validation and not partial_fit_and_fitted:
-> 1144     estimator._validate_params()
   1146 with config_context(
   1147     skip_parameter_validation=(
   1148         prefer_skip_nested_validation or global_skip_validation
   1149     )
   1150 ):
   1151     return fit_method(estimator, *args, **kwargs)

File ~\AppData\Local\miniconda3\envs\automl\lib\site-packages\sklearn\base.py:637, in BaseEstimator._validate_params(self)
    629 def _validate_params(self):
    630     """Validate types and values of constructor parameters
    631 
    632     The expected type and values must be defined in the `_parameter_constraints`
   (...)
    635     accepted constraints.
    636     """
--> 637     validate_parameter_constraints(
    638         self._parameter_constraints,
    639         self.get_params(deep=False),
    640         caller_name=self.__class__.__name__,
    641     )

File ~\AppData\Local\miniconda3\envs\automl\lib\site-packages\sklearn\utils\_param_validation.py:95, in validate_parameter_constraints(parameter_constraints, params, caller_name)
     89 else:
     90     constraints_str = (
     91         f"{', '.join([str(c) for c in constraints[:-1]])} or"
     92         f" {constraints[-1]}"
     93     )
---> 95 raise InvalidParameterError(
     96     f"The {param_name!r} parameter of {caller_name} must be"
     97     f" {constraints_str}. Got {param_val!r} instead."
     98 )

InvalidParameterError: The 'param_distributions' parameter of HalvingRandomSearchCV must be an instance of 'dict'. Got [{'clf': (DecisionTreeClassifier(),), 'clf__criterion': ['entropy', 'gini']}, {'clf': (LogisticRegression(),), 'clf__C': [0.01, 0.1, 1, 10, 100], 'clf__max_iter': [10000]}] instead.

Versions

1.3.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    BugHigh PriorityHigh priority issues and pull requests

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions