Skip to content

Commit

Permalink
Revise pg.Contextual into pg.ContextualValue and introduce functo…
Browse files Browse the repository at this point in the history
…r-based `pg.ContextualGetter`.

This CL allows users to create complex contextual value easily on top of functor. Example::

```
@pg.contextual_getter
def static_value(self, name, context, value):
  return value

class A(pg.Object):
  x: pg.ContextualValue = static_value(value=1)
```

PiperOrigin-RevId: 534158856
  • Loading branch information
daiyip authored and pyglove authors committed May 22, 2023
1 parent 1b7386b commit f50831a
Show file tree
Hide file tree
Showing 12 changed files with 278 additions and 161 deletions.
11 changes: 8 additions & 3 deletions pyglove/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,6 @@
ClassWrapper = symbolic.ClassWrapper
Functor = symbolic.Functor

# Contextual value marker.
Contextual = symbolic.Contextual

# Decorator for declaring symbolic. members for `pg.Object`.
members = symbolic.members

Expand Down Expand Up @@ -117,6 +114,7 @@
# Method for declaring a boilerplated class from a symbolic instance.
boilerplate_class = symbolic.boilerplate_class


#
# Context manager for swapping wrapped class with their wrappers.
#
Expand Down Expand Up @@ -164,6 +162,13 @@
Insertion = symbolic.Insertion
WritePermissionError = symbolic.WritePermissionError

# Contextual value marker.
ContextualValue = symbolic.ContextualValue
ContextualGetter = symbolic.ContextualGetter

# Decorator for making contextual getters.
contextual_getter = symbolic.contextual_getter


#
# Symbols from 'typing.py'
Expand Down
8 changes: 5 additions & 3 deletions pyglove/core/symbolic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@
from pyglove.core.symbolic.flags import auto_call_functors
from pyglove.core.symbolic.flags import should_call_functors_during_init

# Marker for contextual values.
from pyglove.core.symbolic.base import ContextualValue
from pyglove.core.symbolic.contextual import ContextualGetter
from pyglove.core.symbolic.contextual import contextual_getter

# Symbolic types and their definition helpers.
from pyglove.core.symbolic.base import Symbolic
from pyglove.core.symbolic.list import List
Expand Down Expand Up @@ -109,9 +114,6 @@
from pyglove.core.symbolic.base import load
from pyglove.core.symbolic.base import save

# Marker for contextual values.
from pyglove.core.symbolic.contextual import Contextual

# Interfaces for pure symbolic objects.
from pyglove.core.symbolic.pure_symbolic import PureSymbolic
from pyglove.core.symbolic.pure_symbolic import NonDeterministic
Expand Down
106 changes: 96 additions & 10 deletions pyglove/core/symbolic/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from pyglove.core import object_utils
from pyglove.core import typing as pg_typing
from pyglove.core.symbolic import flags
from pyglove.core.symbolic.contextual import Contextual
from pyglove.core.symbolic.origin import Origin
from pyglove.core.symbolic.pure_symbolic import NonDeterministic
from pyglove.core.symbolic.pure_symbolic import PureSymbolic
Expand Down Expand Up @@ -114,6 +113,92 @@ class DescendantQueryOption(enum.Enum):
LEAF = 2


class ContextualValue(
pg_typing.CustomTyping,
object_utils.JSONConvertible,
object_utils.Formattable,
):
"""Base class for contextual value markers.
Contextual value markers allows a symbolic attribute to be late bound
based on the entire symbolic tree where current symbolic value is part of.
As a result, users could access the late bound values directly through
symbolic attributes.
For example::
class A(pg.Object):
x: int
y: int = pg.ContextualValue()
# Not okay: `x` is not contextual and is not specified.
A()
# Okay: both `x` and `y` are specified.
A(x=1, y=2)
# Okay: `y` is contextual, hence optional.
a = A(x=1)
# Raises: `y` is neither specified during __init__
# nor provided from the context.
a.y
d = pg.Dict(y=2, z=pg.Dict(a=a))
# `a.y` now refers to `d.a` since `d` is in its symbolic parent chain,
# aka. context.
assert a.y == 2
"""

