new tests for mean_shift algo#13179
Conversation
|
@ogrisel can help review this |
| def test_mean_shift_negative_bandwidth(): | ||
| bandwidth = -1 | ||
| ms = MeanShift(bandwidth=bandwidth) | ||
| msg = \ |
There was a problem hiding this comment.
Use parentheses to enclose expressions and split them over multiple lines rather than using \ for line continuation
There was a problem hiding this comment.
@jnothman this comment is not clear will following statement work?
msg = "bandwidth needs to be greater than zero or None,"
" got -1.000000"
There was a problem hiding this comment.
This will:
msg = ("bandwidth needs to be greater than zero or None,"
" got -1.000000")
|
|
||
| def test_seeds(): | ||
| ms = MeanShift(seeds=None) | ||
| _ = ms.fit(X).labels_ |
| assert_raise_message(ValueError, msg, ms.fit, X) | ||
|
|
||
|
|
||
| def test_seeds(): |
There was a problem hiding this comment.
I don't get what this is testing. Checking that parameters are maintained should usually be covered by common tests not tests for each specific estimator
| labels = ms.fit(X).labels_ | ||
| labels_unique = np.unique(labels) | ||
| n_clusters_ = len(labels_unique) | ||
| assert_equal(n_clusters_ > n_clusters, True) |
There was a problem hiding this comment.
Use bare assert as with seeds above
| n_clusters_ = len(labels_unique) | ||
| assert_equal(n_clusters_ > n_clusters, True) | ||
|
|
||
| cluster_centers, labels = mean_shift(X, bandwidth=bandwidth, |
There was a problem hiding this comment.
Rather than repeat the code, please use pytest.mark.parameterize to test multiple settings of bandwidth
There was a problem hiding this comment.
changed to use
pytest.mark.parameterize
b018e99 to
4cf6413
Compare
jnothman
left a comment
There was a problem hiding this comment.
I confirm this covers untested lines.
| bandwidth = -1 | ||
| ms = MeanShift(bandwidth=bandwidth) | ||
| msg = ("bandwidth needs to be greater than zero or None," | ||
| " got -1.000000") |
There was a problem hiding this comment.
This whitespace looks like an error in the code raising the message. Please change the code to have a single space between the comma and "got"
There was a problem hiding this comment.
This is unresolved. Please fix the error message in mean_shift_.py
| (1.2, True, 3), | ||
| (1.2, False, 4) | ||
| ]) | ||
| def test_eval(bandwidth, cluster_all, expected): |
There was a problem hiding this comment.
what do you mean by calling this "eval"? Can't we just paramertrize test_mean_shift above, rather than adding a new test?
There was a problem hiding this comment.
But ideally we should also test that cluster_all=False is actually effective at allowing some points to be left unclustered. Create a dataset where a point will be left with label -1 to test this properly.
4cf6413 to
dea8840
Compare
|
Please merge the current master |
| def test_mean_shift(): | ||
| @pytest.mark.parametrize("bandwidth, cluster_all, expected, " | ||
| "first_cluster_label", | ||
| [(1.2, True, 3, 0), (1.2, False, 4, -1)]) |
| bandwidth = -1 | ||
| ms = MeanShift(bandwidth=bandwidth) | ||
| msg = ("bandwidth needs to be greater than zero or None," | ||
| " got -1.000000") |
There was a problem hiding this comment.
This is unresolved. Please fix the error message in mean_shift_.py
bb1dd95 to
f40648d
Compare
|
@jnothman fixed the comments |
|
Thanks @rajdeepd |
|
@jnothman how do we get this pull request merged into master? |
|
4 days is not long to wait for a second review, @rajdeepd... hopefully one will come soon. |
| assert n_clusters_ == expected | ||
| assert labels_unique[0] == first_cluster_label | ||
|
|
||
| cluster_centers, labels = mean_shift(X, bandwidth=bandwidth) |
There was a problem hiding this comment.
Removing this means we are not testing the mean_shift function directly anymore.
There was a problem hiding this comment.
we are testing using
ms = MeanShift(bandwidth=bandwidth, cluster_all=cluster_all)
labels = ms.fit(X).labels_
There was a problem hiding this comment.
The testing of mean_shift should be independent of ms.fit. At the moment, ms.fit calls mean_shift, but we do not know how the code base will change.
There was a problem hiding this comment.
@thomasjpfan do we need another test for mean_shift?
There was a problem hiding this comment.
Leaving the original test here will sufficiently test mean_shift.
There was a problem hiding this comment.
@thomasjpfan added test for mean_shift as well
| ms = MeanShift(bandwidth=bandwidth) | ||
| msg = ("bandwidth needs to be greater than zero or None," | ||
| " got -1.000000") | ||
| assert_raise_message(ValueError, msg, ms.fit, X) |
There was a problem hiding this comment.
We are moving to using pytest.raises:
msg = (r"bandwidth needs to be greater than zero or None,"
r" got -1\.000000")
with pytest.raises(ValueError, match=msg):
ms.fit(X)71df239 to
1b9f928
Compare
| n_clusters_ = len(labels_unique) | ||
| assert_equal(n_clusters_, n_clusters) | ||
| cluster_centers, labels_mean_shift = mean_shift(X, cluster_all=cluster_all) | ||
| print(cluster_centers) |
| # n_neighbors is set to 1. | ||
| bandwidth = estimate_bandwidth(X, n_samples=1, quantile=0.3) | ||
| assert_array_almost_equal(bandwidth, 0., decimal=5) | ||
| assert_equal(bandwidth, 0.) |
There was a problem hiding this comment.
could just be assert a == b then
1b9f928 to
aa17ea1
Compare
|
Thanks @rajdeepd |
This reverts commit 67f53dc.
This reverts commit 67f53dc.
Reference Issues/PRs
none
What does this implement/fix? Explain your changes.
Add test cases to cover un-tested portions of mean_shift.py
Any other comments?
no other comments