Skip to content

Commit be9dd5e

Browse files
committed
Refactor inheritable thread logic, and use it in codebase for pinned thread mode
1 parent fdf86fd commit be9dd5e

4 files changed

Lines changed: 79 additions & 39 deletions

File tree

python/pyspark/context.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,6 +1111,7 @@ def setJobGroup(self, groupId, description, interruptOnCancel=False):
11111111
--------
11121112
>>> import threading
11131113
>>> from time import sleep
1114+
>>> from pyspark import InheritableThread
11141115
>>> result = "Not Set"
11151116
>>> lock = threading.Lock()
11161117
>>> def map_func(x):
@@ -1128,8 +1129,8 @@ def setJobGroup(self, groupId, description, interruptOnCancel=False):
11281129
... sleep(5)
11291130
... sc.cancelJobGroup("job_to_cancel")
11301131
>>> suppress = lock.acquire()
1131-
>>> suppress = threading.Thread(target=start_job, args=(10,)).start()
1132-
>>> suppress = threading.Thread(target=stop_job).start()
1132+
>>> suppress = InheritableThread(target=start_job, args=(10,)).start()
1133+
>>> suppress = InheritableThread(target=stop_job).start()
11331134
>>> suppress = lock.acquire()
11341135
>>> print(result)
11351136
Cancelled

python/pyspark/ml/classification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from abc import ABCMeta, abstractmethod, abstractproperty
2424
from multiprocessing.pool import ThreadPool
2525

26-
from pyspark import keyword_only, since, SparkContext
26+
from pyspark import keyword_only, since, SparkContext, inheritable_thread_target
2727
from pyspark.ml import Estimator, Predictor, PredictionModel, Model
2828
from pyspark.ml.param.shared import HasRawPredictionCol, HasProbabilityCol, HasThresholds, \
2929
HasRegParam, HasMaxIter, HasFitIntercept, HasTol, HasStandardization, HasWeightCol, \
@@ -2921,7 +2921,7 @@ def trainSingleClass(index):
29212921

29222922
pool = ThreadPool(processes=min(self.getParallelism(), numClasses))
29232923

2924-
models = pool.map(trainSingleClass, range(numClasses))
2924+
models = pool.map(inheritable_thread_target(trainSingleClass), range(numClasses))
29252925

29262926
if handlePersistence:
29272927
multiclassLabeled.unpersist()

python/pyspark/ml/tuning.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
import numpy as np
2626

27-
from pyspark import keyword_only, since, SparkContext
27+
from pyspark import keyword_only, since, SparkContext, inheritable_thread_target
2828
from pyspark.ml import Estimator, Transformer, Model
2929
from pyspark.ml.common import inherit_doc, _py2java, _java2py
3030
from pyspark.ml.evaluation import Evaluator
@@ -729,7 +729,9 @@ def _fit(self, dataset):
729729
validation = datasets[i][1].cache()
730730
train = datasets[i][0].cache()
731731

732-
tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam)
732+
tasks = map(
733+
inheritable_thread_target,
734+
_parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam))
733735
for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks):
734736
metrics[j] += (metric / nFolds)
735737
if collectSubModelsParam:
@@ -1261,7 +1263,9 @@ def _fit(self, dataset):
12611263
if collectSubModelsParam:
12621264
subModels = [None for i in range(numModels)]
12631265

1264-
tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam)
1266+
tasks = map(
1267+
inheritable_thread_target,
1268+
_parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam))
12651269
pool = ThreadPool(processes=min(self.getParallelism(), numModels))
12661270
metrics = [None] * numModels
12671271
for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks):

python/pyspark/util.py

