Skip to content

Commit 6a10369

Browse files
Refactor STRtree implementation to store geoms/idx sequences instead of reverse mapping
1 parent 9a76173 commit 6a10369

1 file changed

Lines changed: 82 additions & 83 deletions

File tree

shapely/strtree.py

Lines changed: 82 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -104,44 +104,34 @@ def __init__(
104104
ShapelyDeprecationWarning,
105105
stacklevel=2,
106106
)
107-
self._tree = None
108107
self.node_capacity = node_capacity
109-
self._rev = {
110-
item: geom
111-
for geom, item in self._iterinitdata(geoms, items)
112-
if not geom.is_empty
113-
}
114-
if self._rev:
115-
self._init_tree(self._rev.items())
116-
117-
def _iterinitdata(
118-
self,
119-
geoms: Iterable[BaseGeometry], items: Optional[Iterable[BaseGeometry]],
120-
) -> Iterator[Tuple[BaseGeometry, Any]]:
121-
if items is not None:
122-
for geom, item in zip(geoms, items):
123-
if isinstance(geom, BaseGeometry):
124-
yield (geom, item)
108+
109+
# Keep references to geoms
110+
self._geoms = list(geoms)
111+
# Default enumeration index to store in the tree
112+
self._idxs = list(range(len(self._geoms)))
113+
114+
# handle items
115+
self._has_custom_items = items is not None
116+
if not self._has_custom_items:
117+
items = self._idxs
118+
self._items = items
119+
120+
# initialize GEOS STRtree
121+
self._tree = lgeos.GEOSSTRtree_create(self.node_capacity)
122+
i = 0
123+
for idx, geom in zip(self._idxs, self._geoms):
124+
# filter empty geometries out of the input
125+
if geom is not None and not geom.is_empty:
126+
lgeos.GEOSSTRtree_insert(self._tree, geom._geom, ctypes.py_object(idx))
127+
i += 1
128+
self._n_geoms = i
129+
130+
def __reduce__(self):
131+
if self._has_custom_items:
132+
return STRtree, (self._geoms, self._items)
125133
else:
126-
for enum_idx, geom in enumerate(geoms):
127-
if isinstance(geom, BaseGeometry):
128-
yield (geom, enum_idx)
129-
130-
def _init_tree(self, rev_initdata: ItemsView[Any, BaseGeometry]):
131-
if rev_initdata:
132-
self._tree = lgeos.GEOSSTRtree_create(self.node_capacity)
133-
for item, geom in rev_initdata:
134-
lgeos.GEOSSTRtree_insert(self._tree, geom._geom, ctypes.py_object(item))
135-
136-
def __getstate__(self):
137-
state = self.__dict__.copy()
138-
del state["_tree"]
139-
return state
140-
141-
def __setstate__(self, state):
142-
self.__dict__.update(state)
143-
if self._rev:
144-
self._init_tree(self._rev.items())
134+
return STRtree, (self._geoms, )
145135

146136
def __del__(self):
147137
if self._tree is not None:
@@ -152,6 +142,19 @@ def __del__(self):
152142

153143
self._tree = None
154144

145+
def _query(self, geom):
146+
if self._n_geoms == 0:
147+
return []
148+
149+
result = []
150+
151+
def callback(item, userdata):
152+
idx = ctypes.cast(item, ctypes.py_object).value
153+
result.append(idx)
154+
155+
lgeos.GEOSSTRtree_query(self._tree, geom._geom, lgeos.GEOSQueryCallback(callback), None)
156+
return result
157+
155158
def query_items(self, geom: BaseGeometry) -> Sequence[Any]:
156159
"""Query for nodes which intersect the geom's envelope to get
157160
stored items.
@@ -197,19 +200,11 @@ def query_items(self, geom: BaseGeometry) -> Sequence[Any]:
197200
['POINT (2 2)']
198201
199202
"""
200-
if self._tree is None or not self._rev:
201-
return []
202-
203-
result = []
204-
205-
def callback(item, userdata):
206-
idx = ctypes.cast(item, ctypes.py_object).value
207-
result.append(idx)
208-
209-
lgeos.GEOSSTRtree_query(
210-
self._tree, geom._geom, lgeos.GEOSQueryCallback(callback), None
211-
)
212-
return result
203+
result = self._query(geom)
204+
if self._has_custom_items:
205+
return [self._items[i] for i in result]
206+
else:
207+
return result
213208

