Skip to content

Commit

Permalink
Add sym_ancestor and sym_descendants for easing topological query.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 516964164
  • Loading branch information
daiyip authored and pyglove authors committed Mar 15, 2023
1 parent 05d47c4 commit 40bc014
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 10 deletions.
48 changes: 38 additions & 10 deletions docs/learn/soop/som/operations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,24 @@ nodes in the tree. To do so, PyGlove maintains bi-directional links between
containing/contained symbolic objects, allowing update notifications to
propagate through them.

Root
====

Users can access the root of current symbolic tree via property
:attr:`sym_root <pyglove.symbolic.Symbolic.sym_root>`. For example::

assert zoo.sym_root is zoo
assert zoo.exhibits[0].sym_root is zoo
assert zoo.exhibits[0].animal.sym_root is zoo


Parent Node
===========

Property :attr:`sym_parent <pyglove.symbolic.Symbolic.sym_parent>` is the API for all
symbolic types to access their containing node (parent) in the tree. For example, the ``Cage``
object in the `zoo` has ``exhibits`` (a symbolic list) as its parent::
Similarly, users can access the immediate parent (the containing node) of a
symbolic value via property
:attr:`sym_parent <pyglove.symbolic.Symbolic.sym_parent>`. For example, the
``Cage`` object in the `zoo` has ``exhibits`` (a symbolic list) as its parent::

assert zoo.exhibits[0].sym_parent is zoo.exhibits[0]

Expand All @@ -57,16 +69,17 @@ object in the `zoo` has ``exhibits`` (a symbolic list) as its parent::
assert shark.sym_parent is pool
assert shark.sym_path == 'x'

Root
====

Similarly, users can access the root of current symbolic tree via
property :attr:`sym_root <pyglove.symbolic.Symbolic.sym_root>`. For example::
Ancestor
========

assert zoo.sym_root is zoo
assert zoo.exhibits[0].sym_root is zoo
assert zoo.exhibits[0].animal.sym_root is zoo
:meth:`sym_ancestor <pyglove.symbolic.Symbolic.sym_ancestor>` can be useful when
users require an ancestor in the containing chain that meets specific criteria
instead of the root or immediate parent. For instance, the following code
illustrates how to retrieve the nearest ``Zoo`` object from an ``Animal`` object
located in a zoo::

assert zoo.exhibts[0].animal.sym_ancestor(lambda x: isinstance(x, Zoo)) is zoo

Child Nodes
===========
Expand Down Expand Up @@ -163,6 +176,21 @@ For example::
zoo.sym_hasattr('name') == True
zoo.sym_getattr('name') == 'San Diego Zoo'


Descendants
===========

In addition to accessing immediate child nodes,
:meth:`sym_descendants <pyglove.symbolic.Symbolic.sym_descendants>`
is a handy tool to retrieve all nodes in the sub-tree. Users can also specify a filter
function (using the argument "where") and choose whether to include
intermediate nodes, leaves, or both in the returned nodes (using the argument
"option"). For instance, consider the following code, which demonstrates how to
select all animals from a zoo::

assert zoo.sym_descendants(lambda x: isinstance(x, Animal)) == [
Python('Bob', color='black'), Shark('Jack')]

Location
========

Expand Down
1 change: 1 addition & 0 deletions pyglove/core/symbolic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@

# Symbolic helper classes.
from pyglove.core.symbolic.base import FieldUpdate
from pyglove.core.symbolic.base import DescendantQueryOption
from pyglove.core.symbolic.base import TraverseAction
from pyglove.core.symbolic.list import Insertion
from pyglove.core.symbolic.diff import Diff
Expand Down
69 changes: 69 additions & 0 deletions pyglove/core/symbolic/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,19 @@ def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)


class DescendantQueryOption(enum.Enum):
"""Options for querying descendants through `sym_descendant`."""

# Returning all matched descendants.
ALL = 0

# Returning only the immediate matched descendants.
IMMEDIATE = 1

# Returning only the leaf matched descendants.
LEAF = 2


class Symbolic(object_utils.JSONConvertible,
object_utils.MaybePartial,
object_utils.Formattable):
Expand Down Expand Up @@ -246,6 +259,62 @@ def sym_root(self) -> 'Symbolic':
root = root.sym_parent
return root

def sym_ancestor(
self,
where: Optional[Callable[[Any], bool]] = None,
) -> Optional['Symbolic']:
"""Returns the nearest ancestor of specific classes."""
ancestor = self.sym_parent
where = where or (lambda x: True)
while ancestor is not None and not where(ancestor):
ancestor = ancestor.sym_parent
return ancestor

def sym_descendants(
self,
where: Optional[Callable[[Any], bool]] = None,
option: DescendantQueryOption = DescendantQueryOption.ALL,
include_self: bool = False) -> List[Any]:
"""Returns all descendants of specific classes.
Args:
where: Optional callable object as the filter of descendants to return.
option: Descendant query options, indicating whether all matched,
immediate matched or only the matched leaf nodes will be returned.
include_self: If True, `self` will be included in the query, otherwise
only strict descendants are included.
Returns:
A list of objects that match the descendant_cls.
"""
descendants = []
where = where or (lambda x: True)

def visit(k, v, p):
del k, p
if not where(v):
return TraverseAction.ENTER

if not include_self and self is v:
return TraverseAction.ENTER

if option == DescendantQueryOption.IMMEDIATE:
descendants.append(v)
return TraverseAction.CONTINUE

# Dealing with option = ALL or LEAF.
leaf_descendants = []
if isinstance(v, Symbolic):
leaf_descendants = v.sym_descendants(where, option)

if option is DescendantQueryOption.ALL or not leaf_descendants:
descendants.append(v)
descendants.extend(leaf_descendants)
return TraverseAction.CONTINUE

traverse(self, visit)
return descendants

@abc.abstractmethod
def sym_attr_field(self, key: Union[str, int]) -> Optional[pg_typing.Field]:
"""Returns the field definition for a symbolic attribute."""
Expand Down
118 changes: 118 additions & 0 deletions pyglove/core/symbolic/object_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,6 +1179,24 @@ class A(Object):
self.assertIs(a.x.x.x.sym_root, pa)
self.assertIs(a.x.x.x[0].sym_root, pa)

def test_sym_ancestor(self):

@pg_members([
('x', pg_typing.Any()),
])
class A(Object):
pass

a = A(dict(x=A([A(1)])))
self.assertIs(a.x.sym_ancestor(), a)
self.assertIs(a.x.sym_ancestor(lambda x: isinstance(x, A)), a)
self.assertIsNone(a.x.sym_ancestor(lambda x: isinstance(x, int)), a)
self.assertIs(a.x.x.sym_ancestor(lambda x: isinstance(x, A)), a)
self.assertIs(a.x.x.x[0].sym_ancestor(lambda x: isinstance(x, A)), a.x.x)
self.assertIs(a.x.x.x[0].sym_ancestor(lambda x: isinstance(x, list)),
a.x.x.x)
self.assertIs(a.x.x.x[0].sym_ancestor(lambda x: isinstance(x, dict)), a.x)

def test_sym_path(self):

@pg_members([
Expand Down Expand Up @@ -2268,6 +2286,106 @@ def test_bad_query(self):
pg_query(self._v, path_regex=r'x', custom_selector=lambda: True)


class SymDescendantsTests(unittest.TestCase):
"""Tests for `sym_descendants`."""

def setUp(self):
super().setUp()

@pg_members([
('x', pg_typing.Any()),
])
class A(Object):
pass

self._a = A(dict(x=A([A(1)]), y=[A(2)]))

def test_descendants_with_no_filter(self):
a = self._a
self.assertEqual(
a.sym_descendants(),
[
a.x,
a.x.x,
a.x.x.x,
a.x.x.x[0],
a.x.x.x[0].x,
a.x.y,
a.x.y[0],
a.x.y[0].x,
])

self.assertEqual(
a.sym_descendants(option=base.DescendantQueryOption.IMMEDIATE),
[a.x])

self.assertEqual(
a.sym_descendants(option=base.DescendantQueryOption.LEAF),
[
a.x.x.x[0].x,
a.x.y[0].x,
])

def test_descendants_with_filter(self):
a = self._a
where = lambda x: isinstance(x, a.__class__)
self.assertEqual(
a.sym_descendants(where),
[
a.x.x,
a.x.x.x[0],
a.x.y[0],
])

self.assertEqual(
a.sym_descendants(where, base.DescendantQueryOption.IMMEDIATE),
[
a.x.x,
a.x.y[0],
])

self.assertEqual(
a.sym_descendants(where, base.DescendantQueryOption.LEAF),
[
a.x.x.x[0],
a.x.y[0],
])

self.assertEqual(
a.sym_descendants(
where, base.DescendantQueryOption.IMMEDIATE, include_self=True),
[a])

def test_descendants_with_including_self(self):
a = self._a
self.assertEqual(
a.sym_descendants(include_self=True),
[
a,
a.x,
a.x.x,
a.x.x.x,
a.x.x.x[0],
a.x.x.x[0].x,
a.x.y,
a.x.y[0],
a.x.y[0].x,
])

self.assertEqual(
a.sym_descendants(
option=base.DescendantQueryOption.IMMEDIATE, include_self=True),
[a])

self.assertEqual(
a.sym_descendants(
option=base.DescendantQueryOption.LEAF, include_self=True),
[
a.x.x.x[0].x,
a.x.y[0].x,
])


class SerializationTest(unittest.TestCase):
"""Dedicated tests for `pg.Object` serialization."""

Expand Down

0 comments on commit 40bc014

Please sign in to comment.