Skip to content

Commit

Permalink
Enhancement to pg.compound.
Browse files Browse the repository at this point in the history
- Avoids undesired side effect from calling `super()._onbound` in `pg.Compound._on_bound`.
- Supports using abstract class as the base class of a compound class.

PiperOrigin-RevId: 551410855
  • Loading branch information
daiyip authored and pyglove authors committed Jul 27, 2023
1 parent 3770087 commit 4f659ab
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 1 deletion.
20 changes: 19 additions & 1 deletion pyglove/core/symbolic/compounding.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,10 @@ class _Compound(Compound, base_class):
auto_schema = False

def _on_bound(self):
super()._on_bound()
# NOTE(daiyip): Do not call `super()._on_bound()` to avoid side effect.
# This is okay since all states are delegated to `self.decomposed`.
Compound._on_bound(self) # pylint: disable=protected-access

self._sym_decomposed = None

if not lazy_build:
Expand Down Expand Up @@ -188,6 +191,21 @@ def __getattribute__(self, name: str):
cls.auto_register = True
cls.apply_schema(schema)

# NOTE(daiyip): Override abstract methods as non-ops, so `cls` could have an
# abstract class as its base. We don't need to worry about the implementation
# of the abstract method, since it will be detoured to the decomposed object
# at runtime via `__getattribute__`.
for key in dir(cls):
attr = getattr(cls, key)
if getattr(attr, '__isabstractmethod__', False):
noop = lambda self, *args, **kwargs: None
if isinstance(attr, property):
noop = property(noop)
else:
assert inspect.isfunction(attr), (key, attr)
setattr(cls, key, noop)
abc.update_abstractmethods(cls)

if add_to_registry:
cls.register_for_deserialization(serialization_key, additional_keys)
return cls
Expand Down
49 changes: 49 additions & 0 deletions pyglove/core/symbolic/compounding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Tests for pyglove.compounding."""

import abc
import dataclasses
import unittest

Expand Down Expand Up @@ -62,6 +63,54 @@ def foo(x) -> Foo:
def bar(unused_x):
pass

def test_on_bound_side_effect_free(self):

class Foo(Object):
x: int

def _on_bound(self):
# Side effect.
super()._on_bound()
assert type(self) is Foo # pylint: disable=unidiomatic-typecheck

def hello(self):
return self.x

@pg_compound(Foo)
def foo(x):
return Foo(x)

# This does not trigger assertion.
self.assertEqual(foo(1).hello(), 1)

def test_use_abstract_base(self):

class Foo(metaclass=abc.ABCMeta):
@abc.abstractmethod
def foo(self, x):
pass

@property
@abc.abstractmethod
def bar(self):
pass

class Bar(Foo):
def foo(self, x):
return x

@property
def bar(self):
return 1

@pg_compound(Foo)
def bar():
return Bar()

b = bar()
self.assertEqual(b.bar, 1)
self.assertEqual(b.foo(1), 1)

def test_lazy_build(self):
count = dict(x=0)

Expand Down

0 comments on commit 4f659ab

Please sign in to comment.