Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hard delete function to multi_process_shared #32238

Merged
merged 2 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions sdks/python/apache_beam/utils/multi_process_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ def release(self, proxy):
del self.obj
self.initialied = False

def unsafe_hard_delete(self):
with self.lock:
if self.initialied:
del self.obj
self.initialied = False


class _SingletonManager:
entries: Dict[Any, Any] = {}
Expand All @@ -153,6 +159,9 @@ def acquire_singleton(self, tag):
def release_singleton(self, tag, obj):
return self.entries[tag].release(obj)

def unsafe_hard_delete_singleton(self, tag):
return self.entries[tag].unsafe_hard_delete()


_process_level_singleton_manager = _SingletonManager()

Expand All @@ -169,6 +178,9 @@ class _SingletonRegistrar(multiprocessing.managers.BaseManager):
_SingletonRegistrar.register(
'release_singleton',
callable=_process_level_singleton_manager.release_singleton)
_SingletonRegistrar.register(
'unsafe_hard_delete_singleton',
callable=_process_level_singleton_manager.unsafe_hard_delete_singleton)


# By default, objects registered with BaseManager.register will have only
Expand Down Expand Up @@ -294,6 +306,17 @@ def acquire(self):
def release(self, obj):
self._manager.release_singleton(self._tag, obj.get_auto_proxy_object())

def unsafe_hard_delete(self):
"""Force deletes the underlying object

This function should be used with great care since any other references
to this object will now be invalid and may lead to strange errors. Only
call unsafe_hard_delete if either (a) you are sure no other references
to this object exist, or (b) you are ok with all existing references to
this object throwing strange errors when derefrenced.
"""
self._get_manager().unsafe_hard_delete_singleton(self._tag)

def _create_server(self, address_file):
# We need to be able to authenticate with both the manager and the process.
self._serving_manager = _SingletonRegistrar(
Expand Down
43 changes: 43 additions & 0 deletions sdks/python/apache_beam/utils/multi_process_shared_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,49 @@ def test_release(self):
with self.assertRaisesRegex(Exception, 'released'):
counter1.get()

def test_unsafe_hard_delete(self):
shared1 = multi_process_shared.MultiProcessShared(
Counter, tag='test_unsafe_hard_delete', always_proxy=True)
shared2 = multi_process_shared.MultiProcessShared(
Counter, tag='test_unsafe_hard_delete', always_proxy=True)

counter1 = shared1.acquire()
counter2 = shared2.acquire()
self.assertEqual(counter1.increment(), 1)
self.assertEqual(counter2.increment(), 2)

multi_process_shared.MultiProcessShared(
Counter, tag='test_unsafe_hard_delete').unsafe_hard_delete()

with self.assertRaises(Exception):
counter1.get()
with self.assertRaises(Exception):
counter2.get()

shared3 = multi_process_shared.MultiProcessShared(
Counter, tag='test_unsafe_hard_delete', always_proxy=True)

counter3 = shared3.acquire()

self.assertEqual(counter3.increment(), 1)

def test_unsafe_hard_delete_no_op(self):
shared1 = multi_process_shared.MultiProcessShared(
Counter, tag='test_unsafe_hard_delete_no_op', always_proxy=True)
shared2 = multi_process_shared.MultiProcessShared(
Counter, tag='test_unsafe_hard_delete_no_op', always_proxy=True)

counter1 = shared1.acquire()
counter2 = shared2.acquire()
self.assertEqual(counter1.increment(), 1)
self.assertEqual(counter2.increment(), 2)

multi_process_shared.MultiProcessShared(
Counter, tag='no_tag_to_delete').unsafe_hard_delete()

self.assertEqual(counter1.increment(), 3)
self.assertEqual(counter2.increment(), 4)

def test_release_always_proxy(self):
shared1 = multi_process_shared.MultiProcessShared(
Counter, tag='test_release_always_proxy', always_proxy=True)
Expand Down
Loading