Lines changed: 67 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -287,11 +287,13 @@ def inheritable_thread_target(f):
287287
-----
288288
This API is experimental.
289289
290-
It captures the local properties when you decorate it. Therefore, it is encouraged
291-
to decorate it when you want to capture the local properties.
290+
It is important to know that it captures the local properties when you decorate it
291+
whereas :class:`InheritableThread` captures when the thread is started.
292+
Therefore, it is encouraged to decorate it when you want to capture the local
293+
properties.
292294
293295
For example, the local properties from the current Spark context is captured
294-
when you define a function here:
296+
when you define a function here instead of the invocation:
295297
296298
>>> @inheritable_thread_target
297299
... def target_func():
@@ -305,35 +307,22 @@ def inheritable_thread_target(f):
305307
>>> Thread(target=inheritable_thread_target(target_func)).start() # doctest: +SKIP
306308
"""
307309
from pyspark import SparkContext
308-
if os.environ.get("PYSPARK_PIN_THREAD", "false").lower() == "true":
309-
# Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on.
310-
sc = SparkContext._active_spark_context
311310

312-
# Get local properties from main thread
313-
properties = sc._jsc.sc().getLocalProperties().clone()
311+
if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true":
312+
# NOTICE the internal difference vs `InheritableThread`. `InheritableThread`
313+
# copies local properties when the thread starts but `inheritable_thread_target`
314+
# copies when the function is wrapped.
315+
properties = SparkContext._active_spark_context._jsc.sc().getLocalProperties().clone()
314316

315317
@functools.wraps(f)
316-
def wrapped_f(*args, **kwargs):
318+
def wrapped(*args, **kwargs):
317319
try:
318320
# Set local properties in child thread.
319-
sc._jsc.sc().setLocalProperties(properties)
321+
SparkContext._active_spark_context._jsc.sc().setLocalProperties(properties)
320322
return f(*args, **kwargs)
321323
finally:
322-
thread_connection = sc._jvm._gateway_client.thread_connection.connection()
323-
if thread_connection is not None:
324-
connections = sc._jvm._gateway_client.deque
325-
# Reuse the lock for Py4J in PySpark
326-
with SparkContext._lock:
327-
for i in range(len(connections)):
328-
if connections[i] is thread_connection:
329-
connections[i].close()
330-
del connections[i]
331-
break
332-
else:
333-
# Just in case the connection was not closed but removed from the
334-
# queue.
335-
thread_connection.close()
336-
return wrapped_f
324+
InheritableThread._clean_py4j_conn_for_current_thread()
325+
return wrapped
337326
else:
338327
return f
339328

@@ -354,21 +343,67 @@ class InheritableThread(threading.Thread):
354343
355344
.. versionadded:: 3.1.0
356345
357-
358346
Notes
359347
-----
360348
This API is experimental.
361349
"""
362350
def __init__(self, target, *args, **kwargs):
363-
super(InheritableThread, self).__init__(
364-
target=inheritable_thread_target(target), *args, **kwargs
365-
)
351+
from pyspark import SparkContext
352+
353+
if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true":
354+
def copy_local_properties(*a, **k):
355+
# self._props is set before starting the thread to match the behavior with JVM.
356+
assert hasattr(self, "_props")
357+
SparkContext._active_spark_context._jsc.sc().setLocalProperties(self._props)
358+
try:
359+
return target(*a, **k)
360+
finally:
361+
InheritableThread._clean_py4j_conn_for_current_thread()
362+
363+
super(InheritableThread, self).__init__(
364+
target=copy_local_properties, *args, **kwargs)
365+
else:
366+
super(InheritableThread, self).__init__(target=target, *args, **kwargs)
366367

368+
def start(self, *args, **kwargs):
369+
from pyspark import SparkContext
367370

368-
if __name__ == "__main__":
369-
import doctest
371+
if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true":
372+
# Local property copy should happen in Thread.start to mimic JVM's behavior.
373+
self._props = SparkContext._active_spark_context._jsc.sc().getLocalProperties().clone()
374+
return super(InheritableThread, self).start(*args, **kwargs)
375+
376+
@staticmethod
377+
def _clean_py4j_conn_for_current_thread():
378+
from pyspark import SparkContext
379+
380+
jvm = SparkContext._jvm
381+
thread_connection = jvm._gateway_client.thread_connection.connection()
382+
if thread_connection is not None:
383+
connections = jvm._gateway_client.deque
384+
# Reuse the lock for Py4J in PySpark
385+
with SparkContext._lock:
386+
for i in range(len(connections)):
387+
if connections[i] is thread_connection:
388+
connections[i].close()
389+
del connections[i]
390+
break
391+
else:
392+
# Just in case the connection was not closed but removed from the
393+
# queue.
394+
thread_connection.close()
370395

396+
397+
if __name__ == "__main__":
371398
if "pypy" not in platform.python_implementation().lower() and sys.version_info[:2] >= (3, 7):
372-
(failure_count, test_count) = doctest.testmod()
399+
import doctest
400+
import pyspark.util
401+
from pyspark.context import SparkContext
402+
403+
globs = pyspark.util.__dict__.copy()
404+
globs['sc'] = SparkContext('local[4]', 'PythonTest')
405+
(failure_count, test_count) = doctest.testmod(pyspark.util, globs=globs)
406+
globs['sc'].stop()
407+
373408
if failure_count:
374409
sys.exit(-1)

0 commit comments

Comments
 (0)