From f50831a8a9368f320e4093c995020e39310f943c Mon Sep 17 00:00:00 2001 From: Daiyi Peng Date: Mon, 22 May 2023 13:09:29 -0700 Subject: [PATCH] Revise `pg.Contextual` into `pg.ContextualValue` and introduce functor-based `pg.ContextualGetter`. This CL allows users to create complex contextual value easily on top of functor. Example:: ``` @pg.contextual_getter def static_value(self, name, context, value): return value class A(pg.Object): x: pg.ContextualValue = static_value(value=1) ``` PiperOrigin-RevId: 534158856 --- pyglove/core/__init__.py | 11 ++- pyglove/core/symbolic/__init__.py | 8 +- pyglove/core/symbolic/base.py | 106 +++++++++++++++++++--- pyglove/core/symbolic/base_test.py | 89 ++++++++++++++----- pyglove/core/symbolic/contextual.py | 107 ++++++++++------------- pyglove/core/symbolic/contextual_test.py | 49 ++++------- pyglove/core/symbolic/dict.py | 3 +- pyglove/core/symbolic/dict_test.py | 3 +- pyglove/core/symbolic/functor.py | 6 +- pyglove/core/symbolic/list.py | 3 +- pyglove/core/symbolic/list_test.py | 9 +- pyglove/core/symbolic/object_test.py | 45 ++++++---- 12 files changed, 278 insertions(+), 161 deletions(-) diff --git a/pyglove/core/__init__.py b/pyglove/core/__init__.py index 25213c5..e029bfe 100644 --- a/pyglove/core/__init__.py +++ b/pyglove/core/__init__.py @@ -83,9 +83,6 @@ ClassWrapper = symbolic.ClassWrapper Functor = symbolic.Functor -# Contextual value marker. -Contextual = symbolic.Contextual - # Decorator for declaring symbolic. members for `pg.Object`. members = symbolic.members @@ -117,6 +114,7 @@ # Method for declaring a boilerplated class from a symbolic instance. boilerplate_class = symbolic.boilerplate_class + # # Context manager for swapping wrapped class with their wrappers. # @@ -164,6 +162,13 @@ Insertion = symbolic.Insertion WritePermissionError = symbolic.WritePermissionError +# Contextual value marker. +ContextualValue = symbolic.ContextualValue +ContextualGetter = symbolic.ContextualGetter + +# Decorator for making contextual getters. +contextual_getter = symbolic.contextual_getter + # # Symbols from 'typing.py' diff --git a/pyglove/core/symbolic/__init__.py b/pyglove/core/symbolic/__init__.py index e9d6b23..0db0572 100644 --- a/pyglove/core/symbolic/__init__.py +++ b/pyglove/core/symbolic/__init__.py @@ -60,6 +60,11 @@ from pyglove.core.symbolic.flags import auto_call_functors from pyglove.core.symbolic.flags import should_call_functors_during_init +# Marker for contextual values. +from pyglove.core.symbolic.base import ContextualValue +from pyglove.core.symbolic.contextual import ContextualGetter +from pyglove.core.symbolic.contextual import contextual_getter + # Symbolic types and their definition helpers. from pyglove.core.symbolic.base import Symbolic from pyglove.core.symbolic.list import List @@ -109,9 +114,6 @@ from pyglove.core.symbolic.base import load from pyglove.core.symbolic.base import save -# Marker for contextual values. -from pyglove.core.symbolic.contextual import Contextual - # Interfaces for pure symbolic objects. from pyglove.core.symbolic.pure_symbolic import PureSymbolic from pyglove.core.symbolic.pure_symbolic import NonDeterministic diff --git a/pyglove/core/symbolic/base.py b/pyglove/core/symbolic/base.py index c225f70..4b1c7f6 100644 --- a/pyglove/core/symbolic/base.py +++ b/pyglove/core/symbolic/base.py @@ -26,7 +26,6 @@ from pyglove.core import object_utils from pyglove.core import typing as pg_typing from pyglove.core.symbolic import flags -from pyglove.core.symbolic.contextual import Contextual from pyglove.core.symbolic.origin import Origin from pyglove.core.symbolic.pure_symbolic import NonDeterministic from pyglove.core.symbolic.pure_symbolic import PureSymbolic @@ -114,6 +113,92 @@ class DescendantQueryOption(enum.Enum): LEAF = 2 +class ContextualValue( + pg_typing.CustomTyping, + object_utils.JSONConvertible, + object_utils.Formattable, +): + """Base class for contextual value markers. + + Contextual value markers allows a symbolic attribute to be late bound + based on the entire symbolic tree where current symbolic value is part of. + As a result, users could access the late bound values directly through + symbolic attributes. + + For example:: + + class A(pg.Object): + x: int + y: int = pg.ContextualValue() + + # Not okay: `x` is not contextual and is not specified. + A() + + # Okay: both `x` and `y` are specified. + A(x=1, y=2) + + # Okay: `y` is contextual, hence optional. + a = A(x=1) + + # Raises: `y` is neither specified during __init__ + # nor provided from the context. + a.y + + d = pg.Dict(y=2, z=pg.Dict(a=a)) + + # `a.y` now refers to `d.a` since `d` is in its symbolic parent chain, + # aka. context. + assert a.y == 2 + """ + + def get(self, name: str, context: 'Symbolic') -> Any: + """Try get the contextual value for a symbolic attribute from a parent. + + Args: + name: The name of the request symbolic attribute. + context: A symbolic parent which represents the current context. + + Returns: + The value for the requested symbolic attribute from the current context. + If ``pg.MISSING_VALUE``, it means that current symbolic parent cannot + provide a contextual value for the attribute, so the current context + is moved upward, until it reaches to the root of the symbolic tree. + If a ``pg.ContextualValue`` object, it will use the new contextual + marker returned to resolve its value from its symbolic parent chains. + """ + return self.value_from(name, context) + + def value_from(self, name: str, context: 'Symbolic') -> Any: + """Try get the contextual value for a symbolic attribute from a parent.""" + return getattr(context, name, pg_typing.MISSING_VALUE) + + def custom_apply(self, *args, **kwargs: Any) -> Tuple[bool, Any]: + # This is to make a ``ContextualValue`` object assignable + # to any symbolic attribute. + return (False, self) + + def format( + self, + compact: bool = False, + verbose: bool = True, + root_indent: int = 0, + **kwargs: Any, + ) -> str: + del compact, verbose, root_indent, kwargs + return 'ContextualValue()' + + def to_json(self, **kwargs: Any) -> Dict[str, Any]: + return self.to_json_dict({}) + + def __eq__(self, other: Any) -> bool: + # NOTE(daiyip): We do strict type match here since subclasses might + # have their own __eq__ logic. + return type(other) is ContextualValue # pylint: disable=unidiomatic-typecheck + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + class Symbolic(object_utils.JSONConvertible, object_utils.MaybePartial, object_utils.Formattable): @@ -400,7 +485,7 @@ def sym_getattr( def sym_contextual_hasattr( self, key: Union[str, int], - getter: Optional[Contextual] = None, + getter: Optional[ContextualValue] = None, start: Union[ 'Symbolic', object_utils.MissingValue ] = pg_typing.MISSING_VALUE, @@ -409,7 +494,7 @@ def sym_contextual_hasattr( Args: key: Key of symbolic attribute. - getter: An optional ``Contextual`` object as the value retriever. + getter: An optional ``ContextualValue`` object as the value retriever. start: An object from current object to the root of the composition as the starting point of context lookup (upward). If ``pg.MISSING_VALUE``, it will start with current node. @@ -417,6 +502,7 @@ def sym_contextual_hasattr( Returns: True if the attribute exists. Otherwise False. """ + getter = getter or ContextualValue() v = self.sym_contextual_getattr( key, default=(pg_typing.MISSING_VALUE,), getter=getter, start=start ) @@ -426,7 +512,7 @@ def sym_contextual_getattr( self, key: Union[str, int], default: Any = object_utils.MISSING_VALUE, - getter: Optional[Contextual] = None, + getter: Optional[ContextualValue] = None, start: Union[ 'Symbolic', object_utils.MissingValue ] = pg_typing.MISSING_VALUE, @@ -437,7 +523,7 @@ def sym_contextual_getattr( key: Key of symbolic attribute. default: Default value if attribute does not exist. If absent, `AttributeError` will be thrown. - getter: An optional ``Contextual`` object as the value retriever. + getter: An optional ``ContextualValue`` object as the value retriever. start: An object from current object to the root of the composition as the starting point of context lookup (upward). If ``pg.MISSING_VALUE``, it will start with current node. @@ -450,18 +536,18 @@ def sym_contextual_getattr( AttributeError if `key` does not exist along the parent chain and default value is not ``pg.MISSING_VALUE``. """ - getter = getter or Contextual() + getter = getter or ContextualValue() if start == pg_typing.MISSING_VALUE: current = self else: current = typing.cast(Symbolic, start) while current is not None: - v = getter.value_from(key, current) - # NOTE(daiyip): when the contextual value from the parent returns - # another contextual object, we should follow the new return value's + v = getter.get(key, current) + # NOTE(daiyip): when the ContextualValue value from the parent returns + # another ContextualValue object, we should follow the new return value's # instruction instead of the original one. - if isinstance(v, Contextual): + if isinstance(v, ContextualValue): getter = v elif v != object_utils.MISSING_VALUE: return v diff --git a/pyglove/core/symbolic/base_test.py b/pyglove/core/symbolic/base_test.py index 678e230..359a815 100644 --- a/pyglove/core/symbolic/base_test.py +++ b/pyglove/core/symbolic/base_test.py @@ -14,11 +14,12 @@ """Tests for pyglove.symbolic.base.""" import copy +import dataclasses import unittest from pyglove.core import object_utils from pyglove.core import typing as pg_typing -from pyglove.core.symbolic.base import FieldUpdate +from pyglove.core.symbolic import base from pyglove.core.symbolic.dict import Dict @@ -28,7 +29,7 @@ class FieldUpdateTest(unittest.TestCase): def test_basics(self): x = Dict(x=1) f = pg_typing.Field('x', pg_typing.Int()) - update = FieldUpdate(object_utils.KeyPath('x'), x, f, 1, 2) + update = base.FieldUpdate(object_utils.KeyPath('x'), x, f, 1, 2) self.assertEqual(update.path, 'x') self.assertIs(update.target, x) self.assertIs(update.field, f) @@ -37,44 +38,90 @@ def test_basics(self): def test_format(self): self.assertEqual( - FieldUpdate( - object_utils.KeyPath('x'), - Dict(x=1), None, 1, 2).format(compact=True), - 'FieldUpdate(parent_path=, path=x, old_value=1, new_value=2)') + base.FieldUpdate( + object_utils.KeyPath('x'), Dict(x=1), None, 1, 2 + ).format(compact=True), + 'FieldUpdate(parent_path=, path=x, old_value=1, new_value=2)', + ) self.assertEqual( - FieldUpdate( - object_utils.KeyPath('a'), - Dict(x=Dict(a=1)).x, None, 1, 2).format(compact=True), - 'FieldUpdate(parent_path=x, path=a, old_value=1, new_value=2)') + base.FieldUpdate( + object_utils.KeyPath('a'), Dict(x=Dict(a=1)).x, None, 1, 2 + ).format(compact=True), + 'FieldUpdate(parent_path=x, path=a, old_value=1, new_value=2)', + ) def test_eq_ne(self): x = Dict() f = pg_typing.Field('x', pg_typing.Int()) self.assertEqual( - FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2), - FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2)) + base.FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2), + base.FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2), + ) # Targets are not the same instance. self.assertNotEqual( - FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2), - FieldUpdate(object_utils.KeyPath('a'), Dict(), f, 1, 2)) + base.FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2), + base.FieldUpdate(object_utils.KeyPath('a'), Dict(), f, 1, 2), + ) # Fields are not the same instance. self.assertNotEqual( - FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2), - FieldUpdate(object_utils.KeyPath('b'), x, copy.copy(f), 1, 2)) + base.FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2), + base.FieldUpdate(object_utils.KeyPath('b'), x, copy.copy(f), 1, 2), + ) self.assertNotEqual( - FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2), - FieldUpdate(object_utils.KeyPath('a'), x, f, 0, 2)) + base.FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2), + base.FieldUpdate(object_utils.KeyPath('a'), x, f, 0, 2), + ) self.assertNotEqual( - FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2), - FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 1)) + base.FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2), + base.FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 1), + ) self.assertNotEqual( - FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2), Dict()) + base.FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2), Dict() + ) + + +class ContextualValueTest(unittest.TestCase): + """Tests for `pg.symbolic.ContextualValue`.""" + + def test_str(self): + self.assertEqual(str(base.ContextualValue()), 'ContextualValue()') + + def test_repr(self): + self.assertEqual(repr(base.ContextualValue()), 'ContextualValue()') + + def test_eq(self): + self.assertEqual(base.ContextualValue(), base.ContextualValue()) + self.assertNotEqual(base.ContextualValue(), 1) + + def test_call(self): + @dataclasses.dataclass + class A: + x: int = 1 + y: int = 2 + + self.assertEqual(base.ContextualValue().get('x', A()), 1) + + def test_custom_typing(self): + v = base.ContextualValue() + self.assertIs(pg_typing.Int().apply(v), v) + self.assertIs(pg_typing.Str().apply(v), v) + + def test_to_json(self): + self.assertEqual( + base.to_json(base.ContextualValue()), + {'_type': f'{base.ContextualValue.__module__}.ContextualValue'}, + ) + + def test_from_json(self): + self.assertEqual( + base.from_json(base.ContextualValue().to_json()), base.ContextualValue() + ) if __name__ == '__main__': diff --git a/pyglove/core/symbolic/contextual.py b/pyglove/core/symbolic/contextual.py index 7187fae..f656487 100644 --- a/pyglove/core/symbolic/contextual.py +++ b/pyglove/core/symbolic/contextual.py @@ -11,76 +11,57 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Contextual value marker.""" -from typing import Any, Callable, Optional, Tuple +"""Customizable contextual value markers.""" +from typing import Any, List, Optional, Tuple, Union +from pyglove.core.symbolic import base +from pyglove.core.symbolic import functor import pyglove.core.typing as pg_typing -# The default contextual getter. -_DEFAULT_GETTER = lambda name, x: getattr(x, name, pg_typing.MISSING_VALUE) +class ContextualGetter(functor.Functor, base.ContextualValue): + """Base for functor-based contextual getter.""" -class Contextual(pg_typing.CustomTyping): - """Marker for values to be read from current field's symbolic parents. + def value_from(self, name: str, context: base.Symbolic) -> Any: + return self.__call__(name, context) - Example:: - class A(pg.Object): - x: int - y: int = pg.Contextual() - - # Not okay: `x` is not contextual and is not specified. - A() +def contextual_getter( + args: Optional[ + List[ + Union[ + Tuple[Union[str, pg_typing.KeySpec], pg_typing.ValueSpec, str], + Tuple[ + Union[str, pg_typing.KeySpec], pg_typing.ValueSpec, str, Any + ], + ] + ] + ] = None, # pylint: disable=bad-continuation + returns: Optional[pg_typing.ValueSpec] = None, + **kwargs, +): + """Decorator that makes ContextualGetter class from function. - # Okay: both `x` and `y` are specified. - A(x=1, y=2) + Examples:: - # Okay: `y` is contextual, hence optional. - a = A(x=1) + @pg.contextual_getter + def static_value(self, name, context, value): + return value - # Raises: `y` is neither specified during __init__ - # nor provided from the context. - a.y + class A(pg.Object): + x: pg.ContextualValue = static_value(value=1) + + Args: + args: A list of tuples that defines the schema for function arguments. + Please see `functor_class` for detailed explanation of `args`. If None, it + will be inferenced from the function argument annotations. + returns: Optional value spec for return value. If None, it will be inferred + from the function return value annotation. + **kwargs: Additional keyword argments for controlling the behavior of + functor creation. Please refer to :func:`pg.symbolic.functor_class` for + more details. + + Returns: + A function that converts a regular function into a ``pg.ContextualGetter`` + subclass. """ - - def __init__(self, getter: Optional[Callable[[str, Any], Any]] = None): - """Constructor. - - Args: - getter: An optional callable object to get the value of the request - attribute name from a symbolic parent, with signature: (attribute_name, - symbolic_parent) -> attribute_value If the getter returns - ``pg.MISSING_VALUE` or a ``pg.Contextual`` object, the context will be - moved unto the parent's parent. If None, the getter will be quering the - attribute of the same name from the the parent. - """ - super().__init__() - self._getter = getter or _DEFAULT_GETTER - - def custom_apply(self, *args, **kwargs) -> Tuple[bool, Any]: - # This is to make a ``Contextual`` object assignable - # to any symbolic attribute. - return (False, self) - - def value_from(self, name: str, parent) -> Any: - """Returns the contextual attribute value from the parent object. - - Args: - name: The name of request attribute. - parent: Current context (symbolic parent). - - Returns: - The value for the contextual attribute. - """ - return self._getter(name, parent) - - def __repr__(self): - return str(self) - - def __str__(self): - return 'CONTEXTUAL' - - def __eq__(self, other): - return isinstance(other, Contextual) and self._getter == other._getter - - def __ne__(self, other): - return not self.__eq__(other) + return functor.functor(args, returns, base_class=ContextualGetter, **kwargs) diff --git a/pyglove/core/symbolic/contextual_test.py b/pyglove/core/symbolic/contextual_test.py index b9c66b7..f44cdb3 100644 --- a/pyglove/core/symbolic/contextual_test.py +++ b/pyglove/core/symbolic/contextual_test.py @@ -11,47 +11,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for pyglove.symbolic.Contextual.""" +"""Tests for pyglove.symbolic.ContextualGetter.""" -import dataclasses import unittest -from pyglove.core import typing as pg_typing -from pyglove.core.symbolic.contextual import Contextual +from pyglove.core.symbolic import base +from pyglove.core.symbolic import contextual +from pyglove.core.symbolic.dict import Dict -class ContextualTest(unittest.TestCase): - """Tests for `pg.symbolic.Contextual`.""" +class ContextualGetterTest(unittest.TestCase): + """Tests for `pg.symbolic.ContextualGetter`.""" - def test_str(self): - self.assertEqual(str(Contextual()), 'CONTEXTUAL') - self.assertEqual(str(Contextual(lambda k, p: 1)), 'CONTEXTUAL') + def test_basics(self): + @contextual.contextual_getter + def static_value(k, p, v): + del k, p + return v - def test_repr(self): - self.assertEqual(repr(Contextual()), 'CONTEXTUAL') - self.assertEqual(repr(Contextual(lambda k, p: 1)), 'CONTEXTUAL') - - def test_eq(self): - self.assertEqual(Contextual(), Contextual()) - getter = lambda k, p: 1 - self.assertEqual(Contextual(getter), Contextual(getter)) - - self.assertNotEqual(Contextual(), 1) - self.assertNotEqual(Contextual(getter), Contextual()) - - def test_value_from(self): - @dataclasses.dataclass - class A: - x: int = 1 - y: int = 2 - - self.assertEqual(Contextual().value_from('x', A()), 1) - self.assertEqual(Contextual(lambda k, p: p.y).value_from('x', A()), 2) - - def test_custom_typing(self): - v = Contextual() - self.assertIs(pg_typing.Int().apply(v), v) - self.assertIs(pg_typing.Str().apply(v), v) + getter = static_value(v=1) # pylint: disable=no-value-for-parameter + self.assertIsInstance(getter, base.ContextualValue) + self.assertEqual(getter.get('x', Dict()), 1) + self.assertEqual(base.from_json(base.to_json(getter)), getter) if __name__ == '__main__': diff --git a/pyglove/core/symbolic/dict.py b/pyglove/core/symbolic/dict.py index 51fb816..e287706 100644 --- a/pyglove/core/symbolic/dict.py +++ b/pyglove/core/symbolic/dict.py @@ -19,7 +19,6 @@ from pyglove.core import object_utils from pyglove.core import typing as pg_typing from pyglove.core.symbolic import base -from pyglove.core.symbolic import contextual from pyglove.core.symbolic import flags @@ -579,7 +578,7 @@ def _on_change(self, field_updates: typing.Dict[object_utils.KeyPath, def __getitem__(self, key: str) -> Any: """Get item in this Dict.""" v = super().__getitem__(key) - if isinstance(v, contextual.Contextual): + if isinstance(v, base.ContextualValue): start = self.sym_parent # NOTE(daiyip): The parent of `pg.Object`'s attribute dict points to # the `pg.Object` instance once it's set up. Here we let the ancester diff --git a/pyglove/core/symbolic/dict_test.py b/pyglove/core/symbolic/dict_test.py index 5be225f..38d7dc0 100644 --- a/pyglove/core/symbolic/dict_test.py +++ b/pyglove/core/symbolic/dict_test.py @@ -23,7 +23,6 @@ from pyglove.core.symbolic import base from pyglove.core.symbolic import flags from pyglove.core.symbolic import object as pg_object -from pyglove.core.symbolic.contextual import Contextual from pyglove.core.symbolic.dict import Dict from pyglove.core.symbolic.list import List from pyglove.core.symbolic.pure_symbolic import NonDeterministic @@ -1353,7 +1352,7 @@ def test_seal(self): self.assertTrue(sd.a.is_sealed) def test_contextual(self): - sd = Dict(x=Contextual()) + sd = Dict(x=base.ContextualValue()) with self.assertRaisesRegex( AttributeError, '`x` is not found under its context' ): diff --git a/pyglove/core/symbolic/functor.py b/pyglove/core/symbolic/functor.py index 96b3fc4..7aca5a2 100644 --- a/pyglove/core/symbolic/functor.py +++ b/pyglove/core/symbolic/functor.py @@ -475,10 +475,12 @@ def bar(a, b, c, *args, **kwargs): """ if inspect.isfunction(args): assert returns is None - assert base_class is None return functor_class( typing.cast(Callable[..., Any], args), - add_to_registry=True, **kwargs) + base_class=base_class, + add_to_registry=True, + **kwargs, + ) return lambda fn: functor_class( # pylint: disable=g-long-lambda # pytype: disable=wrong-arg-types fn, args, returns, base_class=base_class, diff --git a/pyglove/core/symbolic/list.py b/pyglove/core/symbolic/list.py index 599ed58..772965d 100644 --- a/pyglove/core/symbolic/list.py +++ b/pyglove/core/symbolic/list.py @@ -20,7 +20,6 @@ from pyglove.core import object_utils from pyglove.core import typing as pg_typing from pyglove.core.symbolic import base -from pyglove.core.symbolic import contextual from pyglove.core.symbolic import flags @@ -464,7 +463,7 @@ def _on_change(self, def __getitem__(self, index) -> Any: """Gets the item at a given position.""" v = super().__getitem__(index) - if isinstance(v, contextual.Contextual): + if isinstance(v, base.ContextualValue): v = self.sym_contextual_getattr(index, getter=v, start=self.sym_parent) return v diff --git a/pyglove/core/symbolic/list_test.py b/pyglove/core/symbolic/list_test.py index 2cf06b5..30a7d1a 100644 --- a/pyglove/core/symbolic/list_test.py +++ b/pyglove/core/symbolic/list_test.py @@ -21,9 +21,9 @@ from pyglove.core import object_utils from pyglove.core import typing as pg_typing from pyglove.core.symbolic import base +from pyglove.core.symbolic import contextual from pyglove.core.symbolic import flags from pyglove.core.symbolic import object as pg_object -from pyglove.core.symbolic.contextual import Contextual from pyglove.core.symbolic.dict import Dict from pyglove.core.symbolic.list import Insertion from pyglove.core.symbolic.list import List @@ -1156,7 +1156,12 @@ def test_seal(self): def test_contextual(self): # Test contextual access for schemaless list. # Okay: sl[1] is contextual. - sl = List([0, Contextual(lambda i, x: x.a)]) + @contextual.contextual_getter + def redirectd_value(k, p, key): + del k + return getattr(p, key) + + sl = List([0, redirectd_value(key='a')]) # pylint: disable=no-value-for-parameter self.assertEqual(sl[0], 0) with self.assertRaisesRegex( diff --git a/pyglove/core/symbolic/object_test.py b/pyglove/core/symbolic/object_test.py index 507fa57..6be4e9b 100644 --- a/pyglove/core/symbolic/object_test.py +++ b/pyglove/core/symbolic/object_test.py @@ -24,10 +24,10 @@ from pyglove.core import object_utils from pyglove.core import typing as pg_typing from pyglove.core.symbolic import base +from pyglove.core.symbolic import contextual from pyglove.core.symbolic import flags from pyglove.core.symbolic.base import query as pg_query from pyglove.core.symbolic.base import traverse as pg_traverse -from pyglove.core.symbolic.contextual import Contextual from pyglove.core.symbolic.dict import Dict from pyglove.core.symbolic.functor import functor as pg_functor from pyglove.core.symbolic.list import List @@ -720,7 +720,7 @@ def test_sym_contextual_hasattr(self): class A(Object): x: int y: int = 1 - z: int = Contextual() + z: int = base.ContextualValue() a = A(0) _ = Dict(p=Dict(a=a, b=3), z=2) @@ -734,7 +734,12 @@ class A(Object): self.assertTrue(a.sym_contextual_hasattr('z', start=a.sym_parent)) # Custom getter. - getter = Contextual(lambda k, p: getattr(p, 'b')) + @contextual.contextual_getter + def redirected_value(k, p, key): + del k + return getattr(p, key) + + getter = redirected_value(key='b') # pylint: disable=no-value-for-parameter self.assertTrue(a.sym_contextual_hasattr('x', getter, start=a.sym_parent)) self.assertTrue(a.sym_contextual_hasattr('y', getter, start=a.sym_parent)) self.assertTrue(a.sym_contextual_hasattr('z', getter, start=a.sym_parent)) @@ -743,7 +748,7 @@ def test_sym_contextual_getattr(self): class A(Object): x: int y: int = 1 - z: int = Contextual() + z: int = base.ContextualValue() a = A(0) @@ -771,7 +776,12 @@ class A(Object): self.assertEqual(a.sym_contextual_getattr('z', start=a.sym_parent), 2) # Custom getter. - getter = Contextual(lambda k, p: getattr(p, 'b')) + @contextual.contextual_getter + def redirected_value(k, p, key): + del k + return getattr(p, key) + + getter = redirected_value(key='b') # pylint: disable=no-value-for-parameter self.assertEqual( a.sym_contextual_getattr('x', getter=getter, start=a.sym_parent), 3 ) @@ -1854,7 +1864,7 @@ class E(B): # pylint: disable=unused-variable def test_contextual(self): class A(Object): x: int - y: str = Contextual() + y: str = base.ContextualValue() # Okay: `A.y` is contextual. a = A(1) @@ -1876,24 +1886,25 @@ class A(Object): _ = a.y # Test parent contextual value with custom getter. - sd = Dict( - a='bar', - b=Dict( - x=a, - y=Contextual(lambda k, p: getattr(p, 'a')))) + @contextual.contextual_getter + def redirected_value(k, p, key): + del k + return getattr(p, key) + + sd = Dict(a='bar', b=Dict(x=a, y=redirected_value(key='a'))) # pylint: disable=no-value-for-parameter # a.y is redirected to sd.a. self.assertEqual(a.y, 'bar') + @contextual.contextual_getter + def immediate_attr(k, p): + return p.sym_getattr(k) + class B(Object): - x: int = Contextual(lambda k, p: p.sym_getattr(k)) + x: int = immediate_attr() # pylint: disable=no-value-for-parameter b = B() - sd = Dict( - a='bar', - b=Dict( - b=b, - x=Contextual(lambda k, p: getattr(p, 'a')))) + sd = Dict(a='bar', b=Dict(b=b, x=redirected_value(key='a'))) # pylint: disable=no-value-for-parameter # a.y is redirected to sd.a. self.assertEqual(b.x, 'bar')