diff --git a/pyglove/core/symbolic/compounding.py b/pyglove/core/symbolic/compounding.py index 9be498a..bea009d 100644 --- a/pyglove/core/symbolic/compounding.py +++ b/pyglove/core/symbolic/compounding.py @@ -154,7 +154,10 @@ class _Compound(Compound, base_class): auto_schema = False def _on_bound(self): - super()._on_bound() + # NOTE(daiyip): Do not call `super()._on_bound()` to avoid side effect. + # This is okay since all states are delegated to `self.decomposed`. + Compound._on_bound(self) # pylint: disable=protected-access + self._sym_decomposed = None if not lazy_build: @@ -188,6 +191,21 @@ def __getattribute__(self, name: str): cls.auto_register = True cls.apply_schema(schema) + # NOTE(daiyip): Override abstract methods as non-ops, so `cls` could have an + # abstract class as its base. We don't need to worry about the implementation + # of the abstract method, since it will be detoured to the decomposed object + # at runtime via `__getattribute__`. + for key in dir(cls): + attr = getattr(cls, key) + if getattr(attr, '__isabstractmethod__', False): + noop = lambda self, *args, **kwargs: None + if isinstance(attr, property): + noop = property(noop) + else: + assert inspect.isfunction(attr), (key, attr) + setattr(cls, key, noop) + abc.update_abstractmethods(cls) + if add_to_registry: cls.register_for_deserialization(serialization_key, additional_keys) return cls diff --git a/pyglove/core/symbolic/compounding_test.py b/pyglove/core/symbolic/compounding_test.py index ec8aff6..b9770fc 100644 --- a/pyglove/core/symbolic/compounding_test.py +++ b/pyglove/core/symbolic/compounding_test.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for pyglove.compounding.""" +import abc import dataclasses import unittest @@ -62,6 +63,54 @@ def foo(x) -> Foo: def bar(unused_x): pass + def test_on_bound_side_effect_free(self): + + class Foo(Object): + x: int + + def _on_bound(self): + # Side effect. + super()._on_bound() + assert type(self) is Foo # pylint: disable=unidiomatic-typecheck + + def hello(self): + return self.x + + @pg_compound(Foo) + def foo(x): + return Foo(x) + + # This does not trigger assertion. + self.assertEqual(foo(1).hello(), 1) + + def test_use_abstract_base(self): + + class Foo(metaclass=abc.ABCMeta): + @abc.abstractmethod + def foo(self, x): + pass + + @property + @abc.abstractmethod + def bar(self): + pass + + class Bar(Foo): + def foo(self, x): + return x + + @property + def bar(self): + return 1 + + @pg_compound(Foo) + def bar(): + return Bar() + + b = bar() + self.assertEqual(b.bar, 1) + self.assertEqual(b.foo(1), 1) + def test_lazy_build(self): count = dict(x=0)