Skip to content

Serialization error when using parallelism in cross_val_score with GridSearchCV and a custom estimator #12413

@TomDLT

Description

@TomDLT

Minimal example:

import numpy as np
from sklearn.model_selection import cross_val_score, GridSearchCV
from sklearn.base import ClassifierMixin, BaseEstimator

class Dummy(ClassifierMixin, BaseEstimator):
    def __init__(self, answer=1):
        self.answer = answer

    def fit(self, X, y=None):
        return self

    def predict(self, X):
        return np.ones(X.shape[0], dtype='int') * self.answer

n_samples, n_features = 500, 8
X = np.random.randn(n_samples, n_features)
y = np.random.randint(0, 2, n_samples)

dummy = Dummy()
gcv = GridSearchCV(dummy, {'answer': [0, 1]}, cv=5, iid=False, n_jobs=1)
cross_val_score(gcv, X, y, cv=5, n_jobs=5)

# BrokenProcessPool: A task has failed to un-serialize.
# Please ensure that the arguments of the function are all picklable.

Full traceback in details.

Interestingly, it does not fail when:

  • calling cross_val_score with n_jobs=1.
  • calling cross_val_score directly on dummy, without GridSearchCV.
  • using a imported classifier, as LogisticRegression, or even the same Dummy custom classifier but imported from another file.

This is a joblib 0.12 issue, different from #12289 or #12389. @ogrisel @tomMoral

Details
Traceback (most recent call last):
  File "/cal/homes/tdupre/work/src/joblib/joblib/externals/loky/process_executor.py", line 393, in _process_worker
    call_item = call_queue.get(block=True, timeout=timeout)
  File "/cal/homes/tdupre/miniconda3/envs/py36/lib/python3.6/multiprocessing/queues.py", line 113, in get
    return _ForkingPickler.loads(res)
AttributeError: Can't get attribute 'Dummy' on <module 'sklearn.externals.joblib.externals.loky.backend.popen_loky_posix' from '/cal/homes/tdupre/work/src/scikit-learn/sklearn/externals/joblib/externals/loky/backend/popen_loky_posix.py'>
'''

The above exception was the direct cause of the following exception:

BrokenProcessPool                         Traceback (most recent call last)
~/work/src/script_csc/condition_effect/test.py in <module>()
     32 
     33     # fails
---> 34     cross_val_score(gcv, X, y, cv=5, n_jobs=5)
     35     """
     36     BrokenProcessPool: A task has failed to un-serialize.

~/work/src/scikit-learn/sklearn/model_selection/_validation.py in cross_val_score(estimator, X, y, groups, scoring, cv, n_jobs, verbose, fit_params, pre_dispatch, error_score)
    384                                 fit_params=fit_params,
    385                                 pre_dispatch=pre_dispatch,
--> 386                                 error_score=error_score)
    387     return cv_results['test_score']
    388 

~/work/src/scikit-learn/sklearn/model_selection/_validation.py in cross_validate(estimator, X, y, groups, scoring, cv, n_jobs, verbose, fit_params, pre_dispatch, return_train_score, return_estimator, error_score)
    232             return_times=True, return_estimator=return_estimator,
    233             error_score=error_score)
--> 234         for train, test in cv.split(X, y, groups))
    235 
    236     zipped_scores = list(zip(*scores))

~/work/src/joblib/joblib/parallel.py in __call__(self, iterable)
    996 
    997             with self._backend.retrieval_context():
--> 998                 self.retrieve()
    999             # Make sure that we get a last message telling us we are done
   1000             elapsed_time = time.time() - self._start_time

~/work/src/joblib/joblib/parallel.py in retrieve(self)
    899             try:
    900                 if getattr(self._backend, 'supports_timeout', False):
--> 901                     self._output.extend(job.get(timeout=self.timeout))
    902                 else:
    903                     self._output.extend(job.get())

~/work/src/joblib/joblib/_parallel_backends.py in wrap_future_result(future, timeout)
    519         AsyncResults.get from multiprocessing."""
    520         try:
--> 521             return future.result(timeout=timeout)
    522         except LokyTimeoutError:
    523             raise TimeoutError()

~/miniconda3/envs/py36/lib/python3.6/concurrent/futures/_base.py in result(self, timeout)
    403                 raise CancelledError()
    404             elif self._state == FINISHED:
--> 405                 return self.__get_result()
    406             else:
    407                 raise TimeoutError()

~/miniconda3/envs/py36/lib/python3.6/concurrent/futures/_base.py in __get_result(self)
    355     def __get_result(self):
    356         if self._exception:
--> 357             raise self._exception
    358         else:
    359             return self._result

BrokenProcessPool: A task has failed to un-serialize. Please ensure that the arguments of the function are all picklable.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions