Skip to content

BaseEstimator does not support __slots__ #10079

@justinrporter

Description

@justinrporter

Hello all,

Description

I noticed that BaseEstimator does not support the use of __slots__, insofar that you cannot round-trip pickle objects that 1) inherit from BaseEstimator and 2) use __slots__ rather than __dict__.

That said, I am not totally certain that I am actually using BaseEstimator correctly, but I thought I'd report it anyway. It may also be outside of the scope of the project...

Steps/Code to Reproduce

Run the following in ipython:

import pickle
import tempfile
from sklearn.base import BaseEstimator

# define a trivial estimator with only one
# parameter that inherits from BaseEstimator
class SlottedEstimator(BaseEstimator):
    __slots__ = ('parameter',)
    def __init__(self, parameter=None):
        if parameter is None:
            parameter = 1
        self.parameter = parameter

# try round-tripping the new estimator with pickle
est = SlottedEstimator(parameter=4)
with tempfile.NamedTemporaryFile() as tmp_f:
    pickle.dump(est, tmp_f)
    tmp_f.flush()
    est2 = pickle.load(open(tmp_f.name, 'rb'))

print(est.parameter)
print(est2.parameter)

Expected Results

>>> 4
>>> 4

Actual Results

4
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-5-b635b31a76c8> in <module>()
     20
     21 print(est.parameter)
---> 22 print(est2.parameter)

AttributeError: parameter

You can verify that this is due to the use of __slots__ by simply commenting out that line and running it again.

Versions

Darwin-16.7.0-x86_64-i386-64bit
Python 3.5.2 (default, Oct 11 2016, 04:59:56)
[GCC 4.2.1 Compatible Apple LLVM 8.0.0 (clang-800.0.38)]
NumPy 1.12.1
SciPy 0.19.0
Scikit-Learn 0.19.1

Comments

I think that the problem is in BaseEstimator.__setstate__ which, for some reason, winds up getting an empty dictionary somewhere in the following block:

        try:
            super(BaseEstimator, self).__setstate__(state)
        except AttributeError:
            self.__dict__.update(state)

If this is inside the scope of the project and all that is needed is a straightforward check for __slots__ instead of __dict__ I'm happy to do it myself and propose a PR.

Cheers, and thanks for the awesome library!

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions