Skip to content

[WIP]: Sparse label_binarizer and OneVsRestClassifier#2458

Closed
rsivapr wants to merge 29 commits intoscikit-learn:masterfrom
rsivapr:sparse_labelbinarizer
Closed

[WIP]: Sparse label_binarizer and OneVsRestClassifier#2458
rsivapr wants to merge 29 commits intoscikit-learn:masterfrom
rsivapr:sparse_labelbinarizer

Conversation

@rsivapr
Copy link
Copy Markdown
Contributor

@rsivapr rsivapr commented Sep 18, 2013

#2441

  • modify the function sklearn.utils.multiclass.type_of_target to recognise sparse binary matrix as a 'multilabel-indicator' format.
  • modify the function sklearn.utils.muticlass.unique_labels to work with this new format.
  • modify the label_binarize function to return a sparse label indicator.
  • modify OVR to work with sparse output matrices. For instance by extracting a dense column of the sparse matrix and fit the classifier on it. Prediction would be made by creating the sparse matrix incrementally out of each fitted classifier.
  • Write tests
  • Edit _fit_binary method to take in dense and sparse matrices.
  • Single label Multi-class broken.
  • import scipy.sparse as sp instead of from scipy.sparse import coo_matrix

@arjoly
Copy link
Copy Markdown
Member

arjoly commented Sep 18, 2013

In #2441, you said

"I am working on a text classification problem. I have a largish dataset with about 5 million documents and close to 50000 classes. I have used the TfidfVectorizer to extract features (again about 1 million features) from the documents."

Can you give the level sparsity of the input and the output?
Do you have a small benchmark script?

@arjoly
Copy link
Copy Markdown
Member

arjoly commented Sep 18, 2013

Apparently, you have some conflict with the master branch. Can you rebase?

@rsivapr
Copy link
Copy Markdown
Contributor Author

rsivapr commented Sep 18, 2013

Can you give the level sparsity of the input and the output?

About 0.3% - 0.4% are non-zero values.

@rsivapr
Copy link
Copy Markdown
Contributor Author

rsivapr commented Sep 18, 2013

Apparently, you have some conflict with the master branch. Can you rebase?

This is embarrassing. I'm getting a conflict error when I try to rebase. I have no idea what to do.

@rsivapr
Copy link
Copy Markdown
Contributor Author

rsivapr commented Sep 18, 2013

@arjoly Done.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you call delayed on _fit_binary, you save a reference to all input argument.
In you case, it means that you densify Y.
To solve the issue, we should pass to _fit_binary a reference to the sparse Y and the class index of the job.
So the signature of _fit_binary would become _fit_binary(estimator, X, Y, class_index, class_name).

I think that you should treat the sparse and dense case in _fit_binary.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I completely understand this. So do we change the _fit_binary method to fit for each class individually instead of all the classes?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by class index of the job?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the delayed function is used to decorate a function, you savea reference
to the function that you call and the argument

In [1]: from joblib.parallel import delayed
In [2]: import numpy as np
In [3]: delayed(np.mean)(np.arange(5))
Out[3]: (<function numpy.core.fromnumeric.mean>, (array([0, 1, 2, 3, 4]),), {})
In [4]: list(delayed(np.mean)(np.arange(i)) for i in range(5))
Out[4]: 
[(<function numpy.core.fromnumeric.mean>, (array([], dtype=int64),), {}),
 (<function numpy.core.fromnumeric.mean>, (array([0]),), {}),
 (<function numpy.core.fromnumeric.mean>, (array([0, 1]),), {}),
 (<function numpy.core.fromnumeric.mean>, (array([0, 1, 2]),), {}),
 (<function numpy.core.fromnumeric.mean>, (array([0, 1, 2, 3]),), {})]

Later on with the Parallel class, you dispatch each function call with the appropriate argument.
In this case, it means that you "densify" the sparse column and it will blow your memory.

What I suggest is to pass to the delayed function a reference to the input array,
the output array, the label index and the label name. In the _fit_binary,
you can densify the label column without blowing your memory.

Memory consumption can be further reduced using the latest version of joblib
and memmap. Though, I am not sure it works with sparse matrix though (@ogrisel?)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So do we change the _fit_binary method to fit for each class individually instead of all the classes?

I am not sure that I understand what you mean by individually.
In one versus rest strategy, you fit one binary classifier by class
and later on aggregate the output of each one.
Am I missing something?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I missing something?

Sorry. I thought we passed the list of classes every time as an argument and later noticed we just pass [not i, i].

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I suggest is to pass to the delayed function a reference to the input array,
the output array, the label index and the label name. In the _fit_binary,
you can densify the label column without blowing your memory.

Awesome. I will modify that and add that to this pr.

