-
-
Notifications
You must be signed in to change notification settings - Fork 26.9k
Serialization error when using parallelism in cross_val_score with GridSearchCV and a custom estimator #12413
Copy link
Copy link
Closed
Description
Minimal example:
import numpy as np
from sklearn.model_selection import cross_val_score, GridSearchCV
from sklearn.base import ClassifierMixin, BaseEstimator
class Dummy(ClassifierMixin, BaseEstimator):
def __init__(self, answer=1):
self.answer = answer
def fit(self, X, y=None):
return self
def predict(self, X):
return np.ones(X.shape[0], dtype='int') * self.answer
n_samples, n_features = 500, 8
X = np.random.randn(n_samples, n_features)
y = np.random.randint(0, 2, n_samples)
dummy = Dummy()
gcv = GridSearchCV(dummy, {'answer': [0, 1]}, cv=5, iid=False, n_jobs=1)
cross_val_score(gcv, X, y, cv=5, n_jobs=5)
# BrokenProcessPool: A task has failed to un-serialize.
# Please ensure that the arguments of the function are all picklable.Full traceback in details.
Interestingly, it does not fail when:
- calling
cross_val_scorewithn_jobs=1. - calling
cross_val_scoredirectly ondummy, withoutGridSearchCV. - using a imported classifier, as
LogisticRegression, or even the sameDummycustom classifier but imported from another file.
This is a joblib 0.12 issue, different from #12289 or #12389. @ogrisel @tomMoral
Details
Traceback (most recent call last):
File "/cal/homes/tdupre/work/src/joblib/joblib/externals/loky/process_executor.py", line 393, in _process_worker
call_item = call_queue.get(block=True, timeout=timeout)
File "/cal/homes/tdupre/miniconda3/envs/py36/lib/python3.6/multiprocessing/queues.py", line 113, in get
return _ForkingPickler.loads(res)
AttributeError: Can't get attribute 'Dummy' on <module 'sklearn.externals.joblib.externals.loky.backend.popen_loky_posix' from '/cal/homes/tdupre/work/src/scikit-learn/sklearn/externals/joblib/externals/loky/backend/popen_loky_posix.py'>
'''
The above exception was the direct cause of the following exception:
BrokenProcessPool Traceback (most recent call last)
~/work/src/script_csc/condition_effect/test.py in <module>()
32
33 # fails
---> 34 cross_val_score(gcv, X, y, cv=5, n_jobs=5)
35 """
36 BrokenProcessPool: A task has failed to un-serialize.
~/work/src/scikit-learn/sklearn/model_selection/_validation.py in cross_val_score(estimator, X, y, groups, scoring, cv, n_jobs, verbose, fit_params, pre_dispatch, error_score)
384 fit_params=fit_params,
385 pre_dispatch=pre_dispatch,
--> 386 error_score=error_score)
387 return cv_results['test_score']
388
~/work/src/scikit-learn/sklearn/model_selection/_validation.py in cross_validate(estimator, X, y, groups, scoring, cv, n_jobs, verbose, fit_params, pre_dispatch, return_train_score, return_estimator, error_score)
232 return_times=True, return_estimator=return_estimator,
233 error_score=error_score)
--> 234 for train, test in cv.split(X, y, groups))
235
236 zipped_scores = list(zip(*scores))
~/work/src/joblib/joblib/parallel.py in __call__(self, iterable)
996
997 with self._backend.retrieval_context():
--> 998 self.retrieve()
999 # Make sure that we get a last message telling us we are done
1000 elapsed_time = time.time() - self._start_time
~/work/src/joblib/joblib/parallel.py in retrieve(self)
899 try:
900 if getattr(self._backend, 'supports_timeout', False):
--> 901 self._output.extend(job.get(timeout=self.timeout))
902 else:
903 self._output.extend(job.get())
~/work/src/joblib/joblib/_parallel_backends.py in wrap_future_result(future, timeout)
519 AsyncResults.get from multiprocessing."""
520 try:
--> 521 return future.result(timeout=timeout)
522 except LokyTimeoutError:
523 raise TimeoutError()
~/miniconda3/envs/py36/lib/python3.6/concurrent/futures/_base.py in result(self, timeout)
403 raise CancelledError()
404 elif self._state == FINISHED:
--> 405 return self.__get_result()
406 else:
407 raise TimeoutError()
~/miniconda3/envs/py36/lib/python3.6/concurrent/futures/_base.py in __get_result(self)
355 def __get_result(self):
356 if self._exception:
--> 357 raise self._exception
358 else:
359 return self._result
BrokenProcessPool: A task has failed to un-serialize. Please ensure that the arguments of the function are all picklable.Reactions are currently unavailable