@@ -287,11 +287,13 @@ def inheritable_thread_target(f):
287287 -----
288288 This API is experimental.
289289
290- It captures the local properties when you decorate it. Therefore, it is encouraged
291- to decorate it when you want to capture the local properties.
290+ It is important to know that it captures the local properties when you decorate it
291+ whereas :class:`InheritableThread` captures when the thread is started.
292+ Therefore, it is encouraged to decorate it when you want to capture the local
293+ properties.
292294
293295 For example, the local properties from the current Spark context is captured
294- when you define a function here:
296+ when you define a function here instead of the invocation :
295297
296298 >>> @inheritable_thread_target
297299 ... def target_func():
@@ -305,35 +307,22 @@ def inheritable_thread_target(f):
305307 >>> Thread(target=inheritable_thread_target(target_func)).start() # doctest: +SKIP
306308 """
307309 from pyspark import SparkContext
308- if os .environ .get ("PYSPARK_PIN_THREAD" , "false" ).lower () == "true" :
309- # Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on.
310- sc = SparkContext ._active_spark_context
311310
312- # Get local properties from main thread
313- properties = sc ._jsc .sc ().getLocalProperties ().clone ()
311+ if os .environ .get ("PYSPARK_PIN_THREAD" , "true" ).lower () == "true" :
312+ # NOTICE the internal difference vs `InheritableThread`. `InheritableThread`
313+ # copies local properties when the thread starts but `inheritable_thread_target`
314+ # copies when the function is wrapped.
315+ properties = SparkContext ._active_spark_context ._jsc .sc ().getLocalProperties ().clone ()
314316
315317 @functools .wraps (f )
316- def wrapped_f (* args , ** kwargs ):
318+ def wrapped (* args , ** kwargs ):
317319 try :
318320 # Set local properties in child thread.
319- sc ._jsc .sc ().setLocalProperties (properties )
321+ SparkContext . _active_spark_context ._jsc .sc ().setLocalProperties (properties )
320322 return f (* args , ** kwargs )
321323 finally :
322- thread_connection = sc ._jvm ._gateway_client .thread_connection .connection ()
323- if thread_connection is not None :
324- connections = sc ._jvm ._gateway_client .deque
325- # Reuse the lock for Py4J in PySpark
326- with SparkContext ._lock :
327- for i in range (len (connections )):
328- if connections [i ] is thread_connection :
329- connections [i ].close ()
330- del connections [i ]
331- break
332- else :
333- # Just in case the connection was not closed but removed from the
334- # queue.
335- thread_connection .close ()
336- return wrapped_f
324+ InheritableThread ._clean_py4j_conn_for_current_thread ()
325+ return wrapped
337326 else :
338327 return f
339328
@@ -354,21 +343,67 @@ class InheritableThread(threading.Thread):
354343
355344 .. versionadded:: 3.1.0
356345
357-
358346 Notes
359347 -----
360348 This API is experimental.
361349 """
362350 def __init__ (self , target , * args , ** kwargs ):
363- super (InheritableThread , self ).__init__ (
364- target = inheritable_thread_target (target ), * args , ** kwargs
365- )
351+ from pyspark import SparkContext
352+
353+ if os .environ .get ("PYSPARK_PIN_THREAD" , "true" ).lower () == "true" :
354+ def copy_local_properties (* a , ** k ):
355+ # self._props is set before starting the thread to match the behavior with JVM.
356+ assert hasattr (self , "_props" )
357+ SparkContext ._active_spark_context ._jsc .sc ().setLocalProperties (self ._props )
358+ try :
359+ return target (* a , ** k )
360+ finally :
361+ InheritableThread ._clean_py4j_conn_for_current_thread ()
362+
363+ super (InheritableThread , self ).__init__ (
364+ target = copy_local_properties , * args , ** kwargs )
365+ else :
366+ super (InheritableThread , self ).__init__ (target = target , * args , ** kwargs )
366367
368+ def start (self , * args , ** kwargs ):
369+ from pyspark import SparkContext
367370
368- if __name__ == "__main__" :
369- import doctest
371+ if os .environ .get ("PYSPARK_PIN_THREAD" , "true" ).lower () == "true" :
372+ # Local property copy should happen in Thread.start to mimic JVM's behavior.
373+ self ._props = SparkContext ._active_spark_context ._jsc .sc ().getLocalProperties ().clone ()
374+ return super (InheritableThread , self ).start (* args , ** kwargs )
375+
376+ @staticmethod
377+ def _clean_py4j_conn_for_current_thread ():
378+ from pyspark import SparkContext
379+
380+ jvm = SparkContext ._jvm
381+ thread_connection = jvm ._gateway_client .thread_connection .connection ()
382+ if thread_connection is not None :
383+ connections = jvm ._gateway_client .deque
384+ # Reuse the lock for Py4J in PySpark
385+ with SparkContext ._lock :
386+ for i in range (len (connections )):
387+ if connections [i ] is thread_connection :
388+ connections [i ].close ()
389+ del connections [i ]
390+ break
391+ else :
392+ # Just in case the connection was not closed but removed from the
393+ # queue.
394+ thread_connection .close ()
370395
396+
397+ if __name__ == "__main__" :
371398 if "pypy" not in platform .python_implementation ().lower () and sys .version_info [:2 ] >= (3 , 7 ):
372- (failure_count , test_count ) = doctest .testmod ()
399+ import doctest
400+ import pyspark .util
401+ from pyspark .context import SparkContext
402+
403+ globs = pyspark .util .__dict__ .copy ()
404+ globs ['sc' ] = SparkContext ('local[4]' , 'PythonTest' )
405+ (failure_count , test_count ) = doctest .testmod (pyspark .util , globs = globs )
406+ globs ['sc' ].stop ()
407+
373408 if failure_count :
374409 sys .exit (- 1 )
0 commit comments