FIX raise an error if user defined categories contain duplicate values#27328
FIX raise an error if user defined categories contain duplicate values#27328betatim merged 14 commits intoscikit-learn:mainfrom
Conversation
glemaitre
left a comment
There was a problem hiding this comment.
We should also acknowledge this change in the changelog.
sklearn/preprocessing/_encoders.py
Outdated
| ) | ||
| raise ValueError(msg) | ||
|
|
||
| if len(cats) != len(set(cats)): |
There was a problem hiding this comment.
Since we deal with a NumPy array, let's use numpy to solve this issue.
| if len(cats) != len(set(cats)): | |
| _, n_unique_categories = np.unique(cats, return_counts=True) | |
| if cats.size != n_unique_categories: |
There was a problem hiding this comment.
I think this causes the error. How about if cats.size != np.unique(cats).size:
There was a problem hiding this comment.
Yep n_unique_categories = array([1, 1, 1, 1, 1]) is older NumPy version. I am not really sure why. The change that you propose is fine with me.
There was a problem hiding this comment.
Unfortunately, this still causes error due to None, see the fails https://github.com/scikit-learn/scikit-learn/pull/27328/checks?check_run_id=18232777028.
For example, np.unique(np.array([None, 'a', 'z'], dtype=object)) will raise an error.
How about if cats.size != len(set(cats)):?
There was a problem hiding this comment.
OK so let's revert to the set then. However, we need a test where we specify several time nan in the category as well to check this corner case.
There was a problem hiding this comment.
I am going to use _unique function instead, see the code below.
import numpy as np
from sklearn.utils._encode import _unique
print(set(np.array(['a', None, None]))) # {'a', None}
print(set(np.array(['a', np.nan, np.nan]))) # {'a', 'nan'}
print(set(np.array([1., np.nan, np.nan]))) # {nan, 1.0, nan}
print(_unique(np.array(['a', None, None]))) # ['a' None]
print(_unique(np.array(['a', np.nan, np.nan]))) # ['a' 'nan']
print(_unique(np.array([1., np.nan, np.nan]))) # [ 1. nan]For several nan in the category, I don't think it's necessary. Since now we assume nan must at the last, and PR #27309 will resolve this.
BTW, which way do you think is better? First check if nan is at the last or first check if category contain duplicated values?
There was a problem hiding this comment.
Hi @thomasjpfan, would you like take a look at this PR and also #27309 ?
There was a problem hiding this comment.
BTW, which way do you think is better? First check if nan is at the last or first check if category contain duplicated values?
Maybe the check for nan first since it is less expensive.
There was a problem hiding this comment.
Sure, I will update code after #27309 is merged.
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
|
Looks good to me, I enabled auto merge. |
Head branch was pushed to by a user without write access
|
I just resolved some conflicts, could you take a look again? @betatim |
scikit-learn#27328) Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: Tim Head <betatim@gmail.com>
Reference Issues/PRs
Follow up #27309
Mentioned in #27088
What does this implement/fix? Explain your changes.
In encoders, check user defined categories and raise an error if they have duplicate values.
Any other comments?