Skip to content

Commit 44a17cb

Browse files
Allow multi_process_shared objects to be called (#26202)
* Allow multi_process_shared objects to be called * Allow multi_process_shared objects to be called (fixed, test passing) * formatting * Update sdks/python/apache_beam/utils/multi_process_shared.py Co-authored-by: Anand Inguva <34158215+AnandInguva@users.noreply.github.com> * Type hint * Type hint --------- Co-authored-by: Anand Inguva <34158215+AnandInguva@users.noreply.github.com>
1 parent fbc7df4 commit 44a17cb

2 files changed

Lines changed: 55 additions & 2 deletions

File tree

sdks/python/apache_beam/utils/multi_process_shared.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ def __init__(self, entry):
5050
self._SingletonProxy_entry = entry
5151
self._SingletonProxy_valid = True
5252

53+
# Used to make the shared object callable (see _AutoProxyWrapper below)
54+
def singletonProxy_call__(self, *args, **kwargs):
55+
if not self._SingletonProxy_valid:
56+
raise RuntimeError('Entry was released.')
57+
return self._SingletonProxy_entry.obj.__call__(*args, **kwargs)
58+
5359
def _SingletonProxy_release(self):
5460
assert self._SingletonProxy_valid
5561
self._SingletonProxy_valid = False
@@ -61,7 +67,9 @@ def __getattr__(self, name):
6167

6268
def __dir__(self):
6369
# Needed for multiprocessing.managers's proxying.
64-
return self._SingletonProxy_entry.obj.__dir__()
70+
dir = self._SingletonProxy_entry.obj.__dir__()
71+
dir.append('singletonProxy_call__')
72+
return dir
6573

6674

6775
class _SingletonEntry:
@@ -127,6 +135,24 @@ class _SingletonRegistrar(multiprocessing.managers.BaseManager):
127135
callable=_process_level_singleton_manager.release_singleton)
128136

129137

138+
# By default, objects registered with BaseManager.register will have only
139+
# public methods available (excluding __call__). If you know the functions
140+
# you would like to expose, you can do so at register time with the `exposed`
141+
# attribute. Since we don't, we will add a wrapper around the returned AutoProxy
142+
# object to handle __call__ function calls and turn them into
143+
# singletonProxy_call__ calls (which is a wrapper around the underlying
144+
# object's __call__ function)
145+
class _AutoProxyWrapper:
146+
def __init__(self, proxyObject: multiprocessing.managers.BaseProxy):
147+
self._proxyObject = proxyObject
148+
149+
def __call__(self, *args, **kwargs):
150+
return self._proxyObject.singletonProxy_call__(*args, **kwargs)
151+
152+
def __getattr__(self, name):
153+
return getattr(self._proxyObject, name)
154+
155+
130156
class MultiProcessShared(Generic[T]):
131157
"""MultiProcessShared is used to share a single object across processes.
132158
@@ -223,7 +249,8 @@ def acquire(self):
223249
# inputs)
224250
# Caveat: They must always agree, as they will be ignored if the object
225251
# is already constructed.
226-
return self._get_manager().acquire_singleton(self._tag)
252+
singleton = self._get_manager().acquire_singleton(self._tag)
253+
return _AutoProxyWrapper(singleton)
227254

228255
def release(self, obj):
229256
self._manager.release_singleton(self._tag, obj)

sdks/python/apache_beam/utils/multi_process_shared_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,23 @@
2323
from apache_beam.utils import multi_process_shared
2424

2525

26+
class CallableCounter(object):
27+
def __init__(self, start=0):
28+
self.running = start
29+
self.lock = threading.Lock()
30+
31+
def __call__(self):
32+
return self.running
33+
34+
def increment(self, value=1):
35+
with self.lock:
36+
self.running += value
37+
return self.running
38+
39+
def error(self, msg):
40+
raise RuntimeError(msg)
41+
42+
2643
class Counter(object):
2744
def __init__(self, start=0):
2845
self.running = start
@@ -45,6 +62,8 @@ class MultiProcessSharedTest(unittest.TestCase):
4562
def setUpClass(cls):
4663
cls.shared = multi_process_shared.MultiProcessShared(
4764
Counter, always_proxy=True).acquire()
65+
cls.sharedCallable = multi_process_shared.MultiProcessShared(
66+
CallableCounter, always_proxy=True).acquire()
4867

4968
def test_call(self):
5069
self.assertEqual(self.shared.get(), 0)
@@ -53,6 +72,13 @@ def test_call(self):
5372
self.assertEqual(self.shared.increment(value=10), 21)
5473
self.assertEqual(self.shared.get(), 21)
5574

75+
def test_call_callable(self):
76+
self.assertEqual(self.sharedCallable(), 0)
77+
self.assertEqual(self.sharedCallable.increment(), 1)
78+
self.assertEqual(self.sharedCallable.increment(10), 11)
79+
self.assertEqual(self.sharedCallable.increment(value=10), 21)
80+
self.assertEqual(self.sharedCallable(), 21)
81+
5682
def test_error(self):
5783
with self.assertRaisesRegex(Exception, 'something bad'):
5884
self.shared.error('something bad')

0 commit comments

Comments
 (0)