-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Code restructure]
symbolic
sub-module.
Break `symbolic.py` into smaller pieces for better readability and extensibility. This CL also reorganizes tests for `symbolic` sub-module. PiperOrigin-RevId: 486216330
- Loading branch information
Showing
31 changed files
with
14,266 additions
and
11,497 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Copyright 2022 The PyGlove Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Thread-local utilities.""" | ||
|
||
import contextlib | ||
import threading | ||
from typing import Any, Iterator | ||
|
||
from pyglove.core import object_utils | ||
|
||
|
||
_thread_local_state = threading.local() | ||
|
||
|
||
@contextlib.contextmanager | ||
def value_scope( | ||
key: str, | ||
value_in_scope: Any, | ||
initial_value: Any) -> Iterator[None]: | ||
"""Context manager to set a thread local state within the scope.""" | ||
previous_value = getattr(_thread_local_state, key, initial_value) | ||
try: | ||
setattr(_thread_local_state, key, value_in_scope) | ||
yield | ||
finally: | ||
setattr(_thread_local_state, key, previous_value) | ||
|
||
|
||
def set_value(key: str, value: Any) -> None: | ||
"""Sets thread-local value by key.""" | ||
setattr(_thread_local_state, key, value) | ||
|
||
|
||
def get_value( | ||
key: str, | ||
default_value: Any = object_utils.MISSING_VALUE) -> Any: | ||
"""Gets thread-local value.""" | ||
value = getattr(_thread_local_state, key, default_value) | ||
if value == object_utils.MISSING_VALUE: | ||
raise ValueError(f'Key {key!r} does not exist in thread-local storage.') | ||
return value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Copyright 2022 The PyGlove Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Tests for pyglove.object_utils.thread_local.""" | ||
|
||
import threading | ||
import time | ||
import unittest | ||
|
||
from pyglove.core.object_utils import thread_local | ||
|
||
|
||
class ThreadLocalTest(unittest.TestCase): | ||
"""Tests for `pg.symbolic.thread_local`.""" | ||
|
||
def assert_thread_func(self, funcs, period_in_second=1): | ||
has_errors = [True] * len(funcs) | ||
def repeat_for_period(func, i): | ||
def _fn(): | ||
begin = time.time() | ||
while True: | ||
func() | ||
if time.time() - begin > period_in_second: | ||
break | ||
has_errors[i] = False | ||
return _fn | ||
|
||
threads = [threading.Thread(target=repeat_for_period(f, i)) | ||
for i, f in enumerate(funcs)] | ||
for t in threads: | ||
t.start() | ||
for t in threads: | ||
t.join() | ||
self.assertFalse(any(has_error for has_error in has_errors)) | ||
|
||
def test_set_get(self): | ||
k, v = 'x', 1 | ||
thread_local.set_value(k, v) | ||
self.assertEqual(thread_local.get_value(k), v) | ||
self.assertIsNone(thread_local.get_value('y', None)) | ||
with self.assertRaisesRegex( | ||
ValueError, 'Key .* does not exist in thread-local storage'): | ||
thread_local.get_value('abc') | ||
|
||
# Test thread locality. | ||
def thread_fun(i): | ||
def _fn(): | ||
thread_local.set_value('x', i) | ||
self.assertEqual(thread_local.get_value('x'), i) | ||
return _fn | ||
self.assert_thread_func([thread_fun(i) for i in range(5)], 2) | ||
|
||
def test_value_scope(self): | ||
with thread_local.value_scope('y', 1, None): | ||
self.assertEqual(thread_local.get_value('y'), 1) | ||
self.assertIsNone(thread_local.get_value('y')) | ||
|
||
# Test thread locality. | ||
def thread_fun(i): | ||
def _fn(): | ||
with thread_local.value_scope('y', i, None): | ||
self.assertEqual(thread_local.get_value('y'), i) | ||
self.assertIsNone(thread_local.get_value('y')) | ||
return _fn | ||
self.assert_thread_func([thread_fun(i) for i in range(5)], 2) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Oops, something went wrong.