Skip to content

Commit 9153fc0

Browse files
author
Hamzeh Alsalhi
committed
Add a classes parameter to LabelEncoder.transform
1 parent 235d681 commit 9153fc0

2 files changed

Lines changed: 27 additions & 8 deletions

File tree

sklearn/preprocessing/label.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -145,28 +145,35 @@ def fit_transform(self, y):
145145
self.classes_, y = np.unique(y, return_inverse=True)
146146
return y
147147

148-
def transform(self, y):
148+
def transform(self, y, classes=None):
149149
"""Transform labels to normalized encoding.
150150
151151
Parameters
152152
----------
153153
y : array-like of shape [n_samples]
154154
Target values.
155155
156+
classes : array-like, optional (default: None)
157+
List of unique sorted labels to encode the target data against.
158+
If None the LabelEncoder must have already been fit and the unique
159+
labels from the fit will be used.
160+
156161
Returns
157162
-------
158163
y : array-like of shape [n_samples]
159164
"""
160-
self._check_fitted()
165+
if classes is None:
166+
self._check_fitted()
167+
classes = self.classes_
161168

162-
classes = np.unique(y)
163-
_check_numpy_unicode_bug(classes)
164-
if len(np.intersect1d(classes, self.classes_)) < len(classes):
169+
y_classes = np.unique(y)
170+
_check_numpy_unicode_bug(y_classes)
171+
if len(np.intersect1d(y_classes, classes)) < len(y_classes):
165172
# Get the new labels
166-
unseen = np.setdiff1d(classes, self.classes_)
173+
unseen = np.setdiff1d(y_classes, classes)
167174

168175
if type(self.new_labels) is int:
169-
ret = np.searchsorted(self.classes_, y)
176+
ret = np.searchsorted(classes, y)
170177
ret[np.in1d(y, unseen)] = self.new_labels
171178
return ret
172179
elif self.new_labels is None:
@@ -176,7 +183,7 @@ def transform(self, y):
176183
raise ValueError("Value of argument `new_labels`={0} is "
177184
"unknown.".format(self.new_labels))
178185

179-
return np.searchsorted(self.classes_, y)
186+
return np.searchsorted(classes, y)
180187

181188
def inverse_transform(self, y):
182189
"""Transform labels back to original encoding.

sklearn/preprocessing/tests/test_label.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,18 @@ def test_label_encoder_new_label_replace():
220220
assert_array_equal(le.transform(["b", "c", "d"]), [1, 2, -99])
221221

222222

223+
def test_label_encoder_transform_classes_parameter():
224+
"""Test LabelEncoder's transform using the classes parameter"""
225+
le = LabelEncoder(new_labels=None)
226+
le.fit(["a", "b", "b", "c"])
227+
assert_array_equal(le.classes_, ["a", "b", "c"])
228+
assert_array_equal(le.transform(["d", "f", "e", "e"],
229+
classes=["d", "e", "f"]),
230+
[0, 2, 1, 1])
231+
assert_array_equal(le.inverse_transform([2, 1, 0]), ["c", "b", "a"])
232+
assert_raises(ValueError, le.transform, ["b", "c", "d"])
233+
234+
223235
def test_label_encoder_fit_transform():
224236
"""Test fit_transform"""
225237
le = LabelEncoder()

0 commit comments

Comments
 (0)