Skip to content

Commit

Permalink
Use __schema__, __type_name__ and __serialization_key__ as the …
Browse files Browse the repository at this point in the history
…canonical names for symbolic class attributes.

This avoids name conflict with user defined class attributes.

PiperOrigin-RevId: 560511511
  • Loading branch information
daiyip authored and pyglove authors committed Aug 27, 2023
1 parent 1cc5aab commit bdffd58
Show file tree
Hide file tree
Showing 22 changed files with 452 additions and 380 deletions.
2 changes: 1 addition & 1 deletion docs/learn/soop/som/types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ or different default values. For example::
# x : pg.typing.Int(min_value=1, max_value=10, default=1))
# y : pg.typing.Float(min_value=0)
# z : pg.typing.Str().noneable()
print(Bar.schema)
print(Bar.__schema__)


Symbolizing a Regular Class
Expand Down
6 changes: 3 additions & 3 deletions docs/learn/soop/som/validation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ during the declaration. For example::
class A(pg.Object):
pass

print(A.schema)
print(A.__schema__)

@pg.symbolize([
('a', pg.typing.Int()),
Expand All @@ -44,7 +44,7 @@ during the declaration. For example::
def foo(a, b):
return a + b

print(foo.schema)
print(foo.__schema__)


Key and Value Specifications
Expand Down Expand Up @@ -161,7 +161,7 @@ The code snippet below illustrates schema inheritance during subclassing::
class B(A):
pass

assert B.schema.fields.keys() == ['x', 'y', 'z']
assert B.__schema__.fields.keys() == ['x', 'y', 'z']

@pg.members([
# Raises: 'z' is frozen in class B and cannot be extended further.
Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks/ml/symbolic_ml.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@
}
],
"source": [
"print(Experiment.schema)"
"print(Experiment.__schema__)"
]
},
{
Expand Down
4 changes: 3 additions & 1 deletion pyglove/core/geno/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,7 +1406,9 @@ def sym_jsonify(

if type_info:
json_value = {
object_utils.JSONConvertible.TYPE_NAME_KEY: self.__class__.type_name,
object_utils.JSONConvertible.TYPE_NAME_KEY: (
self.__class__.__serialization_key__
),
'format': 'compact',
'value': symbolic.to_json(value),
}
Expand Down
18 changes: 12 additions & 6 deletions pyglove/core/object_utils/json_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def to_json_dict(
exclude_keys: Optional[Set[str]] = None,
**kwargs) -> Dict[str, JSONValueType]:
"""Helper method to create JSON dict from class and field."""
json_dict = {JSONConvertible.TYPE_NAME_KEY: _type_name(cls)}
json_dict = {JSONConvertible.TYPE_NAME_KEY: _serialization_key(cls)}
exclude_keys = exclude_keys or set()
if exclude_default:
for k, (v, default) in fields.items():
Expand All @@ -253,15 +253,21 @@ def to_json_dict(
def __init_subclass__(cls):
super().__init_subclass__()
if not inspect.isabstract(cls) and cls.auto_register:
type_name = _type_name(cls)
type_name = _serialization_key(cls)
JSONConvertible.register(type_name, cls, override_existing=True)


def _type_name(type_or_function: Union[Type[Any], types.FunctionType]) -> str:
def _serialization_key(
type_or_function: Union[Type[Any], types.FunctionType]) -> str:
"""Returns the ID for a type or function."""
type_name = getattr(type_or_function, 'type_name', None)
if type_name is not None:
return type_name
serializaton_key = getattr(type_or_function, '__serialization_key__', None)
if serializaton_key is not None:
return serializaton_key
return _type_name(type_or_function)


def _type_name(
type_or_function: Union[Type[Any], types.FunctionType]) -> str:
return f'{type_or_function.__module__}.{type_or_function.__qualname__}'


Expand Down
18 changes: 10 additions & 8 deletions pyglove/core/patching/rule_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,22 +221,25 @@ def _decorator(fn):
arg_specs = cls.signature.args
if len(arg_specs) < 1:
raise TypeError(
f'Patcher function should have at least 1 argument '
f'as patching target. (Patcher={cls.type_name!r})')
'Patcher function should have at least 1 argument '
f'as patching target. (Patcher={cls.__type_name__!r})'
)
if not _is_patcher_target_spec(arg_specs[0].value_spec):
raise TypeError(
f'{arg_specs[0].value_spec!r} cannot be used for constraining '
f'Patcher target. (Patcher={cls.type_name!r}, '
f'Patcher target. (Patcher={cls.__type_name__!r}, '
f'Argument={arg_specs[0].name!r})\n'
f'Acceptable value spec types are: '
f'Any, Callable, Dict, Functor, List, Object.')
'Acceptable value spec types are: '
'Any, Callable, Dict, Functor, List, Object.'
)
for arg_spec in arg_specs[1:]:
if not _is_patcher_parameter_spec(arg_spec.value_spec):
raise TypeError(
f'{arg_spec.value_spec!r} cannot be used for constraining '
f'Patcher argument. (Patcher={cls.type_name!r}, '
f'Patcher argument. (Patcher={cls.__type_name__!r}, '
f'Argument={arg_spec.name!r})\n'
f'Consider to treat it as string and parse yourself.')
'Consider to treat it as string and parse yourself.'
)
return cls
return _decorator

Expand Down Expand Up @@ -462,4 +465,3 @@ def parse_list(string: str,
if string:
return [convert_fn(i, piece) for i, piece in enumerate(string.split(':'))]
return []

40 changes: 21 additions & 19 deletions pyglove/core/symbolic/boilerplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,16 @@ class A(pg.Object):
value: Value that is used as the default value of the boilerplate class.
init_arg_list: An optional list of strings as __init__ positional arguments
names.
**kwargs: Keyword arguments for infrequently used options.
Acceptable keywords are:
* `serialization_key`: An optional string to be used as the serialization
key for the class during `sym_jsonify`. If None, `cls.type_name` will
be used. This is introduced for scenarios when we want to relocate a
class, before the downstream can recognize the new location, we need
the class to serialize it using previous key.
* `additional_keys`: An optional list of strings as additional keys to
deserialize an object of the registered class. This can be useful
when we need to relocate or rename the registered class while being able
to load existing serialized JSON values.
**kwargs: Keyword arguments for infrequently used options. Acceptable
keywords are: * `serialization_key`: An optional string to be used as the
serialization key for the class during `sym_jsonify`. If None,
`cls.__type_name__` will be used. This is introduced for scenarios when we
want to relocate a class, before the downstream can recognize the new
location, we need the class to serialize it using previous key. *
`additional_keys`: An optional list of strings as additional keys to
deserialize an object of the registered class. This can be useful when we
need to relocate or rename the registered class while being able to load
existing serialized JSON values.
Returns:
A class which extends the input value's type, with its schema's default
Expand Down Expand Up @@ -124,6 +122,8 @@ class _BoilerplateClass(base_cls):
cls_module = caller_module.__name__ if caller_module else '__main__'
cls = _BoilerplateClass
cls.__name__ = cls_name
cls.__qualname__ = cls.__qualname__.replace(
'boilerplate_class.<locals>._BoilerplateClass', cls_name)
cls.__module__ = cls_module

# Enable automatic registration for subclass.
Expand Down Expand Up @@ -151,17 +151,19 @@ def _freeze_field(path: object_utils.KeyPath,
pg_typing.MISSING_VALUE, use_default_apply=False)
return value

# NOTE(daiyip): we call `cls.schema.apply` to freeze fields that have default
# values. But we no longer need to formalize `cls.schema`, since it's
# copied from the boilerplate object's class which was already formalized.
# NOTE(daiyip): we call `cls.__schema__.apply` to freeze fields that have
# default values. But we no longer need to formalize `cls.__schema__`, since
# it's copied from the boilerplate object's class which was already
# formalized.
with flags.allow_writable_accessors():
cls.schema.apply(
cls.__schema__.apply(
value._sym_attributes, # pylint: disable=protected-access
allow_partial=allow_partial,
child_transform=_freeze_field)
child_transform=_freeze_field,
)

if init_arg_list is not None:
schema_utils.validate_init_arg_list(init_arg_list, cls.schema)
cls.schema.metadata['init_arg_list'] = init_arg_list
schema_utils.validate_init_arg_list(init_arg_list, cls.__schema__)
cls.__schema__.metadata['init_arg_list'] = init_arg_list
cls.register_for_deserialization(serialization_key, additional_keys)
return cls
78 changes: 50 additions & 28 deletions pyglove/core/symbolic/boilerplate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class BoilerplateClassTest(unittest.TestCase):

def test_basics(self):
self.assertTrue(issubclass(B, A))
self.assertEqual(B.type_name, 'pyglove.core.symbolic.boilerplate_test.B')
self.assertEqual(B.__type_name__, 'pyglove.core.symbolic.boilerplate_test.B')

with self.assertRaisesRegex(
ValueError,
Expand All @@ -76,32 +76,51 @@ def test_init_arg_list(self):
def test_schema(self):
# Boilerplate class' schema should carry the default value and be frozen.
self.assertEqual(
B.schema,
B.__schema__,
pg_typing.create_schema([
('a', pg_typing.Int()),
('b', pg_typing.Union(
[pg_typing.Int(), pg_typing.Str()], default='foo').freeze()),
('c', pg_typing.Dict([
('d', pg_typing.List(pg_typing.Dict([
('e', pg_typing.Float()),
('f', pg_typing.Bool())
]), default=List([Dict(e=1.0, f=True)])).freeze())
]).freeze())
]))
(
'b',
pg_typing.Union(
[pg_typing.Int(), pg_typing.Str()], default='foo'
).freeze(),
),
(
'c',
pg_typing.Dict([(
'd',
pg_typing.List(
pg_typing.Dict([
('e', pg_typing.Float()),
('f', pg_typing.Bool()),
]),
default=List([Dict(e=1.0, f=True)]),
).freeze(),
)]).freeze(),
),
]),
)

