-
-
Notifications
You must be signed in to change notification settings - Fork 26.6k
Description
Hello all,
Description
I noticed that BaseEstimator does not support the use of __slots__, insofar that you cannot round-trip pickle objects that 1) inherit from BaseEstimator and 2) use __slots__ rather than __dict__.
That said, I am not totally certain that I am actually using BaseEstimator correctly, but I thought I'd report it anyway. It may also be outside of the scope of the project...
Steps/Code to Reproduce
Run the following in ipython:
import pickle
import tempfile
from sklearn.base import BaseEstimator
# define a trivial estimator with only one
# parameter that inherits from BaseEstimator
class SlottedEstimator(BaseEstimator):
__slots__ = ('parameter',)
def __init__(self, parameter=None):
if parameter is None:
parameter = 1
self.parameter = parameter
# try round-tripping the new estimator with pickle
est = SlottedEstimator(parameter=4)
with tempfile.NamedTemporaryFile() as tmp_f:
pickle.dump(est, tmp_f)
tmp_f.flush()
est2 = pickle.load(open(tmp_f.name, 'rb'))
print(est.parameter)
print(est2.parameter)Expected Results
>>> 4
>>> 4
Actual Results
4
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-5-b635b31a76c8> in <module>()
20
21 print(est.parameter)
---> 22 print(est2.parameter)
AttributeError: parameter
You can verify that this is due to the use of __slots__ by simply commenting out that line and running it again.
Versions
Darwin-16.7.0-x86_64-i386-64bit
Python 3.5.2 (default, Oct 11 2016, 04:59:56)
[GCC 4.2.1 Compatible Apple LLVM 8.0.0 (clang-800.0.38)]
NumPy 1.12.1
SciPy 0.19.0
Scikit-Learn 0.19.1
Comments
I think that the problem is in BaseEstimator.__setstate__ which, for some reason, winds up getting an empty dictionary somewhere in the following block:
try:
super(BaseEstimator, self).__setstate__(state)
except AttributeError:
self.__dict__.update(state)If this is inside the scope of the project and all that is needed is a straightforward check for __slots__ instead of __dict__ I'm happy to do it myself and propose a PR.
Cheers, and thanks for the awesome library!