Skip to content

Preprocessing simplification#193

Merged
mblondel merged 35 commits intoscikit-learn:masterfrom
ogrisel:preprocessing-simplification
Jun 6, 2011
Merged

Preprocessing simplification#193
mblondel merged 35 commits intoscikit-learn:masterfrom
ogrisel:preprocessing-simplification

Conversation

@ogrisel
Copy link
Copy Markdown
Member

@ogrisel ogrisel commented Jun 2, 2011

Here is some early progress report on a refactoring of the preprocessing package to make it simpler by combining dense and sparse variant into consistent classes.

As usual, early feedback welcomed

TODO before merge:

  • more tests for pathological cases
  • narrative documentation and more usage examples

Sparse variant for the Scaler is left for another pull request.

@mblondel
Copy link
Copy Markdown
Member

mblondel commented Jun 2, 2011

Cool, the code looks simpler than before.

We may want to keep the name Normalizer rather than SampleNormalizer and introduce an axis=0|1 argument to leave the door open for column normalization.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was your motivation for putting the copy argument in the constructor? The rest of the scikit usually puts it in transform. Also copy need not be grid-searched.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it is more logical to have it as a constructor parameter since you do not need to know the dimensionality of the data to set it or not. I could re-add it though.

@ogrisel
Copy link
Copy Markdown
Member Author

ogrisel commented Jun 2, 2011

As for the axis=0|1 flag, do you really see a use case for column l1 or l2-normalization? I am not strongly opposed to the implementation but I would not want to introduce an option that would be confusing for the users if there is no real use-case for it.

The most common column normalization scheme I see is unit-variance normalization that will be addressed in the Scaler class.

@mblondel
Copy link
Copy Markdown
Member

mblondel commented Jun 2, 2011

axis is a common argument in numpy so I don't think it will be confusing but the question of whether it will be useful or not is a good question. Let's hear other people's opinion on this, as well as on the copy parameter.

@vene
Copy link
Copy Markdown
Member

vene commented Jun 2, 2011

I would expect an axis parameter. One use would be in Sparse PCA where U is constrained to have normalized columns, and when transforming data by projecting onto V, I had to do column normalization by hand. OTOH this is not really preprocessing.

@ogrisel
Copy link
Copy Markdown
Member Author

ogrisel commented Jun 2, 2011

@vene I think I should expose lower level functions for your use case instead of using the fit / transform which does not really make sense for this case.

@vene
Copy link
Copy Markdown
Member

vene commented Jun 2, 2011

agreed, it's sort of "internal use"

On Thu, Jun 2, 2011 at 12:08 PM, ogrisel
reply@reply.github.com
wrote:

@vene I think I should expose lower level functions for your use case instead of using the fit / transform which does not really make sense for this case.

Reply to this email directly or view it on GitHub:
#193 (comment)

@pprett
Copy link
Copy Markdown
Member

pprett commented Jun 2, 2011

First of all, thanks for working on this - IMHO that's an important issue.

Regarding the Normalizer and axis issue:
I too use normalize to denote both:

  • normalize vectors e.g. to have L2 norm equals 1.
  • normalize vector components to lie in a specified range (e.g. [-1,1]).

I think the latter is popular by neural net folks [1].
It would be great if we would support both kinds of normalization (possibly in different classes tough).
Personally, I don't have a use case for L1/2 normalization of columns (i.e. features) - but if I had to I'd just transpose the data matrix.

Regarding the Scaler class: do we really want Scaler to work on sparse matrices? As Olivier already pointed out in a thread on the mailing list, subtracting the mean will certainly break sparsity which might not be the intention of the user.
In this case it might be better to only scale the data to have unit variance (not zero mean). Thus, we could add the argument with_mean=True to the constructor of Scaler and allow fit on sparse matrices only if with_mean==False. What do you think?

best,
Peter

[1] http://www.faqs.org/faqs/ai-faq/neural-nets/part2/section-16.html

@ogrisel
Copy link
Copy Markdown
Member Author

ogrisel commented Jun 2, 2011

Regarding the Normalizer and axis issue:
I too use normalize to denote both:

  • normalize vectors e.g. to have L2 norm equals 1.

Alright we agree.

  • normalize vector components to lie in a specified range (e.g. [-1,1]).
    I think the latter is popular by neural net folks [1].