# Original class' schema should remain unchanged.
self.assertEqual(
A.schema,
A.__schema__,
pg_typing.create_schema([
('a', pg_typing.Int()),
('b', pg_typing.Union([pg_typing.Int(), pg_typing.Str()])),
('c', pg_typing.Dict([
('d', pg_typing.List(pg_typing.Dict([
('e', pg_typing.Float()),
('f', pg_typing.Bool())
])))
]))
]))
(
'c',
pg_typing.Dict([(
'd',
pg_typing.List(
pg_typing.Dict([
('e', pg_typing.Float()),
('f', pg_typing.Bool()),
])
),
)]),
),
]),
)

def test_init(self):
b = B(0)
Expand All @@ -124,21 +143,24 @@ def test_do_not_modify_original_object(self):

# Default value of the boilerplate class remain unchanged.
self.assertEqual(
B.schema['c'].default_value,
Dict.partial({'d': [{
'e': 1.0,
'f': True,
}]}, value_spec=B.schema['c'].value))
B.__schema__['c'].default_value,
Dict.partial(
{
'd': [{
'e': 1.0,
'f': True,
}]
},
value_spec=B.__schema__['c'].value,
),
)

# Original object remain unchanged.
self.assertTrue(template_object.c.d[0].f)

def test_serialization(self):
b = B(a=1)
self.assertEqual(b.to_json(), {
'_type': B.type_name,
'a': 1
})
self.assertEqual(b.to_json(), {'_type': B.__type_name__, 'a': 1})
self.assertEqual(pg_from_json_str(b.to_json_str()), b)


Expand Down
Loading

0 comments on commit bdffd58

Please sign in to comment.