Skip to content

Commit

Permalink
Fix bug in pg.object_utils.thread_local_value_scope.
Browse files Browse the repository at this point in the history
This CL deletes the value associated with key when the top level `thread_local_value_scope` exits, if the key was not present previously. This
avoids unexpectedly reuse of the initial value across multiple call to `thread_local_value_scope`.

PiperOrigin-RevId: 561550754
  • Loading branch information
daiyip authored and pyglove authors committed Aug 31, 2023
1 parent bdffd58 commit a560ff7
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
10 changes: 7 additions & 3 deletions pyglove/core/object_utils/thread_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,16 @@ def thread_local_value_scope(
key: str, value_in_scope: Any, initial_value: Any
) -> Iterator[None]:
"""Context manager to set a thread local state within the scope."""
previous_value = getattr(_thread_local_state, key, initial_value)
has_key = thread_local_has(key)
previous_value = thread_local_get(key, initial_value)
try:
setattr(_thread_local_state, key, value_in_scope)
thread_local_set(key, value_in_scope)
yield
finally:
setattr(_thread_local_state, key, previous_value)
if has_key:
thread_local_set(key, previous_value)
else:
thread_local_del(key)


def thread_local_has(key: str) -> bool:
Expand Down
7 changes: 4 additions & 3 deletions pyglove/core/object_utils/thread_local_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_set_get_has_delete(self):
thread_local.thread_local_del(k)
self.assertFalse(thread_local.thread_local_has(k))

self.assertIsNone(thread_local.thread_local_get('y', None))
self.assertFalse(thread_local.thread_local_has('y'))
with self.assertRaisesRegex(
ValueError, 'Key .* does not exist in thread-local storage'):
thread_local.thread_local_get('abc')
Expand All @@ -70,16 +70,17 @@ def _fn():
self.assert_thread_func([thread_fun(i) for i in range(5)], 2)

def test_thread_local_value_scope(self):
thread_local.thread_local_set('y', 2)
with thread_local.thread_local_value_scope('y', 1, None):
self.assertEqual(thread_local.thread_local_get('y'), 1)
self.assertIsNone(thread_local.thread_local_get('y'))
self.assertEqual(thread_local.thread_local_get('y'), 2)

# Test thread locality.
def thread_fun(i):
def _fn():
with thread_local.thread_local_value_scope('y', i, None):
self.assertEqual(thread_local.thread_local_get('y'), i)
self.assertIsNone(thread_local.thread_local_get('y'))
self.assertFalse(thread_local.thread_local_has('y'))
return _fn
self.assert_thread_func([thread_fun(i) for i in range(5)], 2)

Expand Down

0 comments on commit a560ff7

Please sign in to comment.