From a560ff7976648ef73b3996238351ddac6ca3bccc Mon Sep 17 00:00:00 2001 From: Daiyi Peng Date: Wed, 30 Aug 2023 22:43:32 -0700 Subject: [PATCH] Fix bug in `pg.object_utils.thread_local_value_scope`. This CL deletes the value associated with key when the top level `thread_local_value_scope` exits, if the key was not present previously. This avoids unexpectedly reuse of the initial value across multiple call to `thread_local_value_scope`. PiperOrigin-RevId: 561550754 --- pyglove/core/object_utils/thread_local.py | 10 +++++++--- pyglove/core/object_utils/thread_local_test.py | 7 ++++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/pyglove/core/object_utils/thread_local.py b/pyglove/core/object_utils/thread_local.py index 76e8b0f..2dc76e6 100644 --- a/pyglove/core/object_utils/thread_local.py +++ b/pyglove/core/object_utils/thread_local.py @@ -28,12 +28,16 @@ def thread_local_value_scope( key: str, value_in_scope: Any, initial_value: Any ) -> Iterator[None]: """Context manager to set a thread local state within the scope.""" - previous_value = getattr(_thread_local_state, key, initial_value) + has_key = thread_local_has(key) + previous_value = thread_local_get(key, initial_value) try: - setattr(_thread_local_state, key, value_in_scope) + thread_local_set(key, value_in_scope) yield finally: - setattr(_thread_local_state, key, previous_value) + if has_key: + thread_local_set(key, previous_value) + else: + thread_local_del(key) def thread_local_has(key: str) -> bool: diff --git a/pyglove/core/object_utils/thread_local_test.py b/pyglove/core/object_utils/thread_local_test.py index 5a434f3..a427c15 100644 --- a/pyglove/core/object_utils/thread_local_test.py +++ b/pyglove/core/object_utils/thread_local_test.py @@ -52,7 +52,7 @@ def test_set_get_has_delete(self): thread_local.thread_local_del(k) self.assertFalse(thread_local.thread_local_has(k)) - self.assertIsNone(thread_local.thread_local_get('y', None)) + self.assertFalse(thread_local.thread_local_has('y')) with self.assertRaisesRegex( ValueError, 'Key .* does not exist in thread-local storage'): thread_local.thread_local_get('abc') @@ -70,16 +70,17 @@ def _fn(): self.assert_thread_func([thread_fun(i) for i in range(5)], 2) def test_thread_local_value_scope(self): + thread_local.thread_local_set('y', 2) with thread_local.thread_local_value_scope('y', 1, None): self.assertEqual(thread_local.thread_local_get('y'), 1) - self.assertIsNone(thread_local.thread_local_get('y')) + self.assertEqual(thread_local.thread_local_get('y'), 2) # Test thread locality. def thread_fun(i): def _fn(): with thread_local.thread_local_value_scope('y', i, None): self.assertEqual(thread_local.thread_local_get('y'), i) - self.assertIsNone(thread_local.thread_local_get('y')) + self.assertFalse(thread_local.thread_local_has('y')) return _fn self.assert_thread_func([thread_fun(i) for i in range(5)], 2)