Skip to content

Commit

Permalink
Introducing pg.use_init_args decorator of pg.Object subclasses.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 549457752
  • Loading branch information
daiyip authored and pyglove authors committed Jul 19, 2023
1 parent e31fa99 commit 627990e
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 0 deletions.
3 changes: 3 additions & 0 deletions pyglove/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@
# Decorator for declaring symbolic. members for `pg.Object`.
members = symbolic.members

# Decorator for updating the __init__ signature of `pg.Object`.
use_init_args = symbolic.use_init_args

#
# Methods for making symbolic types.
#
Expand Down
1 change: 1 addition & 0 deletions pyglove/core/symbolic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
from pyglove.core.symbolic.object import ObjectMeta
from pyglove.core.symbolic.object import Object
from pyglove.core.symbolic.object import members
from pyglove.core.symbolic.object import use_init_args

from pyglove.core.symbolic.functor import Functor
from pyglove.core.symbolic.functor import functor
Expand Down
34 changes: 34 additions & 0 deletions pyglove/core/symbolic/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,3 +947,37 @@ def _decorator(cls):
)
return cls
return typing.cast(pg_typing.Decorator, _decorator)


def use_init_args(init_arg_list: Sequence[str]) -> pg_typing.Decorator:
"""Decorator for updating the `__init__` signature of a `pg.Object` subclass.
Examples::
@pg.use_init_args(['x', 'y', '*z'])
class Foo(pg.Object):
y: int
x: str
z: list[int]
f = Foo('abc', 1, 2, 3)
assert f.x == 'abc'
assert f.y == 1
assert f.z == [2, 3]
Args:
init_arg_list: A sequence of attribute names that will be used as the
positional arguments of `__init__`. The last element could be the name of
a list-type attribute, indicating it's used as `*args`. Keyword-only
arguments are not needed to be present in this list, which will be figured
out automatically based on class' schema.
Returns:
a decorator function that updates the `__init__` signature.
"""
def _decorator(cls):
schema_utils.update_schema(
cls, [], extend=True, init_arg_list=init_arg_list
)
return cls
return typing.cast(pg_typing.Decorator, _decorator)
18 changes: 18 additions & 0 deletions pyglove/core/symbolic/object_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pyglove.core.symbolic.list import List
from pyglove.core.symbolic.object import members as pg_members
from pyglove.core.symbolic.object import Object
from pyglove.core.symbolic.object import use_init_args as pg_use_init_args
from pyglove.core.symbolic.origin import Origin
from pyglove.core.symbolic.pure_symbolic import NonDeterministic
from pyglove.core.symbolic.pure_symbolic import PureSymbolic
Expand Down Expand Up @@ -1693,6 +1694,23 @@ class B(Object):
TypeError, 'got unexpected keyword argument: \'z\''):
_ = B(1, z=2)

def test_use_init_args(self):

@pg_use_init_args(['x', 'y', '*z'])
class A(Object):
y: int
x: str
z: list[str]
p: str
q: int

a = A('foo', 1, 'a', 'b', p='bar', q=2)
self.assertEqual(a.x, 'foo')
self.assertEqual(a.y, 1)
self.assertEqual(a.z, ['a', 'b'])
self.assertEqual(a.p, 'bar')
self.assertEqual(a.q, 2)

def test_serialization_key(self):

@pg_members([
Expand Down

0 comments on commit 627990e

Please sign in to comment.