Skip to content

Commit

Permalink
Expose pg.object_utils.thread_local_xxx as public APIs.
Browse files Browse the repository at this point in the history
This is because many PyGlove downstream projects require to manage thread-local values in similar manner. And PyGlove's implementation is pretty general.

PiperOrigin-RevId: 539101017
  • Loading branch information
daiyip authored and pyglove authors committed Jun 9, 2023
1 parent c4ad652 commit 4ee6572
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 45 deletions.
5 changes: 2 additions & 3 deletions pyglove/core/hyper/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from pyglove.core import object_utils
from pyglove.core import symbolic
from pyglove.core import typing as pg_typing
from pyglove.core.object_utils import thread_local


class HyperValue(symbolic.NonDeterministic): # pytype: disable=ignored-metaclass
Expand Down Expand Up @@ -187,12 +186,12 @@ def set_dynamic_evaluate_fn(
global _global_dynamic_evaluate_fn
if per_thread:
assert _global_dynamic_evaluate_fn is None, _global_dynamic_evaluate_fn
thread_local.set_value(_TLS_KEY_DYNAMIC_EVALUATE_FN, fn)
object_utils.thread_local_set_value(_TLS_KEY_DYNAMIC_EVALUATE_FN, fn)
else:
_global_dynamic_evaluate_fn = fn


def get_dynamic_evaluate_fn() -> Optional[Callable[[HyperValue], Any]]:
"""Gets current dynamic evaluate function."""
return thread_local.get_value(
return object_utils.thread_local_get_value(
_TLS_KEY_DYNAMIC_EVALUATE_FN, _global_dynamic_evaluate_fn)
6 changes: 3 additions & 3 deletions pyglove/core/hyper/dynamic_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
from typing import Any, Callable, Dict, Iterator, List, Optional, Union

from pyglove.core import geno
from pyglove.core import object_utils
from pyglove.core import symbolic
from pyglove.core import typing as pg_typing
from pyglove.core.hyper import base
from pyglove.core.hyper import categorical
from pyglove.core.hyper import custom
from pyglove.core.hyper import numerical
from pyglove.core.hyper import object_template
from pyglove.core.object_utils import thread_local


@contextlib.contextmanager
Expand Down Expand Up @@ -520,10 +520,10 @@ def ensure_thread_safety(self, context: DynamicEvaluationContext):
@property
def _local_stack(self):
"""Returns thread-local stack."""
stack = thread_local.get_value(self._TLS_KEY, None)
stack = object_utils.thread_local_get_value(self._TLS_KEY, None)
if stack is None:
stack = []
thread_local.set_value(self._TLS_KEY, stack)
object_utils.thread_local_set_value(self._TLS_KEY, stack)
return stack

def push(self, context: DynamicEvaluationContext):
Expand Down
5 changes: 5 additions & 0 deletions pyglove/core/object_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@
# Handling code generation.
from pyglove.core.object_utils.codegen import make_function

# Handling thread local values.
from pyglove.core.object_utils.thread_local import thread_local_value_scope
from pyglove.core.object_utils.thread_local import thread_local_set_value
from pyglove.core.object_utils.thread_local import thread_local_get_value

# Handling docstrings.
from pyglove.core.object_utils.docstr_utils import DocStr
from pyglove.core.object_utils.docstr_utils import DocStrStyle
Expand Down
17 changes: 7 additions & 10 deletions pyglove/core/object_utils/thread_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,16 @@
import threading
from typing import Any, Iterator

from pyglove.core import object_utils
from pyglove.core.object_utils.missing import MISSING_VALUE


_thread_local_state = threading.local()


@contextlib.contextmanager
def value_scope(
key: str,
value_in_scope: Any,
initial_value: Any) -> Iterator[None]:
def thread_local_value_scope(
key: str, value_in_scope: Any, initial_value: Any
) -> Iterator[None]:
"""Context manager to set a thread local state within the scope."""
previous_value = getattr(_thread_local_state, key, initial_value)
try:
Expand All @@ -37,16 +36,14 @@ def value_scope(
setattr(_thread_local_state, key, previous_value)


def set_value(key: str, value: Any) -> None:
def thread_local_set_value(key: str, value: Any) -> None:
"""Sets thread-local value by key."""
setattr(_thread_local_state, key, value)


def get_value(
key: str,
default_value: Any = object_utils.MISSING_VALUE) -> Any:
def thread_local_get_value(key: str, default_value: Any = MISSING_VALUE) -> Any:
"""Gets thread-local value."""
value = getattr(_thread_local_state, key, default_value)
if value == object_utils.MISSING_VALUE:
if value == MISSING_VALUE:
raise ValueError(f'Key {key!r} does not exist in thread-local storage.')
return value
26 changes: 13 additions & 13 deletions pyglove/core/object_utils/thread_local_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,32 +45,32 @@ def _fn():

def test_set_get(self):
k, v = 'x', 1
thread_local.set_value(k, v)
self.assertEqual(thread_local.get_value(k), v)
self.assertIsNone(thread_local.get_value('y', None))
thread_local.thread_local_set_value(k, v)
self.assertEqual(thread_local.thread_local_get_value(k), v)
self.assertIsNone(thread_local.thread_local_get_value('y', None))
with self.assertRaisesRegex(
ValueError, 'Key .* does not exist in thread-local storage'):
thread_local.get_value('abc')
thread_local.thread_local_get_value('abc')

# Test thread locality.
def thread_fun(i):
def _fn():
thread_local.set_value('x', i)
self.assertEqual(thread_local.get_value('x'), i)
thread_local.thread_local_set_value('x', i)
self.assertEqual(thread_local.thread_local_get_value('x'), i)
return _fn
self.assert_thread_func([thread_fun(i) for i in range(5)], 2)

def test_value_scope(self):
with thread_local.value_scope('y', 1, None):
self.assertEqual(thread_local.get_value('y'), 1)
self.assertIsNone(thread_local.get_value('y'))
def test_thread_local_value_scope(self):
with thread_local.thread_local_value_scope('y', 1, None):
self.assertEqual(thread_local.thread_local_get_value('y'), 1)
self.assertIsNone(thread_local.thread_local_get_value('y'))

# Test thread locality.
def thread_fun(i):
def _fn():
with thread_local.value_scope('y', i, None):
self.assertEqual(thread_local.get_value('y'), i)
self.assertIsNone(thread_local.get_value('y'))
with thread_local.thread_local_value_scope('y', i, None):
self.assertEqual(thread_local.thread_local_get_value('y'), i)
self.assertIsNone(thread_local.thread_local_get_value('y'))
return _fn
self.assert_thread_func([thread_fun(i) for i in range(5)], 2)

Expand Down
42 changes: 26 additions & 16 deletions pyglove/core/symbolic/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,16 @@ def notify_on_change(enabled: bool = True) -> ContextManager[None]:
Returns:
A context manager for allowing/disallowing change notification in scope.
"""
return thread_local.value_scope(
_TLS_ENABLE_CHANGE_NOTIFICATION, enabled, True)
return thread_local.thread_local_value_scope(
_TLS_ENABLE_CHANGE_NOTIFICATION, enabled, True
)


def is_change_notification_enabled() -> bool:
"""Returns True if change notification is enabled."""
return thread_local.get_value(_TLS_ENABLE_CHANGE_NOTIFICATION, True)
return thread_local.thread_local_get_value(
_TLS_ENABLE_CHANGE_NOTIFICATION, True
)


def track_origin(enabled: bool = True) -> ContextManager[None]:
Expand All @@ -175,13 +178,14 @@ def track_origin(enabled: bool = True) -> ContextManager[None]:
Returns:
A context manager for enable or disable origin tracking.
"""
return thread_local.value_scope(
_TLS_ENABLE_ORIGIN_TRACKING, enabled, False)
return thread_local.thread_local_value_scope(
_TLS_ENABLE_ORIGIN_TRACKING, enabled, False
)


def is_tracking_origin() -> bool:
"""Returns if origin of symbolic object are being tracked."""
return thread_local.get_value(_TLS_ENABLE_ORIGIN_TRACKING, False)
return thread_local.thread_local_get_value(_TLS_ENABLE_ORIGIN_TRACKING, False)


def enable_type_check(enabled: bool = True) -> ContextManager[None]:
Expand All @@ -203,12 +207,14 @@ def enable_type_check(enabled: bool = True) -> ContextManager[None]:
Returns:
A context manager for allowing/disallowing runtime type check.
"""
return thread_local.value_scope(_TLS_ENABLE_TYPE_CHECK, enabled, True)
return thread_local.thread_local_value_scope(
_TLS_ENABLE_TYPE_CHECK, enabled, True
)


def is_type_check_enabled() -> bool:
"""Returns True if runtme type check is enabled."""
return thread_local.get_value(_TLS_ENABLE_TYPE_CHECK, True)
return thread_local.thread_local_get_value(_TLS_ENABLE_TYPE_CHECK, True)


def allow_writable_accessors(
Expand Down Expand Up @@ -244,12 +250,14 @@ def allow_writable_accessors(
symbolic values in scope. After leaving the scope, the
`accessor_writable` flag of individual objects will remain intact.
"""
return thread_local.value_scope(_TLS_ACCESSOR_WRITABLE, writable, None)
return thread_local.thread_local_value_scope(
_TLS_ACCESSOR_WRITABLE, writable, None
)


def is_under_accessor_writable_scope() -> Optional[bool]:
"""Return True if symbolic values are treated as sealed in current context."""
return thread_local.get_value(_TLS_ACCESSOR_WRITABLE, None)
return thread_local.thread_local_get_value(_TLS_ACCESSOR_WRITABLE, None)


def as_sealed(sealed: Optional[bool] = True) -> ContextManager[None]:
Expand Down Expand Up @@ -287,12 +295,12 @@ def as_sealed(sealed: Optional[bool] = True) -> ContextManager[None]:
in scope. After leaving the scope, the sealed state of individual objects
will remain intact.
"""
return thread_local.value_scope(_TLS_SEALED, sealed, None)
return thread_local.thread_local_value_scope(_TLS_SEALED, sealed, None)


def is_under_sealed_scope() -> Optional[bool]:
"""Return True if symbolic values are treated as sealed in current context."""
return thread_local.get_value(_TLS_SEALED, None)
return thread_local.thread_local_get_value(_TLS_SEALED, None)


def allow_partial(allow: Optional[bool] = True) -> ContextManager[None]:
Expand Down Expand Up @@ -327,12 +335,12 @@ class A(pg.Object):
After leaving the scope, the `allow_partial` state of individual objects
will remain intact.
"""
return thread_local.value_scope(_TLS_ALLOW_PARTIAL, allow, None)
return thread_local.thread_local_value_scope(_TLS_ALLOW_PARTIAL, allow, None)


def is_under_partial_scope() -> Optional[bool]:
"""Return True if partial value is allowed in current context."""
return thread_local.get_value(_TLS_ALLOW_PARTIAL, None)
return thread_local.thread_local_get_value(_TLS_ALLOW_PARTIAL, None)


def auto_call_functors(enabled: bool = True) -> ContextManager[None]:
Expand All @@ -358,9 +366,11 @@ def foo(x, y):
Returns:
A context manager for enabling/disabling auto call for functors.
"""
return thread_local.value_scope(_TLS_AUTO_CALL_FUNCTORS, enabled, False)
return thread_local.thread_local_value_scope(
_TLS_AUTO_CALL_FUNCTORS, enabled, False
)


def should_call_functors_during_init() -> Optional[bool]:
"""Return True functors should be automatically called during __init__."""
return thread_local.get_value(_TLS_AUTO_CALL_FUNCTORS, None)
return thread_local.thread_local_get_value(_TLS_AUTO_CALL_FUNCTORS, None)

0 comments on commit 4ee6572

Please sign in to comment.