Skip to content

Add blockwise ensemble meta-estimators#657

Merged
TomAugspurger merged 18 commits intodask:masterfrom
TomAugspurger:blockwise-meta
May 6, 2020
Merged

Add blockwise ensemble meta-estimators#657
TomAugspurger merged 18 commits intodask:masterfrom
TomAugspurger:blockwise-meta

Conversation

@TomAugspurger
Copy link
Copy Markdown
Member

This adds BlockwiseVotingClassifier and BlockwiseVotingRegressor, which are meta-estimators for

  1. Blockwise training of the sub-estimator.
  2. Ensemble prediction on new data.

Given an input array split into k partitions, k clones of the subestimator are fit independently, one per partition. This is efficient since we don't need to move any data (assuming the partitions of X and y are co-located).

At prediction time we combine the predictions from each of the k fitted models. For classification we take the class with either the most votes (voting="hard") or the highest total probability (voting="soft"). See https://scikit-learn.org/stable/modules/ensemble.html#voting-classifier. For regression we take the average.

In [1]: import sklearn.linear_model
   ...: import dask_ml.datasets
   ...: import dask_ml.ensemble
   ...:
   ...: X, y = dask_ml.datasets.make_classification(n_features=20, chunks=25)
   ...:
   ...: clf = dask_ml.ensemble.BlockwiseVotingClassifier(
   ...:     sklearn.linear_model.LogisticRegression(), voting="soft",
   ...:     classes=[0, 1]
   ...: )
   ...:
   ...: clf.fit(X, y)

In [2]: clf.estimators_
Out[2]:
[LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
                    intercept_scaling=1, l1_ratio=None, max_iter=100,
                    multi_class='auto', n_jobs=None, penalty='l2',
                    random_state=None, solver='lbfgs', tol=0.0001, verbose=0,
                    warm_start=False),
 LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
                    intercept_scaling=1, l1_ratio=None, max_iter=100,
                    multi_class='auto', n_jobs=None, penalty='l2',
                    random_state=None, solver='lbfgs', tol=0.0001, verbose=0,
                    warm_start=False),
 LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
                    intercept_scaling=1, l1_ratio=None, max_iter=100,
                    multi_class='auto', n_jobs=None, penalty='l2',
                    random_state=None, solver='lbfgs', tol=0.0001, verbose=0,
                    warm_start=False),
 LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
                    intercept_scaling=1, l1_ratio=None, max_iter=100,
                    multi_class='auto', n_jobs=None, penalty='l2',
                    random_state=None, solver='lbfgs', tol=0.0001, verbose=0,
                    warm_start=False)]

@TomAugspurger
Copy link
Copy Markdown
Member Author

cc @js3711 from #135 and @nbren12.

One thing I'm not sure about is what the output of .transform should be. In scikit-learn's VotingClassifier, .transform returns the predictions of each sub-estimator. We can add that later if we want to match that.

```python
In [1]: import sklearn.linear_model
   ...: import dask_ml.datasets
   ...: import dask_ml.ensemble
   ...:
   ...: X, y = dask_ml.datasets.make_classification(n_features=20, chunks=25)
   ...:
   ...: clf = dask_ml.ensemble.BlockwiseVotingClassifier(
   ...:     sklearn.linear_model.LogisticRegression(), voting="soft",
   ...:     classes=[0, 1]
   ...: )
   ...:
   ...: clf.fit(X, y)

In [2]: clf.estimators_
Out[2]:
[LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
                    intercept_scaling=1, l1_ratio=None, max_iter=100,
                    multi_class='auto', n_jobs=None, penalty='l2',
                    random_state=None, solver='lbfgs', tol=0.0001, verbose=0,
                    warm_start=False),
 LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
                    intercept_scaling=1, l1_ratio=None, max_iter=100,
                    multi_class='auto', n_jobs=None, penalty='l2',
                    random_state=None, solver='lbfgs', tol=0.0001, verbose=0,
                    warm_start=False),
 LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
                    intercept_scaling=1, l1_ratio=None, max_iter=100,
                    multi_class='auto', n_jobs=None, penalty='l2',
                    random_state=None, solver='lbfgs', tol=0.0001, verbose=0,
                    warm_start=False),
 LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
                    intercept_scaling=1, l1_ratio=None, max_iter=100,
                    multi_class='auto', n_jobs=None, penalty='l2',
                    random_state=None, solver='lbfgs', tol=0.0001, verbose=0,
                    warm_start=False)]
```
Comment on lines +53 to +57
results = [
X.map_blocks(_predict, dtype=dtype, estimator=estimator, drop_axis=1)
for estimator in self.estimators_
]
combined = da.vstack(results).T.rechunk({1: -1})
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is kinda slow for many estimators (many partitions). We end up with X.npartitions * len(self.classifiers_) tasks. Looking into doing a "batch" predict where each task represents the predictions from every estimator on that partition stacked into a single ndarray.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed for array. For dask.dataframe.map_partitions didn't especially like returning a 3D array. Looking into it a bit, but not a blocker right now.

@TomAugspurger TomAugspurger merged commit b812fe5 into dask:master May 6, 2020
@TomAugspurger TomAugspurger deleted the blockwise-meta branch May 6, 2020 11:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant