Skip to content

MRG clone parameters in gridsearch etc#15096

Merged
adrinjalali merged 5 commits intoscikit-learn:masterfrom
amueller:grid_search_extra_clone
Oct 29, 2019
Merged

MRG clone parameters in gridsearch etc#15096
adrinjalali merged 5 commits intoscikit-learn:masterfrom
amueller:grid_search_extra_clone

Conversation

@amueller
Copy link
Copy Markdown
Member

Fixed #10063 without going through the pain of #8350.

I don't see a case where this could change behavior, as we immediately call fit after setting the steps, but maybe I'm missing something.

This not only removes possible confusion of users (see #8350 on how the stored estimators are quite hard to interpret), it also saves us potentially a lot of memory (imagine grid-searching a neural net and storing the weights for each parameter setting).

@amueller amueller added this to the 0.22 milestone Sep 26, 2019
# clone after setting parameters in case any parameters
# are estimators (like pipeline steps)
# because pipeline doesn't clone steps in fit
estimator = clone(estimator.set_params(**parameters))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a problem if the parameter is assumed to be a fitted estimator? But does it ever happen?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could in 3rd party things if a meta-estimator doesn't call fit on an estimator.

@amueller
Copy link
Copy Markdown
Member Author

@jnothman @glemaitre might have thoughts?

**self.best_params_)
# we clone again after setting params in case some
# of the params are estimators as well.
self.best_estimator_ = clone(clone(base_estimator).set_params(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the following good enough?

clone(base_estimator.set_params(**self.best_params_))

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so?
We're changing base_estimator then, right?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test_grid_search_pipeline_steps test passes without the double clone. From

base_estimator = clone(self.estimator)
it is already a clone?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could indeed only clone once as @thomasjpfan suggested since base_estimator is just a local variable, which isn't used later.

I guess cloning twice is fine too: no surprises.

@NicolasHug
Copy link
Copy Markdown
Member

NicolasHug commented Sep 27, 2019

Cloning the estimator in _fit_and_score() will make the implementation of #8230 (gridsearch + warm start) impossible now :(

Maybe we could just clone the parameters?

Copy link
Copy Markdown
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice hack! I'm not sure that this fixes #10063 which is about cv_results_, though... it only fixes the best_estimator_ part of that problem.

@amueller
Copy link
Copy Markdown
Member Author

amueller commented Oct 4, 2019

@jnothman no it definitely fixes that one. It's tested pretty extensively.

@amueller
Copy link
Copy Markdown
Member Author

amueller commented Oct 4, 2019

@NicolasHug good point, now only cloning the parameters. If the parameters are not an estimator, this will do a deep copy btw, so if they are large arrays, we keep copying them. On the other hand, if you pass a mutable structure you're asking for trouble and copying it is probably a good idea.

@amueller
Copy link
Copy Markdown
Member Author

amueller commented Oct 4, 2019

@jnothman do you think the current solution is still a hack? Why?

@amueller amueller changed the title MRG clone estimator again after setting parameters in gridsearch etc MRG clone parameters in gridsearch etc Oct 4, 2019
train_scores = {}
if parameters is not None:
estimator.set_params(**parameters)
# clone after setting parameters in case any parameters
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if someone had code relying on the existing behaviour.

Add a test for this wrt cross_validate??

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly, but I'm not sure what to do about that.
What do you want tested?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, parameters is not set by cross_validate. Could add a test for validation_curve. But I'm okay without.

@adrinjalali
Copy link
Copy Markdown
Member

I was curious about a more nested case, but the good news is that this also passes the tests:

#%%
import numpy as np
from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestClassifier
from sklearn.compose import ColumnTransformer
from sklearn.decomposition import PCA
from sklearn.feature_selection import SelectKBest
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.datasets import fetch_openml
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import OneHotEncoder

X, y = fetch_openml("titanic", version=1, as_frame=True, return_X_y=True)
rng = np.random.RandomState(41)
X['random_cat'] = rng.randint(3, size=X.shape[0])
X['random_num'] = rng.randn(X.shape[0])

categorical_columns = ['pclass', 'sex', 'embarked', 'random_cat']
numerical_columns = ['age', 'sibsp', 'parch', 'fare', 'random_num']

X = X[categorical_columns + numerical_columns]

X_train, X_test, y_train, y_test = train_test_split(
    X, y, stratify=y, random_state=42)

categorical_pipe = Pipeline([
    ('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
    ('onehot', OneHotEncoder(handle_unknown='ignore'))
])
numerical_pipe = Pipeline([
    ('imputer', SimpleImputer(strategy='mean')),
    ('selector', PCA())
])

preprocessing = ColumnTransformer(
    [('cat', categorical_pipe, categorical_columns),
     ('num', numerical_pipe, numerical_columns)])

pipe = Pipeline([
    ('preprocess', preprocessing),
    ('classifier', RandomForestClassifier(random_state=42))
])

param_grid = {'preprocess__num__selector': [PCA(), SelectKBest(k=3)]}
grid_search = GridSearchCV(pipe, param_grid, cv=2)
grid_search.fit(X, y)

Copy link
Copy Markdown
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

**self.best_params_)
# we clone again after setting params in case some
# of the params are estimators as well.
self.best_estimator_ = clone(clone(base_estimator).set_params(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could indeed only clone once as @thomasjpfan suggested since base_estimator is just a local variable, which isn't used later.

I guess cloning twice is fine too: no surprises.

Comment on lines +494 to +496
cloned_parameters = {}
for k, v in parameters.items():
cloned_parameters[k] = clone(v, safe=False)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: dict comprehension?

Copy link
Copy Markdown
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lgtm besides the validation_curve comment which may not be essential

@jnothman
Copy link
Copy Markdown
Member

This needs a what's new

@adrinjalali adrinjalali merged commit 3d606cf into scikit-learn:master Oct 29, 2019
@adrinjalali adrinjalali mentioned this pull request Oct 29, 2019
8 tasks
@amueller
Copy link
Copy Markdown
Member Author

amueller commented Nov 1, 2019

Thanks folks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GridSearchCV saves all fitted estimator in cv_results['params'] when params are estimators

6 participants