From 11f5df156e3f8b26b641c43f205653ccefde47bd Mon Sep 17 00:00:00 2001 From: Daiyi Peng Date: Fri, 8 Sep 2023 10:28:45 -0700 Subject: [PATCH] Supporting subclassed `pg.Functor`. Usage: ```python import pyglove as pg class Foo(pg.Functor): x: int y: int def _call(self) -> int: return self.x + self.y Foo(1, 2)() # Early bound. Foo()(1, 2) # Late bound. Foo(1)(y=2) # Partial bound. Foo(1, 2)(y=3, override_args=True) # Partial bound with override. ... ``` This CL also renames `pg.Functor.signature` to `pg.Functor.__signature__`. PiperOrigin-RevId: 563788916 --- pyglove/core/patching/rule_based.py | 7 +- pyglove/core/symbolic/functor.py | 225 ++++++++++++------ pyglove/core/symbolic/functor_test.py | 138 ++++++++--- pyglove/core/symbolic/symbolize_test.py | 13 +- pyglove/core/typing/callable_signature.py | 4 +- .../core/typing/callable_signature_test.py | 7 + pyglove/core/typing/value_specs_test.py | 4 +- 7 files changed, 278 insertions(+), 120 deletions(-) diff --git a/pyglove/core/patching/rule_based.py b/pyglove/core/patching/rule_based.py index f527b22..8d8b151 100644 --- a/pyglove/core/patching/rule_based.py +++ b/pyglove/core/patching/rule_based.py @@ -137,7 +137,8 @@ def validate(self, x: symbolic.Symbolic) -> None: def __call__( self, x: symbolic.Symbolic - ) -> Union[Dict[str, Any], Tuple[Dict[str, Any], Callable[[], None]]]: + ) -> Union[Dict[str, Any], + Tuple[Dict[str, Any], Callable[[Any], None]]]: """Override __call__ to get rebind dict.""" return super().__call__(x, override_args=True) @@ -218,7 +219,7 @@ def _decorator(fn): cls = functor_decorator(fn) _PATCHER_REGISTRY.register(name or fn.__name__, typing.cast(Type[Patcher], cls)) - arg_specs = cls.signature.args + arg_specs = cls.__signature__.args if len(arg_specs) < 1: raise TypeError( 'Patcher function should have at least 1 argument ' @@ -337,7 +338,7 @@ def from_uri(uri: str) -> Patcher: """Create a Patcher object from a URI-like string.""" name, args, kwargs = parse_uri(uri) patcher_cls = typing.cast(Type[Any], _PATCHER_REGISTRY.get(name)) - args, kwargs = parse_args(patcher_cls.signature, args, kwargs) + args, kwargs = parse_args(patcher_cls.__signature__, args, kwargs) return patcher_cls(object_utils.MISSING_VALUE, *args, **kwargs) diff --git a/pyglove/core/symbolic/functor.py b/pyglove/core/symbolic/functor.py index 4bb61b5..0cdfb84 100644 --- a/pyglove/core/symbolic/functor.py +++ b/pyglove/core/symbolic/functor.py @@ -14,9 +14,10 @@ """Symbolic function (Functor).""" import abc +import contextlib import functools import inspect - +import threading import types import typing from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union @@ -64,29 +65,52 @@ def sum(a, b=1, *args, **kwargs): sum()(1, 2, 3, 4) # returns 10: a=1, b=2, *args=[3, 4] sum(c=4)(1, 2, 3) # returns 10: a=1, b=2, *args=[3], **kwargs={'c': 4} + + Or created by subclassing ``pg.Functor``:: + + class Sum(pg.Functor): + a: int + b: int = 1 + + def _call(self) -> int: + return self.a + self.b + + Usage on subclassed functors is the same as functors created from functions. """ # Allow assignment on symbolic attributes. allow_symbolic_assignment = True - # Do not infer symbolic fields from annotations, since functors are - # created from function definition which does not have class-level attributes. - infer_symbolic_fields_from_annotations = False - - # Functor's schema will be inferred from the function signature based on - # `pg.functor` or `pg.symbolize`. Therefore we do not infer the schema - # automatically during class creation. - auto_schema = False - - # Signature of this function. - signature: pg_typing.Signature + # Key for storing override members during call. + _TLS_OVERRIDE_MEMBERS_KEY = '__override_members__' # # Customizable class traits. # + @classmethod + @property + def is_subclassed_functor(cls) -> bool: + """Returns True if this class is a subclassed Functor.""" + return cls.auto_schema + @classmethod def _update_signatures_based_on_schema(cls): + # Update the return value of subclassed functors. + if cls.is_subclassed_functor: # pylint: disable=using-constant-test + private_call_signature = pg_typing.Signature.from_callable( + cls._call, auto_typing=True + ) + if ( + len(private_call_signature.args) > 1 + or private_call_signature.kwonlyargs + ): + raise TypeError( + '`_call` of a subclassed Functor should take no argument. ' + f'Encountered: {cls._call}.' + ) + cls.__schema__.metadata['returns'] = private_call_signature.return_value + # Update __init_ signature. init_signature = pg_typing.Signature.from_schema( cls.__schema__, @@ -97,10 +121,14 @@ def _update_signatures_based_on_schema(cls): pseudo_init = init_signature.make_function(['pass']) + # Save the original `Functor.__init__` before overriding it. + if not hasattr(cls, '__orig_init__'): + setattr(cls, '__orig_init__', cls.__init__) + @object_utils.explicit_method_override @functools.wraps(pseudo_init) def _init(self, *args, **kwargs): - Functor.__init__(self, *args, **kwargs) + self.__class__.__orig_init__(self, *args, **kwargs) setattr(cls, '__init__', _init) @@ -112,7 +140,7 @@ def _init(self, *args, **kwargs): qualname=cls.__qualname__, is_method=False, ) - setattr(cls, 'signature', call_signature) + setattr(cls, '__signature__', call_signature) def __new__(cls, *args, **kwargs): instance = object.__new__(cls) @@ -149,38 +177,40 @@ def __init__( _ = kwargs.pop('allow_partial', None) varargs = None - if len(args) > len(self.signature.args): - if self.signature.varargs: - varargs = list(args[len(self.signature.args):]) - args = args[:len(self.signature.args)] + signature = self.__signature__ + if len(args) > len(signature.args): + if signature.varargs: + varargs = list(args[len(signature.args) :]) + args = args[: len(signature.args)] else: - arg_phrase = object_utils.auto_plural( - len(self.signature.args), 'argument') + arg_phrase = object_utils.auto_plural(len(signature.args), 'argument') was_phrase = object_utils.auto_plural(len(args), 'was', 'were') raise TypeError( - f'{self.signature.id}() takes {len(self.signature.args)} ' - f'positional {arg_phrase} but {len(args)} {was_phrase} given.') + f'{signature.id}() takes {len(signature.args)} ' + f'positional {arg_phrase} but {len(args)} {was_phrase} given.' + ) bound_kwargs = dict() for i, v in enumerate(args): if pg_typing.MISSING_VALUE != v: - bound_kwargs[self.signature.args[i].name] = v + bound_kwargs[signature.args[i].name] = v if varargs is not None: - bound_kwargs[self.signature.varargs.name] = varargs + bound_kwargs[signature.varargs.name] = varargs for k, v in kwargs.items(): if pg_typing.MISSING_VALUE != v: if k in bound_kwargs: raise TypeError( - f'{self.signature.id}() got multiple values for keyword ' - f'argument {k!r}.') + f'{signature.id}() got multiple values for keyword ' + f'argument {k!r}.' + ) bound_kwargs[k] = v default_args = set() non_default_args = set(bound_kwargs) - for arg_spec in self.signature.named_args: + for arg_spec in signature.named_args: if not arg_spec.value_spec.has_default: continue arg_name = arg_spec.name @@ -190,8 +220,8 @@ def __init__( default_args.add(arg_name) non_default_args.discard(arg_name) - if self.signature.varargs and not varargs: - default_args.add(self.signature.varargs.name) + if signature.varargs and not varargs: + default_args.add(signature.varargs.name) super().__init__(allow_partial=True, root_path=root_path, @@ -203,6 +233,19 @@ def __init__( self._override_args = override_args self._ignore_extra_args = ignore_extra_args + # For subclassed Functor, we use thread-local storage for storing temporary + # member overrides from the arguments during functor call. + self._tls = threading.local() if self.is_subclassed_functor else None + + def _sym_inferred(self, key: str, **kwargs: Any) -> Any: + """Overrides method to allow member overrides during call.""" + if self._tls is not None: + overrides = getattr(self._tls, Functor._TLS_OVERRIDE_MEMBERS_KEY, {}) + v = overrides.get(key, pg_typing.MISSING_VALUE) + if pg_typing.MISSING_VALUE != v: + return overrides[key] + return super()._sym_inferred(key, **kwargs) + def _sym_clone(self, deep: bool, memo: Any = None) -> 'Functor': """Override to copy bound args.""" other = super()._sym_clone(deep, memo) @@ -239,7 +282,7 @@ def _on_change( def __delattr__(self, name: str) -> None: """Discard a previously bound argument and reset to its default value.""" del self._sym_attributes[name] - if self.signature.get_value_spec(name).has_default: + if self.__signature__.get_value_spec(name).has_default: self._default_args.add(name) self._specified_args.discard(name) self._non_default_args.discard(name) @@ -316,19 +359,60 @@ def __call__(self, *args, **kwargs) -> Any: Raises: TypeError: got multiple values for arguments or extra argument name. """ + args, kwargs = self._parse_call_time_overrides(*args, **kwargs) + signature = self.__signature__ + + if self.is_subclassed_functor: + for arg_spec, arg_value in zip(signature.args, args): + kwargs[arg_spec.name] = arg_value + + # Temporarily override members with argument values from the call. + with self._apply_call_time_overrides_to_members(**kwargs): + return_value = self._call() + else: + return_value = self._call(*args, **kwargs) + + # Return value check. + if ( + signature.return_value + and flags.is_type_check_enabled() + and return_value != pg_typing.MISSING_VALUE + ): + return_value = signature.return_value.apply( + return_value, root_path=self.sym_path + 'returns' + ) + if flags.is_tracking_origin() and isinstance(return_value, base.Symbolic): + return_value.sym_setorigin(self, 'return') + return return_value + + @contextlib.contextmanager + def _apply_call_time_overrides_to_members(self, **kwargs): + """Overrides member values within the scope.""" + assert self._tls is not None + setattr(self._tls, Functor._TLS_OVERRIDE_MEMBERS_KEY, kwargs) + try: + yield + finally: + delattr(self._tls, Functor._TLS_OVERRIDE_MEMBERS_KEY) + + def _parse_call_time_overrides( + self, *args, **kwargs + ) -> Tuple[List[Any], Dict[str, Any]]: + """Parses positional and keyword arguments from call-time overrides.""" override_args = kwargs.pop('override_args', self._override_args) ignore_extra_args = kwargs.pop('ignore_extra_args', self._ignore_extra_args) - if len(args) > len(self.signature.args) and not self.signature.has_varargs: + signature = self.__signature__ + if len(args) > len(signature.args) and not signature.has_varargs: if ignore_extra_args: - args = args[:len(self.signature.args)] + args = args[: len(signature.args)] else: - arg_phrase = object_utils.auto_plural( - len(self.signature.args), 'argument') + arg_phrase = object_utils.auto_plural(len(signature.args), 'argument') was_phrase = object_utils.auto_plural(len(args), 'was', 'were') raise TypeError( - f'{self.signature.id}() takes {len(self.signature.args)} ' - f'positional {arg_phrase} but {len(args)} {was_phrase} given.') + f'{signature.id}() takes {len(signature.args)} ' + f'positional {arg_phrase} but {len(args)} {was_phrase} given.' + ) keyword_args = { k: v for k, v in self._sym_attributes.items() @@ -338,27 +422,29 @@ def __call__(self, *args, **kwargs) -> Any: # Work out varargs when positional arguments are provided. varargs = None - if self.signature.has_varargs: - varargs = list(args[len(self.signature.args):]) + if signature.has_varargs: + varargs = list(args[len(signature.args) :]) if flags.is_type_check_enabled(): varargs = [ - self.signature.varargs.value_spec.apply( - v, root_path=self.sym_path + self.signature.varargs.name) + signature.varargs.value_spec.apply( + v, root_path=self.sym_path + signature.varargs.name + ) for v in varargs ] - args = args[:len(self.signature.args)] + args = args[: len(signature.args)] # Convert positional arguments to keyword arguments so we can map them back # later. for i in range(len(args)): - arg_spec = self.signature.args[i] + arg_spec = signature.args[i] arg_name = arg_spec.name if arg_name in self._specified_args: if not override_args: raise TypeError( - f'{self.signature.id}() got new value for argument {arg_name!r} ' - f'from position {i}, but \'override_args\' is set to False. ' - f'Old value: {keyword_args[arg_name]!r}, new value: {args[i]!r}.') + f'{signature.id}() got new value for argument {arg_name!r} ' + f"from position {i}, but 'override_args' is set to False. " + f'Old value: {keyword_args[arg_name]!r}, new value: {args[i]!r}.' + ) arg_value = args[i] if flags.is_type_check_enabled(): arg_value = arg_spec.value_spec.apply( @@ -369,25 +455,26 @@ def __call__(self, *args, **kwargs) -> Any: if arg_name in self._specified_args: if not override_args: raise TypeError( - f'{self.signature.id}() got new value for argument {arg_name!r} ' - f'from keyword argument, while \'override_args\' is set to ' + f'{signature.id}() got new value for argument {arg_name!r} ' + "from keyword argument, while 'override_args' is set to " f'False. Old value: {keyword_args[arg_name]!r}, ' - f'new value: {arg_value!r}.') - arg_spec = self.signature.get_value_spec(arg_name) + f'new value: {arg_value!r}.' + ) + arg_spec = signature.get_value_spec(arg_name) if arg_spec and flags.is_type_check_enabled(): arg_value = arg_spec.apply( arg_value, root_path=self.sym_path + arg_name) keyword_args[arg_name] = arg_value elif not ignore_extra_args: raise TypeError( - f'{self.signature.id}() got an unexpected ' - f'keyword argument {arg_name!r}.') + f'{signature.id}() got an unexpected keyword argument {arg_name!r}.' + ) # Use positional arguments if possible. This allows us to handle varargs # with simplicity. list_args = [] missing_required_arg_names = [] - for arg in self.signature.args: + for arg in signature.args: if arg.name in keyword_args: list_args.append(keyword_args[arg.name]) del keyword_args[arg.name] @@ -401,26 +488,16 @@ def __call__(self, *args, **kwargs) -> Any: len(missing_required_arg_names), 'argument') args_str = object_utils.comma_delimited_str(missing_required_arg_names) raise TypeError( - f'{self.signature.id}() missing {len(missing_required_arg_names)} ' - f'required positional {arg_phrase}: {args_str}.') + f'{signature.id}() missing {len(missing_required_arg_names)} ' + f'required positional {arg_phrase}: {args_str}.' + ) - if self.signature.has_varargs: - prebound_varargs = keyword_args.pop(self.signature.varargs.name, None) + if signature.has_varargs: + prebound_varargs = keyword_args.pop(signature.varargs.name, None) varargs = varargs or prebound_varargs if varargs: list_args.extend(varargs) - - return_value = self._call(*list_args, **keyword_args) - if ( - self.signature.return_value - and flags.is_type_check_enabled() - and return_value != pg_typing.MISSING_VALUE - ): - return_value = self.signature.return_value.apply( - return_value, root_path=self.sym_path + 'returns') - if flags.is_tracking_origin() and isinstance(return_value, base.Symbolic): - return_value.sym_setorigin(self, 'return') - return return_value + return list_args, keyword_args def functor( @@ -536,7 +613,7 @@ def functor_class( @pg.functor([('c', pg.typing.Int(min_value=0), 'Arg c')]) def foo(a, b, c=1, **kwargs): return a + b + c + sum(kwargs.values()) - + assert foo.schema.fields() == [ pg.typing.Field('a', pg.Any(), 'Argument a'.), pg.typing.Field('b', pg.Any(), 'Argument b'.), @@ -584,6 +661,16 @@ def foo(a, b, c=1, **kwargs): class _Functor(base_class or Functor): """Functor wrapper for input function.""" + # The schema for function-based Functor will be inferred from the function + # signature. Therefore we do not infer the schema automatically during class + # creation. + auto_schema = False + + # Do not infer symbolic fields from annotations, since this functor is + # created from function definition which does not have class-level + # attributes. + infer_symbolic_fields_from_annotations = True + def _call(self, *args, **kwargs): return func(*args, **kwargs) diff --git a/pyglove/core/symbolic/functor_test.py b/pyglove/core/symbolic/functor_test.py index b32bd67..03f323f 100644 --- a/pyglove/core/symbolic/functor_test.py +++ b/pyglove/core/symbolic/functor_test.py @@ -15,6 +15,7 @@ import inspect import io +import typing import unittest from pyglove.core import object_utils @@ -104,20 +105,25 @@ def f(a, b, *args, c=0, **kwargs): pg_typing.Field(pg_typing.StrKey(), pg_typing.Any()), ], ) - self.assertEqual(f.signature.args, [ - pg_typing.Argument('a', pg_typing.Any()), - pg_typing.Argument('b', pg_typing.Any()) - ]) self.assertEqual( - f.signature.varargs, - pg_typing.Argument('args', pg_typing.Any())) + f.__signature__.args, + [ + pg_typing.Argument('a', pg_typing.Any()), + pg_typing.Argument('b', pg_typing.Any()), + ], + ) + self.assertEqual( + f.__signature__.varargs, pg_typing.Argument('args', pg_typing.Any()) + ) self.assertEqual( - f.signature.varkw, pg_typing.Argument('kwargs', pg_typing.Any())) + f.__signature__.varkw, pg_typing.Argument('kwargs', pg_typing.Any()) + ) self.assertEqual( - f.signature.kwonlyargs, - [pg_typing.Argument('c', pg_typing.Any(default=0))]) - self.assertIsNone(f.signature.return_value, None) - self.assertTrue(f.signature.has_varargs) + f.__signature__.kwonlyargs, + [pg_typing.Argument('c', pg_typing.Any(default=0))], + ) + self.assertIsNone(f.__signature__.return_value, None) + self.assertTrue(f.__signature__.has_varargs) self.assertIsInstance(f.partial(), Functor) self.assertEqual(f.partial()(1, 2), 3) self.assertEqual(f.partial(b=1)(1), 2) @@ -139,10 +145,13 @@ def test_full_typing(self): def f(a=1, b=2): return a + b - self.assertEqual(f.signature.args, [ - pg_typing.Argument('a', pg_typing.Int(default=1)), - pg_typing.Argument('b', pg_typing.Int(default=2)), - ]) + self.assertEqual( + f.__signature__.args, + [ + pg_typing.Argument('a', pg_typing.Int(default=1)), + pg_typing.Argument('b', pg_typing.Int(default=2)), + ], + ) self.assertEqual( list(f.__schema__.values()), [ @@ -150,9 +159,9 @@ def f(a=1, b=2): pg_typing.Field('b', pg_typing.Int(default=2)), ], ) - self.assertEqual(f.signature.return_value, pg_typing.Int()) - self.assertFalse(f.signature.has_varargs) - self.assertFalse(f.signature.has_varkw) + self.assertEqual(f.__signature__.return_value, pg_typing.Int()) + self.assertFalse(f.__signature__.has_varargs) + self.assertFalse(f.__signature__.has_varkw) self.assertEqual(f.partial()(), 3) self.assertEqual(f.partial(a=2)(b=2), 4) self.assertEqual(f.partial(a=3, b=2)(), 5) @@ -182,18 +191,19 @@ def f(a: int, *args, b: int = 2, **kwargs) -> int: ], ) self.assertEqual( - f.signature.args, - [pg_typing.Argument('a', pg_typing.Int())]) + f.__signature__.args, [pg_typing.Argument('a', pg_typing.Int())] + ) self.assertEqual( - f.signature.varargs, - pg_typing.Argument('args', pg_typing.Any())) + f.__signature__.varargs, pg_typing.Argument('args', pg_typing.Any()) + ) self.assertEqual( - f.signature.kwonlyargs, - [pg_typing.Argument('b', pg_typing.Int(default=2))]) + f.__signature__.kwonlyargs, + [pg_typing.Argument('b', pg_typing.Int(default=2))], + ) self.assertEqual( - f.signature.varkw, - pg_typing.Argument('kwargs', pg_typing.Any())) - self.assertEqual(f.signature.return_value, pg_typing.Int()) + f.__signature__.varkw, pg_typing.Argument('kwargs', pg_typing.Any()) + ) + self.assertEqual(f.__signature__.return_value, pg_typing.Int()) # Test runtime value check. with self.assertRaisesRegex(TypeError, 'Expect .* but encountered .*'): @@ -224,9 +234,9 @@ def f(a=1, b=2): pg_typing.Field('b', pg_typing.Int(default=2), 'another integer.'), ], ) - self.assertEqual(f.signature.return_value, pg_typing.Int()) - self.assertFalse(f.signature.has_varargs) - self.assertFalse(f.signature.has_varkw) + self.assertEqual(f.__signature__.return_value, pg_typing.Int()) + self.assertFalse(f.__signature__.has_varargs) + self.assertFalse(f.__signature__.has_varkw) self.assertEqual(f.partial()(), 3) self.assertEqual(f.partial(a=2)(b=2), 4) self.assertEqual(f.partial(a=3, b=2)(), 5) @@ -240,13 +250,16 @@ def test_partial_typing(self): def f(a, b=1, c=1): return a + b + c - self.assertEqual(f.signature.args, [ - pg_typing.Argument('a', pg_typing.Int()), - pg_typing.Argument('b', pg_typing.Any(default=1)), - pg_typing.Argument('c', pg_typing.Int(default=1)), - ]) - self.assertFalse(f.signature.has_varargs) - self.assertFalse(f.signature.has_varkw) + self.assertEqual( + f.__signature__.args, + [ + pg_typing.Argument('a', pg_typing.Int()), + pg_typing.Argument('b', pg_typing.Any(default=1)), + pg_typing.Argument('c', pg_typing.Int(default=1)), + ], + ) + self.assertFalse(f.__signature__.has_varargs) + self.assertFalse(f.__signature__.has_varkw) self.assertEqual( list(f.__schema__.values()), [ @@ -264,6 +277,52 @@ def f(a, b=1, c=1): TypeError, 'missing 1 required positional argument'): f.partial()() + def test_subclassed_functor_class(self): + class Foo(Functor): + x: int + y: int + + __kwargs__: typing.Any + + def _call(self) -> int: + return self.x + self.y + + print(Foo.__signature__) + self.assertEqual(Foo.__signature__.name, '__call__') + self.assertEqual(Foo.__signature__.qualname, Foo.__qualname__) + self.assertEqual(Foo.__signature__.module_name, Foo.__module__) + self.assertEqual( + Foo.__signature__.args, + [ + pg_typing.Argument('x', pg_typing.Int()), + pg_typing.Argument('y', pg_typing.Int()), + ], + ) + self.assertEqual( + Foo.__signature__.varkw, + pg_typing.Argument('kwargs', pg_typing.Any(annotation=typing.Any)), + ) + self.assertEqual(Foo.__signature__.return_value, pg_typing.Int()) + + foo = Foo(1, 2) + self.assertEqual(foo(), 3) + self.assertEqual(foo(2, override_args=True), 4) + self.assertEqual(foo.x, 1) + self.assertEqual(foo(y=3, override_args=True), 4) + + # Partially bound. + foo = Foo(y=2) + self.assertEqual(foo(1), 3) + + # Bad subclassed Functor. + + with self.assertRaisesRegex( + TypeError, '`_call` of a subclassed Functor should take no argument'): + + class Bar(Functor): # pylint: disable=unused-variable + def _call(self, x: int) -> int: + pass + def test_runtime_type_check(self): @pg_functor([ ('a', pg_typing.Int(min_value=0)), @@ -363,8 +422,9 @@ def test_as_functor(self): f = pg_as_functor(lambda x: x) self.assertIsInstance(f, Functor) self.assertEqual( - f.signature.args, [pg_typing.Argument('x', pg_typing.Any())]) - self.assertIsNone(f.signature.return_value) + f.__signature__.args, [pg_typing.Argument('x', pg_typing.Any())] + ) + self.assertIsNone(f.__signature__.return_value) self.assertEqual(f(1), 1) def test_bad_definition(self): diff --git a/pyglove/core/symbolic/symbolize_test.py b/pyglove/core/symbolic/symbolize_test.py index 39cc857..801f8d6 100644 --- a/pyglove/core/symbolic/symbolize_test.py +++ b/pyglove/core/symbolic/symbolize_test.py @@ -70,11 +70,14 @@ def test_symbolize_a_function_by_decorator_with_typing(self): def f(x, y): del x, y self.assertTrue(issubclass(f, Functor)) - self.assertEqual(f.signature.args, [ - (pg_typing.Argument('x', pg_typing.Int())), - (pg_typing.Argument('y', pg_typing.Str())) - ]) - self.assertEqual(f.signature.return_value, pg_typing.Int()) + self.assertEqual( + f.__signature__.args, + [ + (pg_typing.Argument('x', pg_typing.Int())), + (pg_typing.Argument('y', pg_typing.Str())), + ], + ) + self.assertEqual(f.__signature__.return_value, pg_typing.Int()) def test_symbolize_with_serialization_key(self): @pg_symbolize(serialization_key='BAR', additional_keys=['RRR']) diff --git a/pyglove/core/typing/callable_signature.py b/pyglove/core/typing/callable_signature.py index d6ea6e1..0f230a7 100644 --- a/pyglove/core/typing/callable_signature.py +++ b/pyglove/core/typing/callable_signature.py @@ -256,8 +256,8 @@ def from_callable( raise TypeError(f'{callable_object!r} is not callable.') if isinstance(callable_object, object_utils.Functor): - assert callable_object.signature is not None - return callable_object.signature + assert callable_object.__signature__ is not None + return callable_object.__signature__ func = callable_object if not inspect.isroutine(func): diff --git a/pyglove/core/typing/callable_signature_test.py b/pyglove/core/typing/callable_signature_test.py index ef6cda9..300befe 100644 --- a/pyglove/core/typing/callable_signature_test.py +++ b/pyglove/core/typing/callable_signature_test.py @@ -299,6 +299,13 @@ def test_bad_cases(self): class_schema.Schema([], metadata=dict(init_arg_list=['a'])), '__main__', 'foo') + class Foo: + __call__ = 1 + + with self.assertRaisesRegex( + TypeError, '.*__call__ is not a method'): + callable_signature.Signature.from_callable(Foo) + if __name__ == '__main__': unittest.main() diff --git a/pyglove/core/typing/value_specs_test.py b/pyglove/core/typing/value_specs_test.py index f33b9b7..7dd8cdc 100644 --- a/pyglove/core/typing/value_specs_test.py +++ b/pyglove/core/typing/value_specs_test.py @@ -2432,7 +2432,7 @@ def test_apply_on_functor(self): class FunctorWithRegularArgs(object_utils.Functor): - signature = callable_signature.Signature( + __signature__ = callable_signature.Signature( callable_type=callable_signature.CallableType.FUNCTION, name='foo', module_name='__main__', @@ -2478,7 +2478,7 @@ def test_apply_on_functor_with_varargs(self): class FunctorWithVarArgs(object_utils.Functor): - signature = callable_signature.Signature( + __signature__ = callable_signature.Signature( callable_type=callable_signature.CallableType.FUNCTION, name='foo', module_name='__main__',