Skip to content

Commit

Permalink
Allow multi_process_shared objects to be called (#26202)
Browse files Browse the repository at this point in the history
* Allow multi_process_shared objects to be called

* Allow multi_process_shared objects to be called (fixed, test passing)

* formatting

* Update sdks/python/apache_beam/utils/multi_process_shared.py

Co-authored-by: Anand Inguva <[email protected]>

* Type hint

* Type hint

---------

Co-authored-by: Anand Inguva <[email protected]>
  • Loading branch information
damccorm and AnandInguva authored Apr 24, 2023
1 parent fbc7df4 commit 44a17cb
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 2 deletions.
31 changes: 29 additions & 2 deletions sdks/python/apache_beam/utils/multi_process_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ def __init__(self, entry):
self._SingletonProxy_entry = entry
self._SingletonProxy_valid = True

# Used to make the shared object callable (see _AutoProxyWrapper below)
def singletonProxy_call__(self, *args, **kwargs):
if not self._SingletonProxy_valid:
raise RuntimeError('Entry was released.')
return self._SingletonProxy_entry.obj.__call__(*args, **kwargs)

def _SingletonProxy_release(self):
assert self._SingletonProxy_valid
self._SingletonProxy_valid = False
Expand All @@ -61,7 +67,9 @@ def __getattr__(self, name):

def __dir__(self):
# Needed for multiprocessing.managers's proxying.
return self._SingletonProxy_entry.obj.__dir__()
dir = self._SingletonProxy_entry.obj.__dir__()
dir.append('singletonProxy_call__')
return dir


class _SingletonEntry:
Expand Down Expand Up @@ -127,6 +135,24 @@ class _SingletonRegistrar(multiprocessing.managers.BaseManager):
callable=_process_level_singleton_manager.release_singleton)


# By default, objects registered with BaseManager.register will have only
# public methods available (excluding __call__). If you know the functions
# you would like to expose, you can do so at register time with the `exposed`
# attribute. Since we don't, we will add a wrapper around the returned AutoProxy
# object to handle __call__ function calls and turn them into
# singletonProxy_call__ calls (which is a wrapper around the underlying
# object's __call__ function)
class _AutoProxyWrapper:
def __init__(self, proxyObject: multiprocessing.managers.BaseProxy):
self._proxyObject = proxyObject

def __call__(self, *args, **kwargs):
return self._proxyObject.singletonProxy_call__(*args, **kwargs)

def __getattr__(self, name):
return getattr(self._proxyObject, name)


class MultiProcessShared(Generic[T]):
"""MultiProcessShared is used to share a single object across processes.
Expand Down Expand Up @@ -223,7 +249,8 @@ def acquire(self):
# inputs)
# Caveat: They must always agree, as they will be ignored if the object
# is already constructed.
return self._get_manager().acquire_singleton(self._tag)
singleton = self._get_manager().acquire_singleton(self._tag)
return _AutoProxyWrapper(singleton)

def release(self, obj):
self._manager.release_singleton(self._tag, obj)
Expand Down
26 changes: 26 additions & 0 deletions sdks/python/apache_beam/utils/multi_process_shared_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,23 @@
from apache_beam.utils import multi_process_shared


class CallableCounter(object):
def __init__(self, start=0):
self.running = start
self.lock = threading.Lock()

def __call__(self):
return self.running

def increment(self, value=1):
with self.lock:
self.running += value
return self.running

def error(self, msg):
raise RuntimeError(msg)


class Counter(object):
def __init__(self, start=0):
self.running = start
Expand All @@ -45,6 +62,8 @@ class MultiProcessSharedTest(unittest.TestCase):
def setUpClass(cls):
cls.shared = multi_process_shared.MultiProcessShared(
Counter, always_proxy=True).acquire()
cls.sharedCallable = multi_process_shared.MultiProcessShared(
CallableCounter, always_proxy=True).acquire()

def test_call(self):
self.assertEqual(self.shared.get(), 0)
Expand All @@ -53,6 +72,13 @@ def test_call(self):
self.assertEqual(self.shared.increment(value=10), 21)
self.assertEqual(self.shared.get(), 21)

def test_call_callable(self):
self.assertEqual(self.sharedCallable(), 0)
self.assertEqual(self.sharedCallable.increment(), 1)
self.assertEqual(self.sharedCallable.increment(10), 11)
self.assertEqual(self.sharedCallable.increment(value=10), 21)
self.assertEqual(self.sharedCallable(), 21)

def test_error(self):
with self.assertRaisesRegex(Exception, 'something bad'):
self.shared.error('something bad')
Expand Down

0 comments on commit 44a17cb

Please sign in to comment.