Skip to content

Commit

Permalink
Introduce flags and enable auto call for functors during __init__.
Browse files Browse the repository at this point in the history
This allows functors to behave like normal functions under the `pg.auto_call_functors` context manager.

PiperOrigin-RevId: 519997836
  • Loading branch information
daiyip authored and pyglove authors committed Mar 28, 2023
1 parent 798de60 commit a00e787
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 0 deletions.
1 change: 1 addition & 0 deletions pyglove/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
# Context manager for scoped flags.
allow_partial = symbolic.allow_partial
allow_writable_accessors = symbolic.allow_writable_accessors
auto_call_functors = symbolic.auto_call_functors
notify_on_change = symbolic.notify_on_change
enable_type_check = symbolic.enable_type_check
track_origin = symbolic.track_origin
Expand Down
2 changes: 2 additions & 0 deletions pyglove/core/symbolic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
from pyglove.core.symbolic.flags import notify_on_change
from pyglove.core.symbolic.flags import is_change_notification_enabled

from pyglove.core.symbolic.flags import auto_call_functors
from pyglove.core.symbolic.flags import should_call_functors_during_init

# Symbolic types and their definition helpers.
from pyglove.core.symbolic.base import Symbolic
Expand Down
32 changes: 32 additions & 0 deletions pyglove/core/symbolic/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def get_save_handler() -> Optional[Callable[..., Any]]:
_TLS_ACCESSOR_WRITABLE = '_accessor_writable'
_TLS_ALLOW_PARTIAL = '_allow_partial'
_TLS_SEALED = '_sealed'
_TLS_AUTO_CALL_FUNCTORS = '_allow_auto_call_functors'


def notify_on_change(enabled: bool = True) -> ContextManager[None]:
Expand Down Expand Up @@ -332,3 +333,34 @@ class A(pg.Object):
def is_under_partial_scope() -> Optional[bool]:
"""Return True if partial value is allowed in current context."""
return thread_local.get_value(_TLS_ALLOW_PARTIAL, None)


def auto_call_functors(enabled: bool = True) -> ContextManager[None]:
"""Returns a context manager to enable or disable auto call for functors.
`auto_call_functors` is thread-safe and can be nested. For example::
@pg.symbolize
def foo(x, y):
return x + y
with pg.auto_call_functors(True):
a = foo(1, 2)
assert a == 3
with pg.auto_call_functors(False):
b = foo(1, 2)
assert isinstance(b, foo)
Args:
enabled: If True, enable auto call for functors.
Otherwise, auto call will be disabled.
Returns:
A context manager for enabling/disabling auto call for functors.
"""
return thread_local.value_scope(_TLS_AUTO_CALL_FUNCTORS, enabled, False)


def should_call_functors_during_init() -> Optional[bool]:
"""Return True functors should be automatically called during __init__."""
return thread_local.get_value(_TLS_AUTO_CALL_FUNCTORS, None)
9 changes: 9 additions & 0 deletions pyglove/core/symbolic/flags_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,15 @@ def test_allow_partial(self):
self.assertTrue(flags.is_under_partial_scope())
self.assertFalse(flags.is_under_partial_scope())

def test_auto_call_functors(self):
self.assertFalse(flags.should_call_functors_during_init())
with flags.auto_call_functors(True):
self.assertTrue(flags.should_call_functors_during_init())
with flags.auto_call_functors(False):
self.assertFalse(flags.should_call_functors_during_init())
self.assertTrue(flags.should_call_functors_during_init())
self.assertFalse(flags.should_call_functors_during_init())


if __name__ == '__main__':
unittest.main()
7 changes: 7 additions & 0 deletions pyglove/core/symbolic/functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ def sum(a, b=1, *args, **kwargs):
# Signature of this function.
signature: pg_typing.Signature

def __new__(cls, *args, **kwargs):
instance = object.__new__(cls)
if flags.should_call_functors_during_init():
instance.__init__(*args, **kwargs)
return instance()
return instance

def __init__(
self,
*args,
Expand Down
14 changes: 14 additions & 0 deletions pyglove/core/symbolic/functor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from pyglove.core import object_utils
from pyglove.core import typing as pg_typing
from pyglove.core.symbolic import flags
from pyglove.core.symbolic.base import from_json_str as pg_from_json_str
from pyglove.core.symbolic.dict import Dict
from pyglove.core.symbolic.functor import as_functor as pg_as_functor
Expand Down Expand Up @@ -810,6 +811,19 @@ class A(Object):
x = f.partial(x=A.partial())
self.assertEqual(x.missing_values(), {'x.x': MISSING_VALUE})

def test_auto_call(self):
@pg_functor
def f(x, y):
return x + y

with flags.auto_call_functors():
self.assertEqual(f(1, 2), 3)
with self.assertRaisesRegex(
TypeError, '.* missing 1 required positional argument'):
_ = f(1) # pylint: disable=no-value-for-parameter

self.assertIsInstance(f(1, 2), Functor)


if __name__ == '__main__':
unittest.main()

0 comments on commit a00e787

Please sign in to comment.