Skip to content

Commit

Permalink
Fix init_arg_list inference: subclasses with new fields should not …
Browse files Browse the repository at this point in the history
…inherit base class `init_arg_list`.

PiperOrigin-RevId: 533608676
  • Loading branch information
daiyip authored and pyglove authors committed May 20, 2023
1 parent 8141116 commit 25e4c23
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 5 deletions.
19 changes: 16 additions & 3 deletions pyglove/core/symbolic/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def serialization_key(cls) -> str:
@property
def init_arg_list(cls) -> List[str]:
"""Gets __init__ positional argument list."""
return cls.schema.metadata['init_arg_list']
return typing.cast(List[str], cls.schema.metadata['init_arg_list'])

def apply_schema(cls, schema: pg_typing.Schema) -> None:
"""Applies a schema to a symbolic class.
Expand Down Expand Up @@ -320,14 +320,27 @@ def __init_subclass__(cls):
if isinstance(base_schema, pg_typing.Schema):
base_schema_list.append(base_schema)

new_fields = cls._infer_fields_from_annotations()
cls_schema = schema_utils.formalize_schema(
pg_typing.create_schema(
maybe_field_list=cls._infer_fields_from_annotations(),
maybe_field_list=new_fields,
name=cls.type_name,
base_schema_list=base_schema_list,
allow_nonconst_keys=True,
metadata={}))
metadata={},
)
)

# NOTE(daiyip): When new fields are added through class attributes.
# We invalidate `init_arg_list` so PyGlove could recompute it based
# on its schema during `apply_schema`. Otherwise, we inherit the
# `init_arg_list` from the base class.
# TODO(daiyip): detect new fields based on the differences from the base
# schema.
if new_fields:
cls_schema.metadata['init_arg_list'] = None
cls.apply_schema(cls_schema)

setattr(cls, '__serialization_key__', cls.type_name)

@classmethod
Expand Down
26 changes: 24 additions & 2 deletions pyglove/core/symbolic/object_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,12 @@ def __init__(self, x): # pylint: disable=super-init-not-called

def test_symbolic_fields_from_annotations(self):

class A(Object):
class X(Object):
pass

self.assertEqual(X.init_arg_list, [])

class A(X):
x: int
y: typing.Annotated[float, 'field y'] = 0.0
z = 2
Expand All @@ -226,6 +231,7 @@ class A(Object):
# _q starts with _, which will not be treated as symbolic field either.
_q: int = 2

self.assertEqual(A.init_arg_list, ['x', 'y'])
self.assertEqual(
list(A.schema.fields.keys()),
['x', 'y'])
Expand All @@ -243,6 +249,7 @@ class B(A):
p: str = 'foo'
q: typing.Any = None

self.assertEqual(B.init_arg_list, ['x', 'y', 'p', 'q'])
self.assertEqual(
list(B.schema.fields.keys()),
['x', 'y', 'p', 'q'],
Expand All @@ -260,6 +267,12 @@ class C(B):
# Override the default value of 'y'.
y: float = 1.0

self.assertEqual(
list(C.schema.fields.keys()),
['x', 'y', 'p', 'q', 'k'],
)
self.assertEqual(C.init_arg_list, ['x', 'y', 'p', 'q', 'k'])

c = C(1, q=2, k=3)
self.assertEqual(c.x, 1)
self.assertEqual(c.y, 1.0)
Expand All @@ -272,6 +285,11 @@ class C(B):
class D(C):
f: int = 5

self.assertEqual(D.init_arg_list, ['x', 'y', 'p', 'q', 'k', 'f', 'e'])
self.assertEqual(
list(D.schema.fields.keys()),
['x', 'y', 'p', 'q', 'k', 'f', 'e'],
)
d = D(1, q=2, k=3, e=4)
self.assertEqual(d.x, 1)
self.assertEqual(d.y, 1.0)
Expand All @@ -284,7 +302,11 @@ class E(Object):
__kwargs__: typing.Any
x: int

self.assertIsNotNone(E.schema.dynamic_field)
self.assertEqual(E.init_arg_list, ['x'])
self.assertEqual(
list(E.schema.fields.keys()),
[pg_typing.StrKey(), 'x'],
)
e = E(1, y=3)
self.assertEqual(e.x, 1)
self.assertEqual(e.y, 3)
Expand Down
4 changes: 4 additions & 0 deletions pyglove/core/symbolic/schema_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ def augment_schema(
if init_arg_list is None:
init_arg_list = metadata.get('init_arg_list', None)
metadata = object_utils.merge([schema.metadata, metadata])

# NOTE(daiyip): Consider to inherit `init_arg_list` from the parent when
# there is no new field.
metadata['init_arg_list'] = init_arg_list

return formalize_schema(
pg_typing.create_schema(
maybe_field_list=fields,
Expand Down

0 comments on commit 25e4c23

Please sign in to comment.