Skip to content

Commit 301afc9

Browse files
committed
Add tests for ball tree
1 parent 98d4f68 commit 301afc9

3 files changed

Lines changed: 45 additions & 8 deletions

File tree

sklearn/neighbors/_ball_tree.pyx.tp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ cdef int init_node{{name_suffix}}(
9696
cdef intp_t n_points = idx_end - idx_start
9797

9898
cdef intp_t i, j
99-
cdef float64_t radius
99+
cdef {{INPUT_DTYPE_t}} radius
100100
cdef {{INPUT_DTYPE_t}} *this_pt
101101

102102
cdef intp_t* idx_array = &tree.idx_array[0]
@@ -146,7 +146,7 @@ cdef int init_node{{name_suffix}}(
146146
return 0
147147

148148

149-
cdef inline float64_t min_dist{{name_suffix}}(
149+
cdef inline {{INPUT_DTYPE_t}} min_dist{{name_suffix}}(
150150
BinaryTree{{name_suffix}} tree,
151151
intp_t i_node,
152152
{{INPUT_DTYPE_t}}* pt,
@@ -157,7 +157,7 @@ cdef inline float64_t min_dist{{name_suffix}}(
157157
return fmax(0, dist_pt - tree.node_data[i_node].radius)
158158

159159

160-
cdef inline float64_t max_dist{{name_suffix}}(
160+
cdef inline {{INPUT_DTYPE_t}} max_dist{{name_suffix}}(
161161
BinaryTree{{name_suffix}} tree,
162162
intp_t i_node,
163163
{{INPUT_DTYPE_t}}* pt,
@@ -216,7 +216,7 @@ cdef inline float64_t max_rdist{{name_suffix}}(
216216
)
217217

218218

219-
cdef inline float64_t min_dist_dual{{name_suffix}}(
219+
cdef inline {{INPUT_DTYPE_t}} min_dist_dual{{name_suffix}}(
220220
BinaryTree{{name_suffix}} tree1,
221221
intp_t i_node1,
222222
BinaryTree{{name_suffix}} tree2,
@@ -230,7 +230,7 @@ cdef inline float64_t min_dist_dual{{name_suffix}}(
230230
- tree2.node_data[i_node2].radius))
231231

232232

233-
cdef inline float64_t max_dist_dual{{name_suffix}}(
233+
cdef inline {{INPUT_DTYPE_t}} max_dist_dual{{name_suffix}}(
234234
BinaryTree{{name_suffix}} tree1,
235235
intp_t i_node1,
236236
BinaryTree{{name_suffix}} tree2,

sklearn/neighbors/_binary_tree.pxi.tp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -871,7 +871,7 @@ cdef class BinaryTree{{name_suffix}}:
871871
self.euclidean = (self.dist_metric.__class__.__name__
872872
== 'EuclideanDistance')
873873

874-
metric = self.dist_metric.__class__.__name__
874+
metric = self.dist_metric.__class__.__name__.rstrip("32")
875875
if metric not in VALID_METRICS:
876876
raise ValueError('metric {metric} is not valid for '
877877
'{BinaryTree}'.format(metric=metric,

sklearn/neighbors/tests/test_ball_tree.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import numpy as np
44
import pytest
5-
from numpy.testing import assert_array_almost_equal
6-
from sklearn.neighbors._ball_tree import BallTree
5+
from numpy.testing import assert_array_almost_equal, assert_allclose, assert_equal
6+
from sklearn.neighbors._ball_tree import BallTree, BallTree32
77
from sklearn.utils import check_random_state
88
from sklearn.utils.validation import check_array
99
from sklearn.utils._testing import _convert_container
@@ -101,3 +101,40 @@ def one_arg_func(x):
101101
msg = "takes 1 positional argument but 2 were given"
102102
with pytest.raises(TypeError, match=msg):
103103
BallTree(X, metric=one_arg_func)
104+
105+
106+
@pytest.mark.parametrize("metric", itertools.chain(METRICS, BOOLEAN_METRICS))
107+
def test_ball_tree_numerical_consistency(metric):
108+
_X = rng.random_sample((40, 3)).round(0)
109+
_Y = rng.random_sample((10, 3)).round(0)
110+
111+
X_64 = _X.astype(dtype=np.float64)
112+
Y_64 = _Y.astype(dtype=np.float64)
113+
114+
X_32 = _X.astype(dtype=np.float32)
115+
Y_32 = _Y.astype(dtype=np.float32)
116+
117+
metric_params = METRICS.get(metric, {})
118+
bt_64 = BallTree(X_64, leaf_size=1, metric=metric, **metric_params)
119+
bt_32 = BallTree32(X_32, leaf_size=1, metric=metric, **metric_params)
120+
121+
# Test consistency with respect to the `query` method
122+
k = 5
123+
dist_64, ind_64 = bt_64.query(Y_64, k=k)
124+
dist_32, ind_32 = bt_32.query(Y_32, k=k)
125+
assert_allclose(dist_64, dist_32)
126+
assert_equal(ind_64, ind_32)
127+
128+
# Test consistency with respect to the `query_radius` method
129+
r = 0.3
130+
ind_64, neighbors_64 = bt_64.query_radius(Y_64[0:2, :], r=r)
131+
ind_32, neighbors_32 = bt_32.query_radius(Y_32[0:2, :], r=r)
132+
assert_equal(ind_64, ind_32)
133+
assert_allclose(neighbors_64, neighbors_32)
134+
135+
# Test consistency with respect to the `kernel_density` method
136+
kernel = "gaussian"
137+
h = 0.1
138+
density64 = bt_64.kernel_density(Y_64, h=h, kernel=kernel)
139+
density32 = bt_32.kernel_density(Y_32, h=h, kernel=kernel)
140+
assert_allclose(density64, density32)

0 commit comments

Comments
 (0)