Skip to content

Commit

Permalink
[Code restructure] symbolic sub-module.
Browse files Browse the repository at this point in the history
Break `symbolic.py` into smaller pieces for better readability and extensibility.
This CL also reorganizes tests for `symbolic` sub-module.

PiperOrigin-RevId: 486216330
  • Loading branch information
daiyip authored and pyglove authors committed Nov 4, 2022
1 parent 2561dc4 commit 67e8426
Show file tree
Hide file tree
Showing 31 changed files with 14,266 additions and 11,497 deletions.
6 changes: 3 additions & 3 deletions pyglove/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@
# Global flags.
allow_empty_field_description = symbolic.allow_empty_field_description
allow_repeated_class_registration = symbolic.allow_repeated_class_registration
set_stacktrace_limit = symbolic.set_stacktrace_limit
set_origin_stacktrace_limit = symbolic.set_origin_stacktrace_limit

# Context manager for scoped flags.
allow_partial = symbolic.allow_partial_values
allow_partial = symbolic.allow_partial
allow_writable_accessors = symbolic.allow_writable_accessors
notify_on_change = symbolic.notify_on_change
enable_type_check = symbolic.enable_type_check
Expand Down Expand Up @@ -103,7 +103,7 @@
ne = symbolic.ne
lt = symbolic.lt
gt = symbolic.gt
hash = symbolic.sym_hash # pylint: disable=redefined-builtin
hash = symbolic.hash # pylint: disable=redefined-builtin
clone = symbolic.clone

# Methods for querying symbolic types.
Expand Down
6 changes: 3 additions & 3 deletions pyglove/core/geno.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class DNASpec(symbolic.Object):
# This is helpful when we want to align decision points using DNASpec as
# dictionary key. Users can use `pg.eq`/`pg.ne` for symbolic comparisons
# and `pg.hash` for symbolic hashing.
allow_symbolic_comparison = False
use_symbolic_comparison = False

def _on_bound(self):
"""Event that is triggered when object is modified."""
Expand Down Expand Up @@ -2051,7 +2051,7 @@ def _on_bound(self):

# Automatically set the candidate index for template.
for i, c in enumerate(self.candidates):
c.rebind(index=i, skip_notification=True)
c.rebind(index=i, skip_notification=True, raise_on_no_change=False)

# Create sub choice specs and index decision points.
if self.num_choices > 1 and not self.is_subchoice:
Expand Down Expand Up @@ -3332,7 +3332,7 @@ def needs_feedback(self) -> bool:

def _setup(self):
self.generator.setup(self.dna_spec)
self._hash_fn = self.hash_fn or symbolic.sym_hash
self._hash_fn = self.hash_fn or symbolic.hash
self._cache = {}
self._enables_auto_reward = (
self.needs_feedback and self.auto_reward_fn is not None)
Expand Down
2 changes: 1 addition & 1 deletion pyglove/core/hyper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def testCustomApply(self):
self.assertIs(schema.Object(hyper.Template).apply(t), t)
self.assertIs(schema.Dict().apply(t), t)
with self.assertRaisesRegex(
ValueError, 'Dict cannot be applied to a different spec'):
ValueError, 'Dict .* cannot be assigned to an incompatible field'):
schema.Int().apply(t)


Expand Down
3 changes: 0 additions & 3 deletions pyglove/core/object_utils/common_traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,6 @@ def missing_values(self, flatten: bool = True) -> Dict[str, Any]: # pylint: dis
class Functor(metaclass=abc.ABCMeta):
"""Interface for functor."""

# `schema.Signature` object for this functor class.
signature = None

@abc.abstractmethod
def __call__(self, *args, **kwargs) -> Any:
"""Calls the functor.
Expand Down
52 changes: 52 additions & 0 deletions pyglove/core/object_utils/thread_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2022 The PyGlove Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Thread-local utilities."""

import contextlib
import threading
from typing import Any, Iterator

from pyglove.core import object_utils


_thread_local_state = threading.local()


@contextlib.contextmanager
def value_scope(
key: str,
value_in_scope: Any,
initial_value: Any) -> Iterator[None]:
"""Context manager to set a thread local state within the scope."""
previous_value = getattr(_thread_local_state, key, initial_value)
try:
setattr(_thread_local_state, key, value_in_scope)
yield
finally:
setattr(_thread_local_state, key, previous_value)


def set_value(key: str, value: Any) -> None:
"""Sets thread-local value by key."""
setattr(_thread_local_state, key, value)


def get_value(
key: str,
default_value: Any = object_utils.MISSING_VALUE) -> Any:
"""Gets thread-local value."""
value = getattr(_thread_local_state, key, default_value)
if value == object_utils.MISSING_VALUE:
raise ValueError(f'Key {key!r} does not exist in thread-local storage.')
return value
79 changes: 79 additions & 0 deletions pyglove/core/object_utils/thread_local_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright 2022 The PyGlove Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for pyglove.object_utils.thread_local."""

import threading
import time
import unittest

from pyglove.core.object_utils import thread_local


class ThreadLocalTest(unittest.TestCase):
"""Tests for `pg.symbolic.thread_local`."""

def assert_thread_func(self, funcs, period_in_second=1):
has_errors = [True] * len(funcs)
def repeat_for_period(func, i):
def _fn():
begin = time.time()
while True:
func()
if time.time() - begin > period_in_second:
break
has_errors[i] = False
return _fn

threads = [threading.Thread(target=repeat_for_period(f, i))
for i, f in enumerate(funcs)]
for t in threads:
t.start()
for t in threads:
t.join()
self.assertFalse(any(has_error for has_error in has_errors))

def test_set_get(self):
k, v = 'x', 1
thread_local.set_value(k, v)
self.assertEqual(thread_local.get_value(k), v)
self.assertIsNone(thread_local.get_value('y', None))
with self.assertRaisesRegex(
ValueError, 'Key .* does not exist in thread-local storage'):
thread_local.get_value('abc')

# Test thread locality.
def thread_fun(i):
def _fn():
thread_local.set_value('x', i)
self.assertEqual(thread_local.get_value('x'), i)
return _fn
self.assert_thread_func([thread_fun(i) for i in range(5)], 2)

def test_value_scope(self):
with thread_local.value_scope('y', 1, None):
self.assertEqual(thread_local.get_value('y'), 1)
self.assertIsNone(thread_local.get_value('y'))

# Test thread locality.
def thread_fun(i):
def _fn():
with thread_local.value_scope('y', i, None):
self.assertEqual(thread_local.get_value('y'), i)
self.assertIsNone(thread_local.get_value('y'))
return _fn
self.assert_thread_func([thread_fun(i) for i in range(5)], 2)


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 67e8426

Please sign in to comment.