diff --git a/pyglove/core/__init__.py b/pyglove/core/__init__.py index 45bcf8c..4026e51 100644 --- a/pyglove/core/__init__.py +++ b/pyglove/core/__init__.py @@ -63,6 +63,7 @@ # Context manager for scoped flags. allow_partial = symbolic.allow_partial allow_writable_accessors = symbolic.allow_writable_accessors +auto_call_functors = symbolic.auto_call_functors notify_on_change = symbolic.notify_on_change enable_type_check = symbolic.enable_type_check track_origin = symbolic.track_origin diff --git a/pyglove/core/symbolic/__init__.py b/pyglove/core/symbolic/__init__.py index f9ce789..951e901 100644 --- a/pyglove/core/symbolic/__init__.py +++ b/pyglove/core/symbolic/__init__.py @@ -57,6 +57,8 @@ from pyglove.core.symbolic.flags import notify_on_change from pyglove.core.symbolic.flags import is_change_notification_enabled +from pyglove.core.symbolic.flags import auto_call_functors +from pyglove.core.symbolic.flags import should_call_functors_during_init # Symbolic types and their definition helpers. from pyglove.core.symbolic.base import Symbolic diff --git a/pyglove/core/symbolic/flags.py b/pyglove/core/symbolic/flags.py index 3857bc2..4ead4d7 100644 --- a/pyglove/core/symbolic/flags.py +++ b/pyglove/core/symbolic/flags.py @@ -122,6 +122,7 @@ def get_save_handler() -> Optional[Callable[..., Any]]: _TLS_ACCESSOR_WRITABLE = '_accessor_writable' _TLS_ALLOW_PARTIAL = '_allow_partial' _TLS_SEALED = '_sealed' +_TLS_AUTO_CALL_FUNCTORS = '_allow_auto_call_functors' def notify_on_change(enabled: bool = True) -> ContextManager[None]: @@ -332,3 +333,34 @@ class A(pg.Object): def is_under_partial_scope() -> Optional[bool]: """Return True if partial value is allowed in current context.""" return thread_local.get_value(_TLS_ALLOW_PARTIAL, None) + + +def auto_call_functors(enabled: bool = True) -> ContextManager[None]: + """Returns a context manager to enable or disable auto call for functors. + + `auto_call_functors` is thread-safe and can be nested. For example:: + + @pg.symbolize + def foo(x, y): + return x + y + + with pg.auto_call_functors(True): + a = foo(1, 2) + assert a == 3 + with pg.auto_call_functors(False): + b = foo(1, 2) + assert isinstance(b, foo) + + Args: + enabled: If True, enable auto call for functors. + Otherwise, auto call will be disabled. + + Returns: + A context manager for enabling/disabling auto call for functors. + """ + return thread_local.value_scope(_TLS_AUTO_CALL_FUNCTORS, enabled, False) + + +def should_call_functors_during_init() -> Optional[bool]: + """Return True functors should be automatically called during __init__.""" + return thread_local.get_value(_TLS_AUTO_CALL_FUNCTORS, None) diff --git a/pyglove/core/symbolic/flags_test.py b/pyglove/core/symbolic/flags_test.py index a559b81..4299ffd 100644 --- a/pyglove/core/symbolic/flags_test.py +++ b/pyglove/core/symbolic/flags_test.py @@ -125,6 +125,15 @@ def test_allow_partial(self): self.assertTrue(flags.is_under_partial_scope()) self.assertFalse(flags.is_under_partial_scope()) + def test_auto_call_functors(self): + self.assertFalse(flags.should_call_functors_during_init()) + with flags.auto_call_functors(True): + self.assertTrue(flags.should_call_functors_during_init()) + with flags.auto_call_functors(False): + self.assertFalse(flags.should_call_functors_during_init()) + self.assertTrue(flags.should_call_functors_during_init()) + self.assertFalse(flags.should_call_functors_during_init()) + if __name__ == '__main__': unittest.main() diff --git a/pyglove/core/symbolic/functor.py b/pyglove/core/symbolic/functor.py index 673d297..be985ae 100644 --- a/pyglove/core/symbolic/functor.py +++ b/pyglove/core/symbolic/functor.py @@ -71,6 +71,13 @@ def sum(a, b=1, *args, **kwargs): # Signature of this function. signature: pg_typing.Signature + def __new__(cls, *args, **kwargs): + instance = object.__new__(cls) + if flags.should_call_functors_during_init(): + instance.__init__(*args, **kwargs) + return instance() + return instance + def __init__( self, *args, diff --git a/pyglove/core/symbolic/functor_test.py b/pyglove/core/symbolic/functor_test.py index 1f7e83e..9eabdbe 100644 --- a/pyglove/core/symbolic/functor_test.py +++ b/pyglove/core/symbolic/functor_test.py @@ -19,6 +19,7 @@ from pyglove.core import object_utils from pyglove.core import typing as pg_typing +from pyglove.core.symbolic import flags from pyglove.core.symbolic.base import from_json_str as pg_from_json_str from pyglove.core.symbolic.dict import Dict from pyglove.core.symbolic.functor import as_functor as pg_as_functor @@ -810,6 +811,19 @@ class A(Object): x = f.partial(x=A.partial()) self.assertEqual(x.missing_values(), {'x.x': MISSING_VALUE}) + def test_auto_call(self): + @pg_functor + def f(x, y): + return x + y + + with flags.auto_call_functors(): + self.assertEqual(f(1, 2), 3) + with self.assertRaisesRegex( + TypeError, '.* missing 1 required positional argument'): + _ = f(1) # pylint: disable=no-value-for-parameter + + self.assertIsInstance(f(1, 2), Functor) + if __name__ == '__main__': unittest.main()