@@ -104,44 +104,34 @@ def __init__(
104104 ShapelyDeprecationWarning ,
105105 stacklevel = 2 ,
106106 )
107- self ._tree = None
108107 self .node_capacity = node_capacity
109- self ._rev = {
110- item : geom
111- for geom , item in self ._iterinitdata (geoms , items )
112- if not geom .is_empty
113- }
114- if self ._rev :
115- self ._init_tree (self ._rev .items ())
116-
117- def _iterinitdata (
118- self ,
119- geoms : Iterable [BaseGeometry ], items : Optional [Iterable [BaseGeometry ]],
120- ) -> Iterator [Tuple [BaseGeometry , Any ]]:
121- if items is not None :
122- for geom , item in zip (geoms , items ):
123- if isinstance (geom , BaseGeometry ):
124- yield (geom , item )
108+
109+ # Keep references to geoms
110+ self ._geoms = list (geoms )
111+ # Default enumeration index to store in the tree
112+ self ._idxs = list (range (len (self ._geoms )))
113+
114+ # handle items
115+ self ._has_custom_items = items is not None
116+ if not self ._has_custom_items :
117+ items = self ._idxs
118+ self ._items = items
119+
120+ # initialize GEOS STRtree
121+ self ._tree = lgeos .GEOSSTRtree_create (self .node_capacity )
122+ i = 0
123+ for idx , geom in zip (self ._idxs , self ._geoms ):
124+ # filter empty geometries out of the input
125+ if geom is not None and not geom .is_empty :
126+ lgeos .GEOSSTRtree_insert (self ._tree , geom ._geom , ctypes .py_object (idx ))
127+ i += 1
128+ self ._n_geoms = i
129+
130+ def __reduce__ (self ):
131+ if self ._has_custom_items :
132+ return STRtree , (self ._geoms , self ._items )
125133 else :
126- for enum_idx , geom in enumerate (geoms ):
127- if isinstance (geom , BaseGeometry ):
128- yield (geom , enum_idx )
129-
130- def _init_tree (self , rev_initdata : ItemsView [Any , BaseGeometry ]):
131- if rev_initdata :
132- self ._tree = lgeos .GEOSSTRtree_create (self .node_capacity )
133- for item , geom in rev_initdata :
134- lgeos .GEOSSTRtree_insert (self ._tree , geom ._geom , ctypes .py_object (item ))
135-
136- def __getstate__ (self ):
137- state = self .__dict__ .copy ()
138- del state ["_tree" ]
139- return state
140-
141- def __setstate__ (self , state ):
142- self .__dict__ .update (state )
143- if self ._rev :
144- self ._init_tree (self ._rev .items ())
134+ return STRtree , (self ._geoms , )
145135
146136 def __del__ (self ):
147137 if self ._tree is not None :
@@ -152,6 +142,19 @@ def __del__(self):
152142
153143 self ._tree = None
154144
145+ def _query (self , geom ):
146+ if self ._n_geoms == 0 :
147+ return []
148+
149+ result = []
150+
151+ def callback (item , userdata ):
152+ idx = ctypes .cast (item , ctypes .py_object ).value
153+ result .append (idx )
154+
155+ lgeos .GEOSSTRtree_query (self ._tree , geom ._geom , lgeos .GEOSQueryCallback (callback ), None )
156+ return result
157+
155158 def query_items (self , geom : BaseGeometry ) -> Sequence [Any ]:
156159 """Query for nodes which intersect the geom's envelope to get
157160 stored items.
@@ -197,19 +200,11 @@ def query_items(self, geom: BaseGeometry) -> Sequence[Any]:
197200 ['POINT (2 2)']
198201
199202 """
200- if self ._tree is None or not self ._rev :
201- return []
202-
203- result = []
204-
205- def callback (item , userdata ):
206- idx = ctypes .cast (item , ctypes .py_object ).value
207- result .append (idx )
208-
209- lgeos .GEOSSTRtree_query (
210- self ._tree , geom ._geom , lgeos .GEOSQueryCallback (callback ), None
211- )
212- return result
203+ result = self ._query (geom )
204+ if self ._has_custom_items :
205+ return [self ._items [i ] for i in result ]
206+ else :
207+ return result
213208
214209 def query_geoms (self , geom : BaseGeometry ) -> Sequence [BaseGeometry ]:
215210 """Query for nodes which intersect the geom's envelope to get
@@ -225,8 +220,8 @@ def query_geoms(self, geom: BaseGeometry) -> Sequence[BaseGeometry]:
225220 An array or list of geometry objects.
226221
227222 """
228- items = self .query_items (geom )
229- return [self ._rev [ idx ] for idx in items ]
223+ result = self ._query (geom )
224+ return [self ._geoms [ i ] for i in result ]
230225
231226 def query (self , geom : BaseGeometry ) -> Sequence [BaseGeometry ]:
232227 """Query for nodes which intersect the geom's envelope to get
@@ -247,6 +242,34 @@ def query(self, geom: BaseGeometry) -> Sequence[BaseGeometry]:
247242 """
248243 return self .query_geoms (geom )
249244
245+ def _nearest (self , geom , exclusive ):
246+ envelope = geom .envelope
247+
248+ def callback (item1 , item2 , distance , userdata ):
249+ try :
250+ callback_userdata = ctypes .cast (userdata , ctypes .py_object ).value
251+ idx = ctypes .cast (item1 , ctypes .py_object ).value
252+ geom2 = ctypes .cast (item2 , ctypes .py_object ).value
253+ dist = ctypes .cast (distance , ctypes .POINTER (ctypes .c_double ))
254+ if callback_userdata ["exclusive" ] and self ._geoms [idx ].equals (geom2 ):
255+ dist [0 ] = sys .float_info .max
256+ else :
257+ lgeos .GEOSDistance (self ._geoms [idx ]._geom , geom2 ._geom , dist )
258+
259+ return 1
260+ except Exception :
261+ log .exception ("Caught exception" )
262+ return 0
263+
264+ item = lgeos .GEOSSTRtree_nearest_generic (
265+ self ._tree ,
266+ ctypes .py_object (geom ),
267+ envelope ._geom ,
268+ lgeos .GEOSDistanceCallback (callback ),
269+ ctypes .py_object ({"exclusive" : exclusive }),
270+ )
271+ return ctypes .cast (item , ctypes .py_object ).value
272+
250273 def nearest_item (
251274 self , geom : BaseGeometry , exclusive : bool = False
252275 ) -> Union [Any , None ]:
@@ -285,35 +308,14 @@ def nearest_item(
285308 'POINT (0 0)'
286309
287310 """
288- if self ._tree is None or not self . _rev :
311+ if self ._n_geoms == 0 :
289312 return None
290313
291- envelope = geom .envelope
292-
293- def callback (item1 , item2 , distance , userdata ):
294- try :
295- callback_userdata = ctypes .cast (userdata , ctypes .py_object ).value
296- idx = ctypes .cast (item1 , ctypes .py_object ).value
297- geom2 = ctypes .cast (item2 , ctypes .py_object ).value
298- dist = ctypes .cast (distance , ctypes .POINTER (ctypes .c_double ))
299- if callback_userdata ["exclusive" ] and self ._rev [idx ].equals (geom2 ):
300- dist [0 ] = sys .float_info .max
301- else :
302- lgeos .GEOSDistance (self ._rev [idx ]._geom , geom2 ._geom , dist )
303- return 1
304- except Exception :
305- log .exception ("Caught exception" )
306- return 0
307-
308- item = lgeos .GEOSSTRtree_nearest_generic (
309- self ._tree ,
310- ctypes .py_object (geom ),
311- envelope ._geom ,
312- lgeos .GEOSDistanceCallback (callback ),
313- ctypes .py_object ({"exclusive" : exclusive }),
314- )
315- result = ctypes .cast (item , ctypes .py_object ).value
316- return result
314+ result = self ._nearest (geom , exclusive )
315+ if self ._has_custom_items :
316+ return self ._items [result ]
317+ else :
318+ return result
317319
318320 def nearest_geom (
319321 self , geom : BaseGeometry , exclusive : bool = False
@@ -337,11 +339,8 @@ def nearest_geom(
337339 version 2.0.
338340
339341 """
340- item = self .nearest_item (geom , exclusive = exclusive )
341- if item is None :
342- return None
343- else :
344- return self ._rev [item ]
342+ result = self ._nearest (geom , exclusive )
343+ return self ._geoms [result ]
345344
346345 def nearest (
347346 self , geom : BaseGeometry , exclusive : bool = False
0 commit comments