Skip to content

Commit 799ac4e

Browse files
authored
Clean up the leftover connection for finished threads in pinned thread mode (#471)
## What is the problem? Correctly, there is resource leak when using the pinned thread mode (see also apache/spark#24898). For example, if you repeat the codes below multiple times to create Py4J connections in multiple threads, ```python # PySpark application import threading def print_prop(): # Py4J connection is used under the hood. print(spark.sparkContext.getLocalProperty("a")) threading.Thread(target=print_prop).start() ``` the number of leftover connections grows: ```python spark._jvm._gateway_client.deque deque([<py4j.clientserver.ClientServerConnection object at 0x7fdc60170940>, <py4j.clientserver.ClientServerConnection object at 0x7fdca011e760>, <py4j.clientserver.ClientServerConnection object at 0x7fdcb01acdc0>, <py4j.clientserver.ClientServerConnection object at 0x7fdc60170100>, <py4j.clientserver.ClientServerConnection object at 0x7fdcb0232d30>]) ``` In the environment where multiple threads are used without a pool, it easily causes "Too many files open" due to the lack of file descriptors (as they are all occupied by unclosed sockets). ## How do you fix? This PR adds another variable to thread local that cleans up the connection right before the thread is finished. We need it as a separate thread local because `connection` is NOT cleaned because the reference is being held at `JavaClient.deque`. See also 50fe45e for more details.
1 parent 9702c0f commit 799ac4e

2 files changed

Lines changed: 83 additions & 2 deletions

File tree

py4j-python/src/py4j/clientserver.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from __future__ import unicode_literals, absolute_import
1111

12-
from collections import deque
12+
from collections import deque, Callable
1313
import logging
1414
import socket
1515
from threading import local, Thread
@@ -244,7 +244,10 @@ def set_thread_connection(self, connection):
244244
:param connection: The ClientServerConnection to associate with the
245245
current thread.
246246
"""
247-
self.thread_connection.connection = weakref.ref(connection)
247+
conn = weakref.ref(connection)
248+
self.thread_connection._cleaner = (
249+
ThreadLocalConnectionFinalizer(conn, self.deque))
250+
self.thread_connection.connection = conn
248251

249252
def shutdown_gateway(self):
250253
try:
@@ -300,6 +303,39 @@ def _create_connection_guard(self, connection):
300303
return ClientServerConnectionGuard(self, connection)
301304

302305

306+
class ThreadLocalConnectionFinalizer(object):
307+
"""Cleans :class:`ClientServerConnection` held by a thread local by
308+
closing it properly and removing it from the :class:`JavaClient`
309+
deque. Right before the Python thread is terminated, this
310+
instance will be garbage-collected, which triggers a call
311+
to __del__ that contains the cleanup logic.
312+
"""
313+
def __init__(self, connection, dequeue):
314+
assert (
315+
isinstance(connection, Callable) and
316+
connection() is not None and
317+
isinstance(connection(), ClientServerConnection))
318+
self.connection = connection
319+
self.deque = dequeue
320+
321+
def __del__(self):
322+
"""Removes the connection associated with the current thread
323+
from the deque.
324+
325+
Expected to be called when the thread that started the
326+
connection is garbage-collected.
327+
"""
328+
conn = self.connection()
329+
if conn is not None:
330+
try:
331+
# This dequeue is thread-safe, and shared across other
332+
# threads.
333+
self.deque.remove(conn)
334+
except ValueError:
335+
# Should never reach this point
336+
pass
337+
338+
303339
class ClientServerConnectionGuard(GatewayConnectionGuard):
304340
"""Connection guard that does nothing on exit because there is no need to
305341
close or give back a connection.
@@ -603,6 +639,12 @@ def _get_params(self, input):
603639
temp = smart_decode(input.readline())[:-1]
604640
return params
605641

642+
def __del__(self):
643+
# In case new connection is set via
644+
# `JavaClient.set_thread_connection`, this connection will be
645+
# garbage-collected with closing the underlying socket properly.
646+
self.close()
647+
606648

607649
class ClientServer(JavaGateway):
608650
"""Subclass of JavaGateway that implements a different threading model: a

py4j-python/src/py4j/tests/client_server_test.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from py4j.tests.java_callback_test import IHelloImpl, IHelloFailingImpl
1515
from py4j.tests.java_gateway_test import (
1616
PY4J_JAVA_PATH, check_connection, sleep, WaitOperator)
17+
from py4j.tests.memory_leak_test import python_gc
1718
from py4j.tests.py4j_callback_recursive_example import (
1819
PythonPing, HelloState)
1920

@@ -147,6 +148,44 @@ def testSendObjects(self):
147148
client_server.shutdown()
148149
self.assertEquals(1000, hello.calls)
149150

151+
def testCleanConnections(self):
152+
"""This test intentionally create multiple connections in multiple
153+
threads so each connection is in a thread local of each thread.
154+
After that, it verifies that if the connection is cleaned and closed
155+
properly without a resource leak.
156+
"""
157+
with clientserver_example_app_process():
158+
client_server = ClientServer(
159+
JavaParameters(), PythonParameters())
160+
connections = client_server._gateway_client.deque
161+
conditions = []
162+
163+
def assert_connection():
164+
# Creates a connection.
165+
client_server.jvm.System.currentTimeMillis()
166+
# Should at least create one connection.
167+
conditions.append(0 < len(connections))
168+
169+
threads = [
170+
threading.Thread(target=assert_connection),
171+
threading.Thread(target=assert_connection),
172+
threading.Thread(target=assert_connection),
173+
]
174+
for t in threads:
175+
t.start()
176+
for t in threads:
177+
t.join()
178+
179+
# Here we explicitly call garbage collection to clean
180+
# `ClientServerConnection`s by
181+
# `JavaClient.ThreadLocalConnectionCleaner`
182+
python_gc()
183+
184+
# Should have zero connections left.
185+
self.assertEqual(len(client_server._gateway_client.deque), 0)
186+
self.assertTrue(all(connections))
187+
client_server.shutdown()
188+
150189

151190
class RetryTest(unittest.TestCase):
152191

0 commit comments

Comments
 (0)