Skip to content

Commit

Permalink
pg.Object: Allow member methods to be used as the default value for…
Browse files Browse the repository at this point in the history
… callable symbolic attributes.

This allows convenient provision of callable symbolic attributes.

PiperOrigin-RevId: 539811905
  • Loading branch information
daiyip authored and pyglove authors committed Jun 13, 2023
1 parent e6a9066 commit 9b2bd3b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
8 changes: 7 additions & 1 deletion pyglove/core/symbolic/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,20 @@ def _infer_fields_from_annotations(cls) -> List[pg_typing.Field]:
return fields

def _update_default_values_from_class_attributes(cls):
"""Updates the symbolic attribute defaults from class attributes."""
for field in cls.schema.fields.values():
if isinstance(field.key, pg_typing.ConstStrKey):
attr_name = field.key.text
attr_value = getattr(cls, attr_name, pg_typing.MISSING_VALUE)
if (
attr_value != pg_typing.MISSING_VALUE
and not isinstance(attr_value, property)
and not inspect.isfunction(attr_value)
and (
# This allows class methods to be used as callable
# symbolic attributes.
not inspect.isfunction(attr_value)
or isinstance(field.value, pg_typing.Callable)
)
):
field.value.set_default(attr_value)

Expand Down
22 changes: 22 additions & 0 deletions pyglove/core/symbolic/object_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,28 @@ def x(self):
pass
self.assertEqual(F().sym_init_args.x, 3)

@pg_members([
('x', pg_typing.Callable([pg_typing.Int()])),
('y', pg_typing.Int())
])
class G(Object):
pass

class H(G):

# Member method as the default value for callable symbolic attribute.
def x(self, v):
return self.sym_init_args.y + v * 2

# Member method will not override non-callable symbolic attribute.
def y(self):
return self.sym_init_args.y * 2

h = H(y=1)
self.assertEqual(h.x(1), 3)
self.assertEqual(h.y(), 2)
self.assertEqual(h.sym_init_args.x(h, 1), 3)

def test_override_symbolic_attribute_with_property(self):

@pg_members([
Expand Down

0 comments on commit 9b2bd3b

Please sign in to comment.