Memory consumption can be further reduced using the latest version of joblib
and memmap. Though, I am not sure it works with sparse matrix though (@ogrisel?)

Yes. I'd be very interested to see if that works. Perhaps @ogrisel can point me toward the right direction.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I suggest is to pass the delayed function references to the input array,
the output array, the label index and the label name. In the _fit_binary,
you can densify the label column without blowing your memory.

That's probably a good idea. Another useful trick may be to rely on the
'pre_dispatch' argument of Parallel, which helps a bit in such a
situation.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the tip!

@arjoly
Copy link
Copy Markdown
Member

arjoly commented Sep 19, 2013

You have some travis failure.

@arjoly
Copy link
Copy Markdown
Member

arjoly commented Sep 19, 2013

Thank you for working on this. :-). Can you put at the beginning of pr title WIP (work in progress)? Later on, we can swap to MRG to indicate that we need some reviews.

@rsivapr
Copy link
Copy Markdown
Contributor Author

rsivapr commented Sep 19, 2013

Thank you for guiding me! Apart from the failures, am I going in the right direction?

I think I need to write tests for label_binarizer returning sparse matrices. I might be wrong. I am pretty much new to this.

@arjoly
Copy link
Copy Markdown
Member

arjoly commented Sep 19, 2013

I think I need to write tests for label_binarizer returning sparse matrices. I might be wrong. I am pretty much new to this.

Yes, new features should be tested. We use the nose
as the unit testing package.

You can add the new test in the test_XXX.py file and run those test "locally" using
nose /path/to/test/test_XXX.py. All the test suite can be run with make test.

You can find many informations in the contributing documentation.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arjoly I made the changes as per your recommendations. Does this look better?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would simply do

 y = Y[:, class_index]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arjoly I made the changes as per your recommendations. Does this look better?

👍 Are you able to make some run on your problem?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use the toarray method to perform this operation. All classifier should accept a column as y.

@rsivapr
Copy link
Copy Markdown
Contributor Author

rsivapr commented Dec 4, 2013

Hi Arnaud, will do it. Just busy with my thesis defense in a couple of weeks.

Is there anything other than writing tests for this PR?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather never use todense() but instead use toarray() to never deal with np.matrix instances in the sklearn code base.

@arjoly
Copy link
Copy Markdown
Member

arjoly commented Dec 4, 2013

I am checking that the pr is not dying. The release is likely to be at the end of december or mid-january. If you don't have time, maybe @mahendrakariya can help you to finish (by making a pr to merge into your branch. If this pr is not ready for 0.15, it's better to merge this in 0.16.

Is there anything other than writing tests for this PR?

Yes, mainly testing and a bit of narrative documentation. We should also benchmark against the master version.

@jnothman
Copy link
Copy Markdown
Member

jnothman commented Dec 4, 2013

Do we need to raise an error in metrics to say "not supported yet" to sparse multilabel-indicators?

@jnothman
Copy link
Copy Markdown
Member

jnothman commented Dec 5, 2013

Actually, I've got a solution for sparse metrics. I haven't yet tested it, and doing so would benefit from this sparse label_binarizer.

See https://gist.github.com/jnothman/7798757 (untested!) where I use class polymorphism to refactor the multilabel metric code paths. The idea is that each metric calls _multilabel_helper once and calls methods on its returned object to get marginals over comparisons between y_true and y_pred.

This approach gets around scipy.sparse's (former?) lack of support for !=, logical_and, logical_or and xor that would make writing the same code for sparse and dense cases easier. It also means we can get the marginals in each axis for the sparse case in a fast, CSR-specialised way.

@arjoly, @ogrisel

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Firstly, setxor1d is overkill. All you need to ensure there is at most 2 labels is len(y.data) == 0 or np.ptp(y.data) == 0.

Secondly, this doesn't work for sparse types without data (e.g. DOK), or lil_matrix which has a data attribute that is an array of lists. Instead you can use:

def is_sparse_and_binary(y):
    if not hasattr(y, 'data') or y.data.dtype.kind == 'O':
        y = y.tocoo()
    return len(y.data) == 0 or np.ptp(y.data) == 0 

(You could use isinstance(y, scipy.sparse.data._data_matrix), but that involves a private class.)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. the y.data.dtype.kind == 'O' check is non trivial so this code should be included with a comment to explain that this adds support for the DOK and LIL sparse matrix datastructures.

@jnothman
Copy link
Copy Markdown
Member

jnothman commented Dec 5, 2013

I have a tested sparse multilabel metrics implementation at https://github.com/jnothman/scikit-learn/tree/sparse_multi_metrics. I'll make a PR of it once this is accepted.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use scipy.sparse.issparse

@jnothman
Copy link
Copy Markdown
Member

jnothman commented Dec 5, 2013

@arjoly, I've just taken a look through more of the PR than I had before. I think it still requires some non-trivial work apart from testing. I think it needs to either be held back until 0.16, or taken on quickly by a developer who is more familiar with the subtleties of scipy.sparse, who may need to cherry-pick from this branch and re-implement portions of it. What do you think?

@ogrisel
Copy link
Copy Markdown
Member

ogrisel commented Dec 5, 2013

I have a tested sparse multilabel metrics implementation at https://github.com/jnothman/scikit-learn/tree/sparse_multi_metrics. I'll make a PR of it once this is accepted.

This looks very nice. Have you tried to bench it a bit?

@ogrisel
Copy link
Copy Markdown
Member

ogrisel commented Dec 5, 2013

@arjoly, I've just taken a look through more of the PR than I had before. I think it still requires some non-trivial work apart from testing. I think it needs to either be held back until 0.16, or taken on quickly by a developer who is more familiar with the subtleties of scipy.sparse, who may need to cherry-pick from this branch and re-implement portions of it. What do you think?

Is this mostly about: #2458 (comment) or do you have other major road blocks in mind?

@jnothman
Copy link
Copy Markdown
Member

jnothman commented Dec 5, 2013

@ogrisel wrote about my metrics implementation:

This looks very nice. Have you tried to bench it a bit?

I haven't. I expect the sequence of sequences will be slower because I yet again need to get the set of labels, and may need to iterate through the data more times than before. For the label indicator matrices, I think it's exactly the same vector operations, only with the small O(1) class construction and method calling overhead.

@jnothman
Copy link
Copy Markdown
Member

jnothman commented Dec 5, 2013

Yes, it's mostly with regard to that issue, as well as some code duplication, ensuring that this works for lots of matrix types. I guess this also depends on what the exact release time-scale is for 0.15. Note that @rsivapr is unlikely to do much on this before his thesis defence in a couple of weeks and that there aren't tests for some of the functionality.

@arjoly
Copy link
Copy Markdown
Member

arjoly commented Dec 9, 2013

I have a tested sparse multilabel metrics implementation at https://github.com/jnothman/scikit-learn/tree/sparse_multi_metrics. I'll make a PR of it once this is accepted.

Very nice !!!

@arjoly, I've just taken a look through more of the PR than I had before. I think it still requires some non-trivial work apart from testing. I think it needs to either be held back until 0.16, or taken on quickly by a developer who is more familiar with the subtleties of scipy.sparse, who may need to cherry-pick from this branch and re-implement portions of it. What do you think?

I think that this would be very nice if @rsivapr is able to finish its (first?) contribution. Anybody wanting to help is welcome.

@rsivapr
Copy link
Copy Markdown
Contributor Author

rsivapr commented Dec 27, 2013

Hey @arjoly. I should be able to work on this starting mid-January. I would have plenty of time to perhaps work on more issues as well. Until then I'm afraid I would have to focus on my thesis.

Thanks for review @jnothman, @ogrisel, and @arjoly. I will attempt to fix all of the above suggestions when I get back.

@jnothman
Copy link
Copy Markdown
Member

How's this going, Rohit?

@arjoly
Copy link
Copy Markdown
Member

arjoly commented Feb 11, 2014

Hi @rsivapr,

I think it's about time to deprecate the sequence of sequence format. A lot of work have already been done in that direction and waiting in different pull requests (thanks to @jnothman).

Since this pull request is blocking the way, I took some time to work on the sparse label binarizer (see this branch https://github.com/arjoly/scikit-learn/tree/sparse-label_binarizer). It's almost finished.

Is it ok for you if I finish the sparse label binarizer?
Can I help you to finish what you have undertaken with OneVersusRestClassfier?

@hamsal hamsal mentioned this pull request May 26, 2014
17 tasks
@rsivapr
Copy link
Copy Markdown
Contributor Author

rsivapr commented May 28, 2014

Hi @arjoly
Sorry I did not notice this post. I've been busy with the new job. Since I'm not able to find time to work on this, it's completely okay if you or @hamsal work on this. Thank you and my apologies for not being able to complete this.

@rsivapr
Copy link
Copy Markdown
Contributor Author

rsivapr commented May 28, 2014

@arjoly I can close this PR if you'd like.

@arjoly
Copy link
Copy Markdown
Member

arjoly commented May 28, 2014

Ok then I am closing. Good luck with your job!

@arjoly arjoly closed this May 28, 2014
@jnothman
Copy link
Copy Markdown
Member

Thanks for your effort here @rsivapr, and even if the code doesn't go in, your first attempt at an inevitable feature has highlighted the challenges in what needs to be done. Good luck with the job.

@hamsal hamsal mentioned this pull request Jun 13, 2014
6 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants