Skip to content

AffinityPropagation creates 3d array of cluster centers on rare occasions #9612

@jsamoocha

Description

@jsamoocha

Description

Just stumbled upon a rare combination of training data and preference value that causes the model to save its cluster centers as a 3d ndarray instead of expected 2d.

Steps/Code to Reproduce

import numpy as np
from sklearn.cluster.affinity_propagation_ import AffinityPropagation

train_data = np.array([[-1.,  1.], [1., -1.]])
model = AffinityPropagation(preference=-10).fit(train_data)
model.cluster_centers_

yields

array([[[-1.,  1.], [ 1., -1.]]])  # 3d!!

and

model.predict(train_data)

leads to

Traceback (most recent call last):
  File "<input>", line 1, in <module>
  File "/Users/jsamoocha/.virtualenvs/coach/lib/python2.7/site-packages/sklearn/cluster/affinity_propagation_.py", line 324, in predict
    return pairwise_distances_argmin(X, self.cluster_centers_)
  File "/Users/jsamoocha/.virtualenvs/coach/lib/python2.7/site-packages/sklearn/metrics/pairwise.py", line 464, in pairwise_distances_argmin
    metric_kwargs)[0]
  File "/Users/jsamoocha/.virtualenvs/coach/lib/python2.7/site-packages/sklearn/metrics/pairwise.py", line 339, in pairwise_distances_argmin_min
    X, Y = check_pairwise_arrays(X, Y)
  File "/Users/jsamoocha/.virtualenvs/coach/lib/python2.7/site-packages/sklearn/metrics/pairwise.py", line 111, in check_pairwise_arrays
    warn_on_dtype=warn_on_dtype, estimator=estimator)
  File "/Users/jsamoocha/.virtualenvs/coach/lib/python2.7/site-packages/sklearn/utils/validation.py", line 405, in check_array
    % (array.ndim, estimator_name))
ValueError: Found array with dim 3. check_pairwise_arrays expected <= 2.

When using slightly different values for preference (e.g. 0 or -20), or slightly different training data (e.g. [[-1, 1], [1, -0.9]]), cluster centers are stored correctly as 2d ndarray.

Expected Results

Cluster centers to be stored as 2d ndarray, as in normal cases.

Versions

Darwin-15.6.0-x86_64-i386-64bit
('Python', '2.7.13 (default, Jul 18 2017, 09:16:53) \n[GCC 4.2.1 Compatible Apple LLVM 8.0.0 (clang-800.0.42.1)]')
('NumPy', '1.13.1')
('SciPy', '0.19.1')
('Scikit-Learn', '0.18.2')

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions