Skip to content

Commit 359fd1a

Browse files
author
Lowik CHANUSSOT
committed
#4475 : Add a safe_pairwise_distances function, dealing with zero variance samples when using correlation metric.
The best fix would be to have the metric not returning NaN values, but as the correlation metric is actually computed by spicy, we can't modify it directly. So, in case of metric=='correlation', we replace rows and cols corresponding to zero variance samples by the maximum distance (here 1.0).
1 parent 1c33a6f commit 359fd1a

3 files changed

Lines changed: 49 additions & 5 deletions

File tree

sklearn/manifold/t_sne.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from ..utils import check_random_state
1818
from ..utils.extmath import _ravel
1919
from ..decomposition import RandomizedPCA
20-
from ..metrics.pairwise import pairwise_distances
20+
from ..metrics.pairwise import safe_pairwise_distances
2121
from . import _utils
2222

2323

@@ -269,8 +269,8 @@ def trustworthiness(X, X_embedded, n_neighbors=5, precomputed=False):
269269
if precomputed:
270270
dist_X = X
271271
else:
272-
dist_X = pairwise_distances(X, squared=True)
273-
dist_X_embedded = pairwise_distances(X_embedded, squared=True)
272+
dist_X = safe_pairwise_distances(X, squared=True)
273+
dist_X_embedded = safe_pairwise_distances(X_embedded, squared=True)
274274
ind_X = np.argsort(dist_X, axis=1)
275275
ind_X_embedded = np.argsort(dist_X_embedded, axis=1)[:, 1:n_neighbors + 1]
276276

@@ -438,9 +438,9 @@ def fit(self, X, y=None):
438438
print("[t-SNE] Computing pairwise distances...")
439439

440440
if self.metric == "euclidean":
441-
distances = pairwise_distances(X, metric=self.metric, squared=True)
441+
distances = safe_pairwise_distances(X, metric=self.metric, squared=True)
442442
else:
443-
distances = pairwise_distances(X, metric=self.metric)
443+
distances = safe_pairwise_distances(X, metric=self.metric)
444444

445445
# Degrees of freedom of the Student's t-distribution. The suggestion
446446
# alpha = n_components - 1 comes from "Learning a Parametric Embedding

sklearn/manifold/tests/test_t_sne.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,3 +271,21 @@ def test_reduction_to_one_component():
271271
X = random_state.randn(5, 2)
272272
X_embedded = tsne.fit_transform(X)
273273
assert(np.all(np.isfinite(X_embedded)))
274+
275+
276+
def test_undefined_correlation():
277+
#t_SNE throws an exception with undefined correlation (issue #4475)
278+
from sklearn.manifold import TSNE
279+
import numpy as np
280+
np.random.seed(42)
281+
282+
data = np.random.rand(10, 3)
283+
data[-1, :] = 0
284+
285+
try:
286+
model = TSNE(metric="correlation")
287+
model.fit_transform(data)
288+
assert True
289+
except ValueError as e:
290+
assert False
291+

sklearn/metrics/pairwise.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,6 +1113,32 @@ def pairwise_distances(X, Y=None, metric="euclidean", n_jobs=1, **kwds):
11131113
return _parallel_pairwise(X, Y, func, n_jobs, **kwds)
11141114

11151115

1116+
def safe_pairwise_distances(X, Y=None, metric="euclidean", n_jobs=1, **kwds):
1117+
""" Compute a 'safe' distance matrix from a vector array X and optional Y, especially when metric is 'correlation'.
1118+
1119+
When metric is 'correlation', scipy.distance.pdist returns NaN values when a sample has zero variance.
1120+
Those values will be replaced by 1.0.
1121+
1122+
See also
1123+
--------
1124+
pairwise_distances
1125+
"""
1126+
distances = pairwise_distances(X, Y, metric, n_jobs, **kwds)
1127+
if metric == 'correlation':
1128+
invalid_peak_to_peak_value = 0.0
1129+
correlation_for_invalid_value = 1.0 # maximum value of scipy.spatial.distance.correlation
1130+
invalid_rows_X = np.where(np.ptp(X, axis=1) == invalid_peak_to_peak_value)
1131+
distances[invalid_rows_X, :] = correlation_for_invalid_value
1132+
1133+
if Y is not None:
1134+
invalid_rows_Y = np.where(np.ptp(Y, axis=1) == invalid_peak_to_peak_value)
1135+
distances[:, invalid_rows_Y] = correlation_for_invalid_value
1136+
else:
1137+
distances[:, invalid_rows_X] = correlation_for_invalid_value
1138+
1139+
return distances
1140+
1141+
11161142
# Helper functions - distance
11171143
PAIRWISE_KERNEL_FUNCTIONS = {
11181144
# If updating this dictionary, update the doc in both distance_metrics()

0 commit comments

Comments
 (0)