Skip to content

Commit 85cf315

Browse files
Directly construct CSR matrix
1 parent a1c0982 commit 85cf315

1 file changed

Lines changed: 7 additions & 7 deletions

File tree

sklearn/preprocessing/data.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2761,16 +2761,16 @@ def transform(self, X):
27612761
mask = X_mask.ravel()
27622762
n_values = [cats.shape[0] for cats in self.categories_]
27632763
n_values = np.array([0] + n_values)
2764-
indices = np.cumsum(n_values)
2764+
feature_indices = np.cumsum(n_values)
27652765

2766-
column_indices = (X_int + indices[:-1]).ravel()[mask]
2767-
row_indices = np.repeat(np.arange(n_samples, dtype=np.int32),
2768-
n_features)[mask]
2766+
indices = (X_int + feature_indices[:-1]).ravel()[mask]
2767+
indptr = X_mask.sum(axis=1).cumsum()
2768+
indptr = np.insert(indptr, 0, 0)
27692769
data = np.ones(n_samples * n_features)[mask]
27702770

2771-
out = sparse.csc_matrix((data, (row_indices, column_indices)),
2772-
shape=(n_samples, indices[-1]),
2773-
dtype=self.dtype).tocsr()
2771+
out = sparse.csr_matrix((data, indices, indptr),
2772+
shape=(n_samples, feature_indices[-1]),
2773+
dtype=self.dtype)
27742774
if self.encoding == 'onehot-dense':
27752775
return out.toarray()
27762776
else:

0 commit comments

Comments
 (0)