|
2 | 2 |
|
3 | 3 | import numpy as np |
4 | 4 | import pytest |
5 | | -from numpy.testing import assert_array_almost_equal |
6 | | -from sklearn.neighbors._ball_tree import BallTree |
| 5 | +from numpy.testing import assert_array_almost_equal, assert_allclose, assert_equal |
| 6 | +from sklearn.neighbors._ball_tree import BallTree, BallTree32 |
7 | 7 | from sklearn.utils import check_random_state |
8 | 8 | from sklearn.utils.validation import check_array |
9 | 9 | from sklearn.utils._testing import _convert_container |
@@ -101,3 +101,40 @@ def one_arg_func(x): |
101 | 101 | msg = "takes 1 positional argument but 2 were given" |
102 | 102 | with pytest.raises(TypeError, match=msg): |
103 | 103 | BallTree(X, metric=one_arg_func) |
| 104 | + |
| 105 | + |
| 106 | +@pytest.mark.parametrize("metric", itertools.chain(METRICS, BOOLEAN_METRICS)) |
| 107 | +def test_ball_tree_numerical_consistency(metric): |
| 108 | + _X = rng.random_sample((40, 3)).round(0) |
| 109 | + _Y = rng.random_sample((10, 3)).round(0) |
| 110 | + |
| 111 | + X_64 = _X.astype(dtype=np.float64) |
| 112 | + Y_64 = _Y.astype(dtype=np.float64) |
| 113 | + |
| 114 | + X_32 = _X.astype(dtype=np.float32) |
| 115 | + Y_32 = _Y.astype(dtype=np.float32) |
| 116 | + |
| 117 | + metric_params = METRICS.get(metric, {}) |
| 118 | + bt_64 = BallTree(X_64, leaf_size=1, metric=metric, **metric_params) |
| 119 | + bt_32 = BallTree32(X_32, leaf_size=1, metric=metric, **metric_params) |
| 120 | + |
| 121 | + # Test consistency with respect to the `query` method |
| 122 | + k = 5 |
| 123 | + dist_64, ind_64 = bt_64.query(Y_64, k=k) |
| 124 | + dist_32, ind_32 = bt_32.query(Y_32, k=k) |
| 125 | + assert_allclose(dist_64, dist_32) |
| 126 | + assert_equal(ind_64, ind_32) |
| 127 | + |
| 128 | + # Test consistency with respect to the `query_radius` method |
| 129 | + r = 0.3 |
| 130 | + ind_64, neighbors_64 = bt_64.query_radius(Y_64[0:2, :], r=r) |
| 131 | + ind_32, neighbors_32 = bt_32.query_radius(Y_32[0:2, :], r=r) |
| 132 | + assert_equal(ind_64, ind_32) |
| 133 | + assert_allclose(neighbors_64, neighbors_32) |
| 134 | + |
| 135 | + # Test consistency with respect to the `kernel_density` method |
| 136 | + kernel = "gaussian" |
| 137 | + h = 0.1 |
| 138 | + density64 = bt_64.kernel_density(Y_64, h=h, kernel=kernel) |
| 139 | + density32 = bt_32.kernel_density(Y_32, h=h, kernel=kernel) |
| 140 | + assert_allclose(density64, density32) |
0 commit comments