I think the neural network folks (Lecun, Bottou, Hinton...) recommend to scale to unit variance (or even to whiten the data using PCA(whiten=True) for instance.

We could rename Scaler to Standardizer to make that more explicit that this is about mean removal + feature-wise unit variance scaling.

We could also offer a MidrangeStandardizer (center to set mean set to 0.0 and then perform 2 linear scaling on the negative and positive parts to set min to -1.0 and max 1.0).

Personally, I don't have a use case for L1/2 normalization of columns (i.e. features) - but if I had to I'd just transpose the data matrix.

I agree, I will factor out a utility function called normalize(X, norm='l2', axis=1, copy=True) for "internal use", e.g. use cases where we don't want to use it as part of a pipeline.

Regarding the Scaler class: do we really want Scaler to work on sparse matrices? As Olivier already pointed out in a thread on the mailing list, subtracting the mean will certainly break sparsity which might not be the intention of the user.
In this case it might be better to only scale the data to have unit variance (not zero mean). Thus, we could add the argument with_mean=True to the constructor of Scaler and allow fit on sparse matrices only if with_mean==False. What do you think?

I am +1, this is exactly what I had in mind.

@mblondel
Copy link
Copy Markdown
Member

mblondel commented Jun 2, 2011

On Thu, Jun 2, 2011 at 8:27 PM, ogrisel
reply@reply.github.com
wrote:

I think the neural network folks (Lecun, Bottou, Hinton...) recommend to scale to unit variance (or even to whiten the data using PCA(whiten=True) for instance.

We could rename Scaler to Standardizer to make that more explicit that this is about mean removal + feature-wise unit variance scaling.

I'm rather +1 for keeping Scaler (I don't find Standardizer more
explicit and the distinction between Standardizer and Normalizer is a
bit fuzzy).

Personally, I don't have a use case for L1/2 normalization of columns (i.e. features) - but if I had to I'd just transpose the data matrix.

I agree, I will factor out a utility function called normalize(X, norm='l2', axis=1, copy=True) for "internal use", e.g. use cases where we don't want to use it as part of a pipeline.

If you create such a utility function, I don't see the harm of adding
axis to Normalizer too.

Regarding the Scaler class: do we really want Scaler to work on sparse matrices? As Olivier already pointed out in a thread on the mailing list, subtracting the mean will certainly break sparsity which might not be the intention of the user.
In this case it might be better to only scale the data to have unit variance (not zero mean). Thus, we could add the argument with_mean=True to the constructor of Scaler and allow fit on sparse matrices only if with_mean==False. What do you think?

I am +1, this is exactly what I had in mind.

Would applying the centering only on non-zero features work?

Mathieu

@ogrisel
Copy link
Copy Markdown
Member Author

ogrisel commented Jun 2, 2011

If you create such a utility function, I don't see the harm of adding axis to Normalizer too.

Do you have any use case for column normalization in a pipeline setting? I don't see any really. The API should emphasize good / common practices IMHO.

Would applying the centering only on non-zero features work?

That sounds wrong to me. Do you know a case where this was proven useful?

@mblondel
Copy link
Copy Markdown
Member

mblondel commented Jun 2, 2011

On Thu, Jun 2, 2011 at 9:12 PM, ogrisel
reply@reply.github.com
wrote:

If you create such a utility function, I don't see the harm of adding axis to Normalizer too.

Do you have any use case for column normalization in a pipeline setting? I don't see any really. The API should emphasize good / common practices IMHO.

Ok so let's postpone to when someone feels the need for it then.

Regarding the name SampleNormalizer, I actually find it confusing.
Yes, you are normalizing samples (rows) one by one. In other words,
you normalize wrt samples. But what actually gets normalized are
features. So FeatureNormalizer would actually make more sense to me. I
guess it highly depends on how you see things. Plus, if you use
SampleNormalizer, you should probably use SampleScaler too. So, how
about we name them Normalizer and Scaler without an axis argument for
now?

Would applying the centering only on non-zero features work?

That sounds wrong to me. Do you know a case where this was proven useful?

I'm just checking with you, I've never done it. I agree it can
potentially mess up the feature distribution. The thing that worries
me is that the center_mean option should default to True in the dense
case and to False in the sparse case (the scikit should always do by
default the most sensible choice).

Mathieu

@ogrisel
Copy link
Copy Markdown
Member Author

ogrisel commented Jun 2, 2011

Alright for the name of the SampleNormalizer I will rename to its original name Normalizer that will be less confusing.

@ogrisel
Copy link
Copy Markdown
Member Author

ogrisel commented Jun 2, 2011

As for the centering of data indeed the default won't be the same for both sparse and dense. I had in mind to set the default to centering=True in the constructor but raise an exception if the data is sparse and centering is True. That means that users will have to disable centering explicitly if they have sparse data but I think it's better to be explicit rather than trying to implement "smart" / magic behavior.

@mblondel
Copy link
Copy Markdown
Member

mblondel commented Jun 3, 2011

On Fri, Jun 3, 2011 at 3:40 AM, ogrisel
reply@reply.github.com
wrote:

As for the centering of data indeed the default won't be the same for both sparse and dense. I had in mind to set the default to centering=True in the constructor but raise an exception if the data is sparse and centering is True. That means that users will have to disable centering explicitly if they have sparse data but I think it's better to be explicit rather than trying to implement "smart" / magic behavior.

+1

@mblondel
Copy link
Copy Markdown
Member

mblondel commented Jun 3, 2011

Once you have added the additional tests you were thinking about and the narrative doc, I think you can merge and keep the new features for another pull request.

@ogrisel
Copy link
Copy Markdown
Member Author

ogrisel commented Jun 3, 2011

Alright. I'll try to finish that work tomorrow.

larsmans added a commit to larsmans/scikit-learn that referenced this pull request Jun 4, 2011
Requires ogrisel's preprocessing-simplification branch,
pull request scikit-learn#193.
@ogrisel
Copy link
Copy Markdown
Member Author

ogrisel commented Jun 4, 2011

Alright I think this branch is now ready to be merged: sparse Scaler and binning will be implemented on other future branches.

@mblondel
Copy link
Copy Markdown
Member

mblondel commented Jun 4, 2011

Huge job on the documentation Olivier! I'll try to write the missing parts (LabelBinarizer and KernelCenterer) next week. If it's ok that I do it in master, +1 for merge.

@ogrisel
Copy link
Copy Markdown
Member Author

ogrisel commented Jun 4, 2011

Ok for writing the missing doc directly on master. I think Alex will have a look at the branch tomorrow. I let you guys merge yourself when you agree (I might be off-line all day tomorrow).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here you say you normalize rows but then you set axis=1 by default and later you say : if 1, normalize the columns instead of the rows.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i am not sure why axis is required for normalize.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You normalize the samples (rows) over the features (columns, hence axis=1 in numpy parlance).

@agramfort
Copy link
Copy Markdown
Member

although I understand and like the idea behind the 2 objects (Scaler for features and Normalizer for samples), I have a few remarks:

  • normalize and scale seem to duplicate functionalities. Especially if you allow to pass axis to normalize.
  • the use of Normalizer is a pipeline and even a grid search seems weird. If you need to normalize your samples, do it as soon as you load your data not a in pipeline as you may do many time this preprocessing step for nothing.

@ogrisel
Copy link
Copy Markdown
Member Author

ogrisel commented Jun 5, 2011

Normalize is scaling using the norms of the vectors with Scaler uses the variances. You can use a pipeline outside of a grid search. I agree it should not be put inside a grid search undless you can to measure the impact of of with_mean or with_std as part of your grid search for instance.

@mblondel
Copy link
Copy Markdown
Member

mblondel commented Jun 5, 2011

  • Normalizer and Scaler have mutually exclusive parameters so they'd better be in a separate object
  • For me, the point of pipeline is that you define a chain of objects and the pipeline does everything for you

(BTW, normalizing with the L2 norm is in my experience what works best for text classification)

@agramfort
Copy link
Copy Markdown
Member

Normalizer and Scaler have mutually exclusive parameters so they'd better be in a separate object

+1

my point was on the functions scale and normalize. Normalizer do not expose axis.

For me, the point of pipeline is that you define a chain of objects and the pipeline does everything for you

I agree. My point is that Normalizing samples in a pipelince can cause a lot of duplicated computation if you do not pay attention.

@ogrisel
Copy link
Copy Markdown
Member Author

ogrisel commented Jun 5, 2011

I agree. My point is that Normalizing samples in a pipelince can cause a lot of duplicated computation if you do not pay attention.

This is a general issue with any feature extractor / transformer: if you use PCA, ICA or NMF as the input of a classifier the issue will be even worst. I think we could extend the pipeline to make a variant that uses joblib to memoize pipelined transformers but we need to do it carefully to avoid eating all the available space on the hard drive.

mblondel added a commit that referenced this pull request Jun 6, 2011
@mblondel mblondel merged commit 575c0fe into scikit-learn:master Jun 6, 2011
VirgileFritsch pushed a commit to VirgileFritsch/scikit-learn that referenced this pull request Jun 26, 2011
Requires ogrisel's preprocessing-simplification branch,
pull request scikit-learn#193.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants