Skip to content

[MRG+1] fix #6101 GradientBoosting decision_function for sparse inputs#6116

Merged
jnothman merged 12 commits intoscikit-learn:masterfrom
olologin:GradientBoostingFix
Oct 15, 2016
Merged

[MRG+1] fix #6101 GradientBoosting decision_function for sparse inputs#6116
jnothman merged 12 commits intoscikit-learn:masterfrom
olologin:GradientBoostingFix

Conversation

@olologin
Copy link
Copy Markdown
Contributor

@olologin olologin commented Jan 5, 2016

Fix for issue #6101
Please make suggestions.

@olologin olologin changed the title GradientBoosting decision_function GradientBoosting decision_function for sparse inputs Jan 5, 2016
@olologin olologin changed the title GradientBoosting decision_function for sparse inputs [MRG] fix #6101 GradientBoosting decision_function for sparse inputs Jan 8, 2016
@aflaxman
Copy link
Copy Markdown
Contributor

This makes my example #6101 work. Thanks!

@jmschrei
Copy link
Copy Markdown
Member

This looks mostly good to me. You should squash the commits as well. @glouppe can you take a look?

@amueller
Copy link
Copy Markdown
Member

Have you check the prediction speed for single samples? There is a benchmark in the benchmarks folder, I think.

@olologin
Copy link
Copy Markdown
Contributor Author

@amueller Hmm, it shouldn't slow down anything, because this PR only adds prediction functionality for sparse matricies. Also, can you point me at that benchmark? I can't find anything related to GradientBoosting in benchmark folder.

@olologin
Copy link
Copy Markdown
Contributor Author

olologin commented Apr 24, 2016

Sorry for late response, could someone review it? @amueller, @glouppe

On same dataset dense prediction takes ~958ms, sparse ~1.2s

20 newsgroups
=============
X_train.shape = (11314, 130107)
X_train density = 0.001214353154362896
y_train (11314,)
X_test (3500, 130107)
X_test.format = csr
X_test.dtype = float32
y_test (3500,)

Classifier Training
===================
Training GradientBoostingClassifier_100_trees ...

1 loop, best of 3: 958 ms per loop
1 loop, best of 3: 1.2 s per loop

I've made some benchmark based on bench_20_newsgroups.py:

from __future__ import print_function, division
import numpy as np

from sklearn.datasets import fetch_20newsgroups_vectorized
from sklearn.utils.validation import check_array

from sklearn.ensemble import GradientBoostingClassifier

data_train = fetch_20newsgroups_vectorized(subset="train")
data_test = fetch_20newsgroups_vectorized(subset="test")
X_train_sp = check_array(data_train.data, dtype=np.float32,
                      accept_sparse="csc")
checked_test = check_array(data_test.data, dtype=np.float32, accept_sparse="csr")
X_test_sp = checked_test[:3500, :]
y_train_sp = data_train.target
y_test_sp = data_test.target[:3500]

X_test_dense = X_test_sp.todense()

print("20 newsgroups")
print("=============")
print("X_train.shape = {0}".format(X_train_sp.shape))
print("X_train density = {0}"
      "".format(X_train_sp.nnz / np.product(X_train_sp.shape)))
print("y_train {0}".format(y_train_sp.shape))
print("X_test {0}".format(X_test_sp.shape))
print("X_test.format = {0}".format(X_test_sp.format))
print("X_test.dtype = {0}".format(X_test_sp.dtype))
print("y_test {0}".format(y_test_sp.shape))
print()

print("Classifier Training")
print("===================")
accuracy, test_time = {}, {}

name = "GradientBoostingClassifier_100_trees"
clf = GradientBoostingClassifier(n_estimators=100)
try:
    clf.set_params(random_state=0)
except (TypeError, ValueError):
    pass

print("Training %s ... " % name, end="")
clf.fit(X_train_sp, y_train_sp)

%timeit clf.predict(X_test_dense)
%timeit clf.predict(X_test_sp)

@l3link
Copy link
Copy Markdown

l3link commented Jun 20, 2016

This fix is incredibly useful for very sparse data sets ( >95% 0 values). Converting a medium sized data set with a 60k x 3k matrix from dense to sparse reduces training time from hours to minutes (on a c3.8xlarge AWS single machine). Any chance we can get this merged into the next release? Or at least into develop? I'm getting annoyed building this from source for the last 2 months.

@glouppe
Copy link
Copy Markdown
Contributor

glouppe commented Jun 21, 2016

From a quick read, this looks good, besides some minor cosmit issues.

Py_ssize_t K,
Py_ssize_t n_samples,
Py_ssize_t n_features,
float64 *out):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be indented with DTYPE_t, not with the opening parenthesis.

@jaquesgrobler
Copy link
Copy Markdown
Member

I had a quick read-through.. Apart from @glouppe 's comments above, this looks great.
Very useful fix this. I'm +1 for merging this once last points are addressed 👍

@jmschrei
Copy link
Copy Markdown
Member

This seems extremely useful. Should be merged as soon as comments are addressed.

@jnothman
Copy link
Copy Markdown
Member

@olologin or whoever does the merge should remember to add a what's new entry.

@olologin olologin force-pushed the GradientBoostingFix branch from 7c10b48 to 4ce106e Compare June 26, 2016 09:15
@olologin
Copy link
Copy Markdown
Contributor Author

olologin commented Jun 26, 2016

@glouppe , @jaquesgrobler , @jnothman .

Fixed, AppVeyour build fails, but seems it's not my fault.

And sorry for delay, I had to finish paperwork with my university.

@olologin
Copy link
Copy Markdown
Contributor Author

ping @jnothman

@amueller
Copy link
Copy Markdown
Member

please rebase.

By `Sebastian Säger`_ and `YenChen Lin`_.

- :class:`ensemble.GradientBoostingClassifier` and :class:`ensemble.GradientBoostingRegressor`
now support sparse input for ``predict`` method.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*the predict method

@amueller amueller added this to the 0.18 milestone Jul 28, 2016
@amueller
Copy link
Copy Markdown
Member

@glouppe @jnothman does this have your +1 and can be merged, or should I review?

@olologin olologin force-pushed the GradientBoostingFix branch from 4ce106e to 0816606 Compare July 28, 2016 18:29
@olologin olologin changed the title [MRG] fix #6101 GradientBoosting decision_function for sparse inputs [MRG+1] fix #6101 GradientBoosting decision_function for sparse inputs Jul 30, 2016
@jnothman
Copy link
Copy Markdown
Member

I've not looked at this yet.

@jnothman
Copy link
Copy Markdown
Member

jnothman commented Aug 2, 2016

Btw, at a skim this looks good, but I'd like to look through it more closely.

By `Sebastian Säger`_ and `YenChen Lin`_.

- :class:`ensemble.GradientBoostingClassifier` and :class:`ensemble.GradientBoostingRegressor`
now support sparse input for the ``predict`` method.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might as well say for "prediction" to include other prediction methods.

@jmschrei
Copy link
Copy Markdown
Member

Now that I think about it actually, maybe it would be worth adding a test ensuring that both the dense and sparse versions run within a factor of 2 of each other? @amueller what is your position on timings in unit tests? I don't want this to deprecate in the future.

@jnothman
Copy link
Copy Markdown
Member

Timings in unit tests are problematic. Relative timings are going to be dependent on nonzero density in X, apart from architecture issues. I'm -1 for such tests, though it is worth benchmarking (on one architecture) at PR time to see that nothing crazy is happening.

@olologin olologin force-pushed the GradientBoostingFix branch from 056f780 to 8d21c0c Compare October 14, 2016 17:56
@jmschrei
Copy link
Copy Markdown
Member

LGTM. This has my +1.

<https://github.com/scikit-learn/scikit-learn/pull/6178>`_) by `Bertrand
Thirion`_

- :class:`ensemble.GradientBoostingClassifier` and :class:`ensemble.GradientBoostingRegressor`
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this belongs under enhancements

@jnothman
Copy link
Copy Markdown
Member

Move the what's new and we'll merge. Thanks!

@olologin
Copy link
Copy Markdown
Contributor Author

Thanks for review 👍

@jnothman jnothman merged commit 78dbcb2 into scikit-learn:master Oct 15, 2016
Sundrique pushed a commit to Sundrique/scikit-learn that referenced this pull request Jun 14, 2017
paulha pushed a commit to paulha/scikit-learn that referenced this pull request Aug 19, 2017
maskani-moh pushed a commit to maskani-moh/scikit-learn that referenced this pull request Nov 15, 2017
@Sandy4321
Copy link
Copy Markdown

as was written
olologin commented on Aug 27, 2016
I fixed performance issue, now it works almost as fast as dense version in test provided by @ogrisel above. 2.773s for dense and 3.104s for sparse.
Also I've found and fixed stupid mistake in safe_realloc usage from tree.pyx and in function for sparse prediction which I added here. It required more memory to allocate than user needs
if somebody may share a test case code

@rth
Copy link
Copy Markdown
Member

rth commented May 18, 2018

@Sandy4321 see #6101 (comment)

@Sandy4321
Copy link
Copy Markdown

I see so your code looks like this
'''
from future import print_function, division
import numpy as np

from sklearn.datasets import fetch_20newsgroups_vectorized
from sklearn.utils.validation import check_array

from sklearn.ensemble import GradientBoostingClassifier

data_train = fetch_20newsgroups_vectorized(subset="train")
data_test = fetch_20newsgroups_vectorized(subset="test")
X_train_sp = check_array(data_train.data, dtype=np.float32,
accept_sparse="csc")
checked_test = check_array(data_test.data, dtype=np.float32, accept_sparse="csr")
X_test_sp = checked_test[:3500, :]
y_train_sp = data_train.target
y_test_sp = data_test.target[:3500]

X_test_dense = X_test_sp.todense()

print("20 newsgroups")
print("=============")
print("X_train.shape = {0}".format(X_train_sp.shape))
print("X_train density = {0}"
"".format(X_train_sp.nnz / np.product(X_train_sp.shape)))
print("y_train {0}".format(y_train_sp.shape))
print("X_test {0}".format(X_test_sp.shape))
print("X_test.format = {0}".format(X_test_sp.format))
print("X_test.dtype = {0}".format(X_test_sp.dtype))
print("y_test {0}".format(y_test_sp.shape))
print()

print("Classifier Training")
print("===================")
accuracy, test_time = {}, {}

name = "GradientBoostingClassifier_100_trees"
clf = GradientBoostingClassifier(n_estimators=100)
try:
clf.set_params(random_state=0)
except (TypeError, ValueError):
pass

print("Training %s ... " % name, end="")
clf.fit(X_train_sp, y_train_sp)

%timeit clf.predict(X_test_dense)
%timeit clf.predict(X_test_sp)
'''

@Sandy4321
Copy link
Copy Markdown

???
I put code between ''' and '''

@rth
Copy link
Copy Markdown
Member

rth commented May 23, 2018

Code needs to be between ``` not ''' :)

@Sandy4321
Copy link
Copy Markdown


my code

like this?

@Sandy4321
Copy link
Copy Markdown

great it works!!!!!
thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.