def get(self, name: str, context: 'Symbolic') -> Any:
"""Try get the contextual value for a symbolic attribute from a parent.
Args:
name: The name of the request symbolic attribute.
context: A symbolic parent which represents the current context.
Returns:
The value for the requested symbolic attribute from the current context.
If ``pg.MISSING_VALUE``, it means that current symbolic parent cannot
provide a contextual value for the attribute, so the current context
is moved upward, until it reaches to the root of the symbolic tree.
If a ``pg.ContextualValue`` object, it will use the new contextual
marker returned to resolve its value from its symbolic parent chains.
"""
return self.value_from(name, context)

def value_from(self, name: str, context: 'Symbolic') -> Any:
"""Try get the contextual value for a symbolic attribute from a parent."""
return getattr(context, name, pg_typing.MISSING_VALUE)

def custom_apply(self, *args, **kwargs: Any) -> Tuple[bool, Any]:
# This is to make a ``ContextualValue`` object assignable
# to any symbolic attribute.
return (False, self)

def format(
self,
compact: bool = False,
verbose: bool = True,
root_indent: int = 0,
**kwargs: Any,
) -> str:
del compact, verbose, root_indent, kwargs
return 'ContextualValue()'

def to_json(self, **kwargs: Any) -> Dict[str, Any]:
return self.to_json_dict({})

def __eq__(self, other: Any) -> bool:
# NOTE(daiyip): We do strict type match here since subclasses might
# have their own __eq__ logic.
return type(other) is ContextualValue # pylint: disable=unidiomatic-typecheck

def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)


class Symbolic(object_utils.JSONConvertible,
object_utils.MaybePartial,
object_utils.Formattable):
Expand Down Expand Up @@ -400,7 +485,7 @@ def sym_getattr(
def sym_contextual_hasattr(
self,
key: Union[str, int],
getter: Optional[Contextual] = None,
getter: Optional[ContextualValue] = None,
start: Union[
'Symbolic', object_utils.MissingValue
] = pg_typing.MISSING_VALUE,
Expand All @@ -409,14 +494,15 @@ def sym_contextual_hasattr(
Args:
key: Key of symbolic attribute.
getter: An optional ``Contextual`` object as the value retriever.
getter: An optional ``ContextualValue`` object as the value retriever.
start: An object from current object to the root of the composition as the
starting point of context lookup (upward). If ``pg.MISSING_VALUE``, it
will start with current node.
Returns:
True if the attribute exists. Otherwise False.
"""
getter = getter or ContextualValue()
v = self.sym_contextual_getattr(
key, default=(pg_typing.MISSING_VALUE,), getter=getter, start=start
)
Expand All @@ -426,7 +512,7 @@ def sym_contextual_getattr(
self,
key: Union[str, int],
default: Any = object_utils.MISSING_VALUE,
getter: Optional[Contextual] = None,
getter: Optional[ContextualValue] = None,
start: Union[
'Symbolic', object_utils.MissingValue
] = pg_typing.MISSING_VALUE,
Expand All @@ -437,7 +523,7 @@ def sym_contextual_getattr(
key: Key of symbolic attribute.
default: Default value if attribute does not exist. If absent,
`AttributeError` will be thrown.
getter: An optional ``Contextual`` object as the value retriever.
getter: An optional ``ContextualValue`` object as the value retriever.
start: An object from current object to the root of the composition as the
starting point of context lookup (upward). If ``pg.MISSING_VALUE``, it
will start with current node.
Expand All @@ -450,18 +536,18 @@ def sym_contextual_getattr(
AttributeError if `key` does not exist along the parent chain and
default value is not ``pg.MISSING_VALUE``.
"""
getter = getter or Contextual()
getter = getter or ContextualValue()
if start == pg_typing.MISSING_VALUE:
current = self
else:
current = typing.cast(Symbolic, start)

while current is not None:
v = getter.value_from(key, current)
# NOTE(daiyip): when the contextual value from the parent returns
# another contextual object, we should follow the new return value's
v = getter.get(key, current)
# NOTE(daiyip): when the ContextualValue value from the parent returns
# another ContextualValue object, we should follow the new return value's
# instruction instead of the original one.
if isinstance(v, Contextual):
if isinstance(v, ContextualValue):
getter = v
elif v != object_utils.MISSING_VALUE:
return v
Expand Down
89 changes: 68 additions & 21 deletions pyglove/core/symbolic/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
"""Tests for pyglove.symbolic.base."""

import copy
import dataclasses
import unittest

from pyglove.core import object_utils
from pyglove.core import typing as pg_typing
from pyglove.core.symbolic.base import FieldUpdate
from pyglove.core.symbolic import base
from pyglove.core.symbolic.dict import Dict


Expand All @@ -28,7 +29,7 @@ class FieldUpdateTest(unittest.TestCase):
def test_basics(self):
x = Dict(x=1)
f = pg_typing.Field('x', pg_typing.Int())
update = FieldUpdate(object_utils.KeyPath('x'), x, f, 1, 2)
update = base.FieldUpdate(object_utils.KeyPath('x'), x, f, 1, 2)
self.assertEqual(update.path, 'x')
self.assertIs(update.target, x)
self.assertIs(update.field, f)
Expand All @@ -37,44 +38,90 @@ def test_basics(self):

def test_format(self):
self.assertEqual(
FieldUpdate(
object_utils.KeyPath('x'),
Dict(x=1), None, 1, 2).format(compact=True),
'FieldUpdate(parent_path=, path=x, old_value=1, new_value=2)')
base.FieldUpdate(
object_utils.KeyPath('x'), Dict(x=1), None, 1, 2
).format(compact=True),
'FieldUpdate(parent_path=, path=x, old_value=1, new_value=2)',
)

self.assertEqual(
FieldUpdate(
object_utils.KeyPath('a'),
Dict(x=Dict(a=1)).x, None, 1, 2).format(compact=True),
'FieldUpdate(parent_path=x, path=a, old_value=1, new_value=2)')
base.FieldUpdate(
object_utils.KeyPath('a'), Dict(x=Dict(a=1)).x, None, 1, 2
).format(compact=True),
'FieldUpdate(parent_path=x, path=a, old_value=1, new_value=2)',
)

def test_eq_ne(self):
x = Dict()
f = pg_typing.Field('x', pg_typing.Int())
self.assertEqual(
FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2),
FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2))
base.FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2),
base.FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2),
)

# Targets are not the same instance.
self.assertNotEqual(
FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2),
FieldUpdate(object_utils.KeyPath('a'), Dict(), f, 1, 2))
base.FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2),
base.FieldUpdate(object_utils.KeyPath('a'), Dict(), f, 1, 2),
)

# Fields are not the same instance.
self.assertNotEqual(
FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2),
FieldUpdate(object_utils.KeyPath('b'), x, copy.copy(f), 1, 2))
base.FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2),
base.FieldUpdate(object_utils.KeyPath('b'), x, copy.copy(f), 1, 2),
)

self.assertNotEqual(
FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2),
FieldUpdate(object_utils.KeyPath('a'), x, f, 0, 2))
base.FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2),
base.FieldUpdate(object_utils.KeyPath('a'), x, f, 0, 2),
)

self.assertNotEqual(
FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2),
FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 1))
base.FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2),
base.FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 1),
)

self.assertNotEqual(
FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2), Dict())
base.FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2), Dict()
)


class ContextualValueTest(unittest.TestCase):
"""Tests for `pg.symbolic.ContextualValue`."""

def test_str(self):
self.assertEqual(str(base.ContextualValue()), 'ContextualValue()')

def test_repr(self):
self.assertEqual(repr(base.ContextualValue()), 'ContextualValue()')

def test_eq(self):
self.assertEqual(base.ContextualValue(), base.ContextualValue())
self.assertNotEqual(base.ContextualValue(), 1)

def test_call(self):
@dataclasses.dataclass
class A:
x: int = 1
y: int = 2

self.assertEqual(base.ContextualValue().get('x', A()), 1)

def test_custom_typing(self):
v = base.ContextualValue()
self.assertIs(pg_typing.Int().apply(v), v)
self.assertIs(pg_typing.Str().apply(v), v)

def test_to_json(self):
self.assertEqual(
base.to_json(base.ContextualValue()),
{'_type': f'{base.ContextualValue.__module__}.ContextualValue'},
)

def test_from_json(self):
self.assertEqual(
base.from_json(base.ContextualValue().to_json()), base.ContextualValue()
)


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit f50831a

Please sign in to comment.