Skip to content

Commit

Permalink
pg.Object to support symbolic field inference from annotations.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 532525691
  • Loading branch information
daiyip authored and pyglove authors committed May 16, 2023
1 parent 85e8e31 commit 27a7392
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 46 deletions.
4 changes: 4 additions & 0 deletions pyglove/core/symbolic/class_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,10 @@ class SubclassedWrapper(
# `user_cls` will be used.
use_symbolic_comparison = use_symbolic_comp

# Do not infer symbolic fields from annotations. This is because that
# symbolic fields are inspected from the `__init__`` signature.
infer_symbolic_fields_from_annotations = False

cls = SubclassedWrapper
cls.__name__ = class_name or user_cls.__name__
cls.__module__ = module_name or user_cls.__module__
Expand Down
4 changes: 4 additions & 0 deletions pyglove/core/symbolic/functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def sum(a, b=1, *args, **kwargs):
# Allow assignment on symbolic attributes.
allow_symbolic_assignment = True

# Do not infer symbolic fields from annotations, since functors are
# created from function definition which does not have class-level attributes.
infer_symbolic_fields_from_annotations = False

# Signature of this function.
signature: pg_typing.Signature

Expand Down
49 changes: 46 additions & 3 deletions pyglove/core/symbolic/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import abc
import functools

import inspect
import typing
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union

Expand Down Expand Up @@ -137,6 +137,21 @@ class Bar(Foo)
# `sym_ne` for `__ne__`, and `sym_hash` for `__hash__`.
use_symbolic_comparison = True

# If True, symbolic fields will be inferred from class annotations.
# It's an alternative way of declaring symbolic fields other than
# `pg.members`.
#
# e.g.::
#
# class A(pg.Object):
# x: int
# y: str
#
# Please note that class attributes in UPPER_CASE or starting with '_' will
# not be considered as symbolic fields even if they have annotations.

infer_symbolic_fields_from_annotations = True

@classmethod
def __init_subclass__(cls):
super().__init_subclass__()
Expand All @@ -153,7 +168,7 @@ def __init_subclass__(cls):

cls_schema = schema_utils.formalize_schema(
pg_typing.create_schema(
maybe_field_list=[],
maybe_field_list=cls._infer_fields_from_annotations(),
name=cls.type_name,
base_schema_list=base_schema_list,
allow_nonconst_keys=True,
Expand All @@ -165,6 +180,31 @@ def __init_subclass__(cls):
cls._update_init_signature_based_on_schema()
cls._generate_sym_attributes_if_enabled()

@classmethod
def _infer_fields_from_annotations(cls) -> List[pg_typing.Field]:
"""Infers symbolic fields from class annotations."""
if not cls.infer_symbolic_fields_from_annotations:
return []

# NOTE(daiyip): refer to https://docs.python.org/3/howto/annotations.html.
if hasattr(inspect, 'get_annotations'):
annotations = inspect.get_annotations(cls)
else:
annotations = cls.__dict__.get('__annotations__', {})

fields = []
for attr_name, attr_annotation in annotations.items():
# We consider class-level attributes in upper cases non-fields even
# when they appear with annotations.
if attr_name.isupper() or attr_name.startswith('_'):
continue
field = pg_typing.create_field((attr_name, attr_annotation))
attr_value = getattr(cls, attr_name, pg_typing.MISSING_VALUE)
if attr_value != pg_typing.MISSING_VALUE:
field.value.set_default(attr_value)
fields.append(field)
return fields

@classmethod
def _update_init_signature_based_on_schema(cls):
"""Updates the signature of `__init__` if needed."""
Expand Down Expand Up @@ -206,7 +246,10 @@ def _create_sym_attribute(attr_name, field):
for key, field in cls.schema.fields.items():
if isinstance(key, pg_typing.ConstStrKey):
attr_name = str(key)
if not hasattr(cls, attr_name):
attr_value = getattr(cls, attr_name, pg_typing.MISSING_VALUE)
if (attr_value == pg_typing.MISSING_VALUE
or (not inspect.isfunction(attr_value)
and not isinstance(attr_value, property))):
setattr(cls, attr_name, _create_sym_attribute(attr_name, field))

@classmethod
Expand Down
87 changes: 87 additions & 0 deletions pyglove/core/symbolic/object_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,93 @@ def __init__(self, x): # pylint: disable=super-init-not-called
ValueError, '.* should call `super.*__init__`'):
_ = B(1)

def test_infer_symbolic_fields_from_annotations(self):

class A(Object):
x: int
y: float = 0.0
z = 2
# P is upper-case, thus will not be treated as symbolic field.
P: int = 1
# _q starts with _, which will not be treated as symbolic field either.
_q: int = 2

self.assertEqual(
list(A.schema.fields.keys()),
['x', 'y'])

a = A(1)
self.assertEqual(a.x, 1)
self.assertEqual(a.y, 0.0)

a = A(2, y=1.0)
self.assertEqual(a.x, 2)
self.assertEqual(a.y, 1.0)

class B(A):
p: str = 'foo'
q: typing.Any = None

self.assertEqual(
list(B.schema.fields.keys()),
['x', 'y', 'p', 'q'],
)
b = B(1, q=2)
self.assertEqual(b.x, 1)
self.assertEqual(b.y, 0.0)
self.assertEqual(b.p, 'foo')
self.assertEqual(b.q, 2)

@pg_members([
('k', pg_typing.Int())
])
class C(B):
# Override the default value of 'y'.
y: float = 1.0

c = C(1, q=2, k=3)
self.assertEqual(c.x, 1)
self.assertEqual(c.y, 1.0)
self.assertEqual(c.q, 2)
self.assertEqual(c.k, 3)

@pg_members([
('e', pg_typing.Int())
])
class D(C):
f: int = 5

d = D(1, q=2, k=3, e=4)
self.assertEqual(d.x, 1)
self.assertEqual(d.y, 1.0)
self.assertEqual(d.q, 2)
self.assertEqual(d.k, 3)
self.assertEqual(d.e, 4)
self.assertEqual(d.f, 5)

def test_override_symbolic_attribute_with_property(self):

@pg_members([
('x', pg_typing.Int()),
('y', pg_typing.Int()),
('z', pg_typing.Int()),
])
class A(Object):

@property
def x(self):
return self.sym_init_args.x + 1

def z(self):
return self.sym_init_args.z + 2

a = A(1, 2, 3)
self.assertEqual(a.x, 2)
self.assertEqual(a.sym_init_args.x, 1)
self.assertEqual(a.y, 2)
self.assertEqual(a.z(), 5)
self.assertEqual(a.sym_init_args.z, 3)

def test_runtime_type_check(self):

@pg_members([
Expand Down
1 change: 1 addition & 0 deletions pyglove/core/typing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ class Foo(pg.Object):
from pyglove.core.typing.class_schema import ValueSpec
from pyglove.core.typing.class_schema import Field
from pyglove.core.typing.class_schema import Schema
from pyglove.core.typing.class_schema import create_field
from pyglove.core.typing.class_schema import create_schema
from pyglove.core.typing.class_schema import ForwardRef

Expand Down
112 changes: 69 additions & 43 deletions pyglove/core/typing/class_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,74 @@ def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)


def create_field(
maybe_field: Union[Field, Tuple], # pylint: disable=g-bare-generic
auto_typing: bool = True,
accept_value_as_annotation: bool = True
) -> Field:
"""Creates ``Field`` from its equivalence.
Args:
maybe_field: a ``Field`` object or its equivalence, which is a tuple of
2 - 4 elements:
`(<key>, <value>, [description], [metadata])`.
`key` can be a KeySpec subclass object or string. `value` can be a
ValueSpec subclass object or equivalent value. (see
``ValueSpec.from_value`` method). `description` is the description of this
field. It can be optional when this field overrides the default value of a
field defined in parent schema. `metadata` is an optional field which is a
dict of user objects.
auto_typing: If True, infer value spec from Python annotations. Otherwise,
``pg.typing.Any()`` will be used.
accept_value_as_annotation: If True, allow default values to be used as
annotations when creating the value spec.
Returns:
A ``Field`` object.
"""
if isinstance(maybe_field, Field):
return maybe_field

if not isinstance(maybe_field, tuple):
raise TypeError(
f'Field definition should be tuples with 2 to 4 elements. '
f'Encountered: {maybe_field}.')

if len(maybe_field) == 4:
maybe_key_spec, maybe_value_spec, description, field_metadata = maybe_field
elif len(maybe_field) == 3:
maybe_key_spec, maybe_value_spec, description = maybe_field
field_metadata = {}
elif len(maybe_field) == 2:
maybe_key_spec, maybe_value_spec = maybe_field
description = None
field_metadata = {}
else:
raise TypeError(
f'Field definition should be tuples with 2 to 4 elements. '
f'Encountered: {maybe_field}.')
key = None
if isinstance(maybe_key_spec, (str, KeySpec)):
key = maybe_key_spec
else:
raise TypeError(
f'The 1st element of field definition should be of '
f'<class \'str\'> or KeySpec. Encountered: {maybe_key_spec}.')
value = ValueSpec.from_annotation(
maybe_value_spec,
auto_typing=auto_typing,
accept_value_as_annotation=accept_value_as_annotation)
if (description is not None and
not isinstance(description, str)):
raise TypeError(f'Description (the 3rd element) of field definition '
f'should be text type. Encountered: {description}')
if not isinstance(field_metadata, dict):
raise TypeError(f'Metadata (the 4th element) of field definition '
f'should be a dict of objects. '
f'Encountered: {field_metadata}')
return Field(key, value, description, field_metadata)


def create_schema(
maybe_field_list: List[Union[Field, Tuple]], # pylint: disable=g-bare-generic
name: Optional[str] = None,
Expand Down Expand Up @@ -1247,50 +1315,8 @@ def create_schema(
raise TypeError(f'Metadata of schema should be a dict. '
f'Encountered: {metadata}.')

fields = []
for maybe_field in maybe_field_list:
if isinstance(maybe_field, Field):
fields.append(maybe_field)
continue
if not isinstance(maybe_field, tuple):
raise TypeError(
f'Field definition should be tuples with 2 to 4 elements. '
f'Encountered: {maybe_field}.')

if len(maybe_field) == 4:
(maybe_key_spec, maybe_value_spec, description,
field_metadata) = maybe_field
elif len(maybe_field) == 3:
maybe_key_spec, maybe_value_spec, description = maybe_field
field_metadata = {}
elif len(maybe_field) == 2:
maybe_key_spec, maybe_value_spec = maybe_field
description = None
field_metadata = {}
else:
raise TypeError(
f'Field definition should be tuples with 2 to 4 elements. '
f'Encountered: {maybe_field}.')
key = None
if isinstance(maybe_key_spec, (str, KeySpec)):
key = maybe_key_spec
else:
raise TypeError(
f'The 1st element of field definition should be of '
f'<class \'str\'> or KeySpec. Encountered: {maybe_key_spec}.')
value = ValueSpec.from_annotation(
maybe_value_spec, True, accept_value_as_annotation=True)
if (description is not None and
not isinstance(description, str)):
raise TypeError(f'Description (the 3rd element) of field definition '
f'should be text type. Encountered: {description}')
if not isinstance(field_metadata, dict):
raise TypeError(f'Metadata (the 4th element) of field definition '
f'should be a dict of objects. '
f'Encountered: {field_metadata}')
fields.append(Field(key, value, description, field_metadata))
return Schema(
fields=fields,
fields=[create_field(maybe_field) for maybe_field in maybe_field_list],
name=name,
base_schema_list=base_schema_list,
allow_nonconst_keys=allow_nonconst_keys,
Expand Down

0 comments on commit 27a7392

Please sign in to comment.