Skip to content

Dynamic typing for number of trees in RandomForestClassifier #7146

@fabianegli

Description

@fabianegli

I think it would not hurt to allow n_estimators to be fed into RandomForestClassifier as a string. And if the input is not an int or throws a ValueException when transforming with int() there should be a more verbose error message in the style of "TypeError: n_estimators has to be an integer."

>>> from sklearn.ensemble import RandomForestClassifier
>>> forest = RandomForestClassifier(n_estimators=10, oob_score=True, n_jobs=1) 
>>> forest.fit([[1,2,3],[1,4,5],[0,3,2],[5,3,5]],[0,1,0,1]).predict([[2,3,5]])
array([1])
>>> forest = RandomForestClassifier(n_estimators='10', oob_score=True, n_jobs=1)
>>> forest.fit([[1,2,3],[1,4,5],[0,3,2],[5,3,5]],[0,1,0,1]).predict([2,3,5])    Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/sklearn/ensemble/forest.py", line 247, in fit
    self._validate_estimator()
  File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/sklearn/ensemble/base.py", line 58, in _validate_estimator
    if self.n_estimators <= 0:
TypeError: unorderable types: str() <= int()

Metadata

Metadata

Assignees

No one assigned

    Labels

    EasyWell-defined and straightforward way to resolve

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions