Skip to content

Commit

Permalink
Add contextual value support in pg.Symbolic.
Browse files Browse the repository at this point in the history
Example:

```
class A(pg.Object):
  x: int
  y: int = pg.Contextual()

# Not okay: `x` is not contextual and is not specified.
A()

# Okay: both `x` and `y` are specified.
A(x=1, y=2)

# Okay: `y` is contextual, hence optional.
a = A(x=1)

# Raises: `y` is neither specified during __init__ nor provided from the context.
a.y
```

PiperOrigin-RevId: 533922482
  • Loading branch information
daiyip authored and pyglove authors committed May 22, 2023
1 parent af410fe commit 1b7386b
Show file tree
Hide file tree
Showing 12 changed files with 433 additions and 28 deletions.
3 changes: 3 additions & 0 deletions pyglove/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@
ClassWrapper = symbolic.ClassWrapper
Functor = symbolic.Functor

# Contextual value marker.
Contextual = symbolic.Contextual

# Decorator for declaring symbolic. members for `pg.Object`.
members = symbolic.members

Expand Down
3 changes: 3 additions & 0 deletions pyglove/core/symbolic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@
from pyglove.core.symbolic.base import load
from pyglove.core.symbolic.base import save

# Marker for contextual values.
from pyglove.core.symbolic.contextual import Contextual

# Interfaces for pure symbolic objects.
from pyglove.core.symbolic.pure_symbolic import PureSymbolic
from pyglove.core.symbolic.pure_symbolic import NonDeterministic
Expand Down
93 changes: 87 additions & 6 deletions pyglove/core/symbolic/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
import json
import re
import sys
from typing import Any, Callable, Dict, Iterator, List, Optional, Type, Union, Tuple
import typing
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type, Union

from pyglove.core import object_utils
from pyglove.core import typing as pg_typing
from pyglove.core.symbolic import flags
from pyglove.core.symbolic.contextual import Contextual
from pyglove.core.symbolic.origin import Origin
from pyglove.core.symbolic.pure_symbolic import NonDeterministic
from pyglove.core.symbolic.pure_symbolic import PureSymbolic
Expand Down Expand Up @@ -349,15 +351,15 @@ def sym_get(
Args:
path: A KeyPath object or equivalence.
default: Default value if path does not exists. If absent, `KeyError`
will be thrown.
default: Default value if path does not exists. If absent, `KeyError` will
be thrown.
Returns:
Value of symbolic attribute specified by path if found, otherwise the
default value if it's specified.
Raises:
KeyError if `path` does not exist and default value is `pg.MISSING_VALUE`.
KeyError if `path` does not exist and `default` is `pg.MISSING_VALUE`.
"""
path = object_utils.KeyPath.from_value(path)
if default == object_utils.MISSING_VALUE:
Expand All @@ -378,14 +380,14 @@ def sym_getattr(
Args:
key: Key of symbolic attribute.
default: Default value if attribute does not exist. If absent,
`AttributeError` will be thrown.
Returns:
Value of symbolic attribute if found, otherwise the default value
if it's specified.
Raises:
AttributeError if `key` does not exist.
AttributeError if `key` does not exist and `default` is
``pg.MISSING_VALUE``.
"""
if not self.sym_hasattr(key):
if default != object_utils.MISSING_VALUE:
Expand All @@ -395,6 +397,85 @@ def sym_getattr(
f'{self.__class__!r} object has no symbolic attribute {key!r}.'))
return self._sym_getattr(key)

def sym_contextual_hasattr(
self,
key: Union[str, int],
getter: Optional[Contextual] = None,
start: Union[
'Symbolic', object_utils.MissingValue
] = pg_typing.MISSING_VALUE,
) -> bool:
"""Returns True if an attribute exists from current object's context.
Args:
key: Key of symbolic attribute.
getter: An optional ``Contextual`` object as the value retriever.
start: An object from current object to the root of the composition as the
starting point of context lookup (upward). If ``pg.MISSING_VALUE``, it
will start with current node.
Returns:
True if the attribute exists. Otherwise False.
"""
v = self.sym_contextual_getattr(
key, default=(pg_typing.MISSING_VALUE,), getter=getter, start=start
)
return v != (pg_typing.MISSING_VALUE,)

def sym_contextual_getattr(
self,
key: Union[str, int],
default: Any = object_utils.MISSING_VALUE,
getter: Optional[Contextual] = None,
start: Union[
'Symbolic', object_utils.MissingValue
] = pg_typing.MISSING_VALUE,
) -> Any:
"""Gets a key from current object's context (symbolic parent chain).
Args:
key: Key of symbolic attribute.
default: Default value if attribute does not exist. If absent,
`AttributeError` will be thrown.
getter: An optional ``Contextual`` object as the value retriever.
start: An object from current object to the root of the composition as the
starting point of context lookup (upward). If ``pg.MISSING_VALUE``, it
will start with current node.
Returns:
Value of symbolic attribute if found, otherwise the default value
if it's specified.
Raises:
AttributeError if `key` does not exist along the parent chain and
default value is not ``pg.MISSING_VALUE``.
"""
getter = getter or Contextual()
if start == pg_typing.MISSING_VALUE:
current = self
else:
current = typing.cast(Symbolic, start)

while current is not None:
v = getter.value_from(key, current)
# NOTE(daiyip): when the contextual value from the parent returns
# another contextual object, we should follow the new return value's
# instruction instead of the original one.
if isinstance(v, Contextual):
getter = v
elif v != object_utils.MISSING_VALUE:
return v
current = current.sym_parent

if default != object_utils.MISSING_VALUE:
return default
raise AttributeError(
self._error_message(
f'`{key}` is not found under its context '
'(along its symbolic parent chain).'
)
)

@abc.abstractmethod
def sym_keys(self) -> Iterator[Union[str, int]]:
"""Iterates the keys of symbolic attributes."""
Expand Down
86 changes: 86 additions & 0 deletions pyglove/core/symbolic/contextual.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2023 The PyGlove Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contextual value marker."""
from typing import Any, Callable, Optional, Tuple
import pyglove.core.typing as pg_typing

# The default contextual getter.
_DEFAULT_GETTER = lambda name, x: getattr(x, name, pg_typing.MISSING_VALUE)


class Contextual(pg_typing.CustomTyping):
"""Marker for values to be read from current field's symbolic parents.
Example::
class A(pg.Object):
x: int
y: int = pg.Contextual()
# Not okay: `x` is not contextual and is not specified.
A()
# Okay: both `x` and `y` are specified.
A(x=1, y=2)
# Okay: `y` is contextual, hence optional.
a = A(x=1)
# Raises: `y` is neither specified during __init__
# nor provided from the context.
a.y
"""

def __init__(self, getter: Optional[Callable[[str, Any], Any]] = None):
"""Constructor.
Args:
getter: An optional callable object to get the value of the request
attribute name from a symbolic parent, with signature: (attribute_name,
symbolic_parent) -> attribute_value If the getter returns
``pg.MISSING_VALUE` or a ``pg.Contextual`` object, the context will be
moved unto the parent's parent. If None, the getter will be quering the
attribute of the same name from the the parent.
"""
super().__init__()
self._getter = getter or _DEFAULT_GETTER

def custom_apply(self, *args, **kwargs) -> Tuple[bool, Any]:
# This is to make a ``Contextual`` object assignable
# to any symbolic attribute.
return (False, self)

def value_from(self, name: str, parent) -> Any:
"""Returns the contextual attribute value from the parent object.
Args:
name: The name of request attribute.
parent: Current context (symbolic parent).
Returns:
The value for the contextual attribute.
"""
return self._getter(name, parent)

def __repr__(self):
return str(self)

def __str__(self):
return 'CONTEXTUAL'

def __eq__(self, other):
return isinstance(other, Contextual) and self._getter == other._getter

def __ne__(self, other):
return not self.__eq__(other)
58 changes: 58 additions & 0 deletions pyglove/core/symbolic/contextual_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2023 The PyGlove Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for pyglove.symbolic.Contextual."""

import dataclasses
import unittest

from pyglove.core import typing as pg_typing
from pyglove.core.symbolic.contextual import Contextual


class ContextualTest(unittest.TestCase):
"""Tests for `pg.symbolic.Contextual`."""

def test_str(self):
self.assertEqual(str(Contextual()), 'CONTEXTUAL')
self.assertEqual(str(Contextual(lambda k, p: 1)), 'CONTEXTUAL')

def test_repr(self):
self.assertEqual(repr(Contextual()), 'CONTEXTUAL')
self.assertEqual(repr(Contextual(lambda k, p: 1)), 'CONTEXTUAL')

def test_eq(self):
self.assertEqual(Contextual(), Contextual())
getter = lambda k, p: 1
self.assertEqual(Contextual(getter), Contextual(getter))

self.assertNotEqual(Contextual(), 1)
self.assertNotEqual(Contextual(getter), Contextual())

def test_value_from(self):
@dataclasses.dataclass
class A:
x: int = 1
y: int = 2

self.assertEqual(Contextual().value_from('x', A()), 1)
self.assertEqual(Contextual(lambda k, p: p.y).value_from('x', A()), 2)

def test_custom_typing(self):
v = Contextual()
self.assertIs(pg_typing.Int().apply(v), v)
self.assertIs(pg_typing.Str().apply(v), v)


if __name__ == '__main__':
unittest.main()
33 changes: 25 additions & 8 deletions pyglove/core/symbolic/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pyglove.core import object_utils
from pyglove.core import typing as pg_typing
from pyglove.core.symbolic import base
from pyglove.core.symbolic import contextual
from pyglove.core.symbolic import flags


Expand Down Expand Up @@ -197,8 +198,10 @@ def __init__(self,
# If True, the parent of dict items should be set to `self.sym_parent`,
# This is useful when Dict is used as the field container of
# pg.Object.
self._set_raw_attr('_pass_through_parent',
kwargs.pop('pass_through_parent', False))
self._set_raw_attr(
'_as_object_attributes_container',
kwargs.pop('as_object_attributes_container', False),
)

if dict_obj is not None:
dict_obj = dict(dict_obj)
Expand Down Expand Up @@ -436,20 +439,20 @@ def sym_keys(self) -> Iterator[str]:
def sym_values(self) -> Iterator[Any]:
"""Iterates the values of symbolic attributes."""
for k in self.sym_keys():
yield self[k]
yield self._sym_getattr(k)

def sym_items(self) -> Iterator[
Tuple[str, Any]]:
"""Iterates the (key, value) pairs of symbolic attributes."""
for k in self.sym_keys():
yield k, self[k]
yield k, self._sym_getattr(k)

def sym_setparent(self, parent: base.Symbolic):
"""Override set parent of Dict to handle the passing through scenario."""
super().sym_setparent(parent)
# NOTE(daiyip): when flag `pass_through_parent` is on, it sets the parent
# of child symbolic values using its parent.
if self._pass_through_parent:
# NOTE(daiyip): when flag `as_object_attributes_container` is on, it sets
# the parent of child symbolic values using its parent.
if self._as_object_attributes_container:
for v in self.values():
if isinstance(v, base.Symbolic):
v.sym_setparent(parent)
Expand All @@ -464,7 +467,7 @@ def sym_hash(self) -> int:
def _sym_getattr( # pytype: disable=signature-mismatch # overriding-parameter-type-checks
self, key: str) -> Any:
"""Gets symbolic attribute by key."""
return self[key]
return super().__getitem__(key)

def _sym_clone(self, deep: bool, memo=None) -> 'Dict':
"""Override Symbolic._sym_clone."""
Expand Down Expand Up @@ -573,6 +576,20 @@ def _on_change(self, field_updates: typing.Dict[object_utils.KeyPath,
if self._onchange_callback:
self._onchange_callback(field_updates)

def __getitem__(self, key: str) -> Any:
"""Get item in this Dict."""
v = super().__getitem__(key)
if isinstance(v, contextual.Contextual):
start = self.sym_parent
# NOTE(daiyip): The parent of `pg.Object`'s attribute dict points to
# the `pg.Object` instance once it's set up. Here we let the ancester
# traversal to bypass `pg.Object` to avoid double entry, which causes
# dead loop.
if self._as_object_attributes_container and self.sym_parent:
start = start.sym_parent
v = self.sym_contextual_getattr(key, getter=v, start=start)
return v

def __setitem__(self, key: str, value: Any) -> None:
"""Set item in this Dict.
Expand Down
Loading

0 comments on commit 1b7386b

Please sign in to comment.