214209
def query_geoms(self, geom: BaseGeometry) -> Sequence[BaseGeometry]:
215210
"""Query for nodes which intersect the geom's envelope to get
@@ -225,8 +220,8 @@ def query_geoms(self, geom: BaseGeometry) -> Sequence[BaseGeometry]:
225220
An array or list of geometry objects.
226221
227222
"""
228-
items = self.query_items(geom)
229-
return [self._rev[idx] for idx in items]
223+
result = self._query(geom)
224+
return [self._geoms[i] for i in result]
230225

231226
def query(self, geom: BaseGeometry) -> Sequence[BaseGeometry]:
232227
"""Query for nodes which intersect the geom's envelope to get
@@ -247,6 +242,34 @@ def query(self, geom: BaseGeometry) -> Sequence[BaseGeometry]:
247242
"""
248243
return self.query_geoms(geom)
249244

245+
def _nearest(self, geom, exclusive):
246+
envelope = geom.envelope
247+
248+
def callback(item1, item2, distance, userdata):
249+
try:
250+
callback_userdata = ctypes.cast(userdata, ctypes.py_object).value
251+
idx = ctypes.cast(item1, ctypes.py_object).value
252+
geom2 = ctypes.cast(item2, ctypes.py_object).value
253+
dist = ctypes.cast(distance, ctypes.POINTER(ctypes.c_double))
254+
if callback_userdata["exclusive"] and self._geoms[idx].equals(geom2):
255+
dist[0] = sys.float_info.max
256+
else:
257+
lgeos.GEOSDistance(self._geoms[idx]._geom, geom2._geom, dist)
258+
259+
return 1
260+
except Exception:
261+
log.exception("Caught exception")
262+
return 0
263+
264+
item = lgeos.GEOSSTRtree_nearest_generic(
265+
self._tree,
266+
ctypes.py_object(geom),
267+
envelope._geom,
268+
lgeos.GEOSDistanceCallback(callback),
269+
ctypes.py_object({"exclusive": exclusive}),
270+
)
271+
return ctypes.cast(item, ctypes.py_object).value
272+
250273
def nearest_item(
251274
self, geom: BaseGeometry, exclusive: bool = False
252275
) -> Union[Any, None]:
@@ -285,35 +308,14 @@ def nearest_item(
285308
'POINT (0 0)'
286309
287310
"""
288-
if self._tree is None or not self._rev:
311+
if self._n_geoms == 0:
289312
return None
290313

291-
envelope = geom.envelope
292-
293-
def callback(item1, item2, distance, userdata):
294-
try:
295-
callback_userdata = ctypes.cast(userdata, ctypes.py_object).value
296-
idx = ctypes.cast(item1, ctypes.py_object).value
297-
geom2 = ctypes.cast(item2, ctypes.py_object).value
298-
dist = ctypes.cast(distance, ctypes.POINTER(ctypes.c_double))
299-
if callback_userdata["exclusive"] and self._rev[idx].equals(geom2):
300-
dist[0] = sys.float_info.max
301-
else:
302-
lgeos.GEOSDistance(self._rev[idx]._geom, geom2._geom, dist)
303-
return 1
304-
except Exception:
305-
log.exception("Caught exception")
306-
return 0
307-
308-
item = lgeos.GEOSSTRtree_nearest_generic(
309-
self._tree,
310-
ctypes.py_object(geom),
311-
envelope._geom,
312-
lgeos.GEOSDistanceCallback(callback),
313-
ctypes.py_object({"exclusive": exclusive}),
314-
)
315-
result = ctypes.cast(item, ctypes.py_object).value
316-
return result
314+
result = self._nearest(geom, exclusive)
315+
if self._has_custom_items:
316+
return self._items[result]
317+
else:
318+
return result
317319

318320
def nearest_geom(
319321
self, geom: BaseGeometry, exclusive: bool = False
@@ -337,11 +339,8 @@ def nearest_geom(
337339
version 2.0.
338340
339341
"""
340-
item = self.nearest_item(geom, exclusive=exclusive)
341-
if item is None:
342-
return None
343-
else:
344-
return self._rev[item]
342+
result = self._nearest(geom, exclusive)
343+
return self._geoms[result]
345344

346345
def nearest(
347346
self, geom: BaseGeometry, exclusive: bool = False

0 commit comments

Comments
 (0)