Skip to content

Commit

Permalink
Improve annotation inference to be compatible with module reloading.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 723558958
  • Loading branch information
daiyip authored and pyglove authors committed Feb 5, 2025
1 parent 97d6446 commit 65c7e52
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 26 deletions.
66 changes: 53 additions & 13 deletions pyglove/core/typing/annotation_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,24 +195,64 @@ def _type_id() -> str:
return t_id

def _resolve(type_id: str):

def _as_forward_ref() -> typing.ForwardRef:
return typing.ForwardRef(type_id, False, parent_module) # pytype: disable=not-callable

def _resolve_name(name: str, parent_obj: typing.Any):
if name == 'None':
return None
return None, True
if parent_obj is not None and hasattr(parent_obj, name):
return getattr(parent_obj, name)
return getattr(parent_obj, name), False
if hasattr(builtins, name):
return getattr(builtins, name)
return getattr(builtins, name), True
if type_id == '...':
return ...
return utils.MISSING_VALUE
parent_obj = parent_module
for name in type_id.split('.'):
parent_obj = _resolve_name(name, parent_obj)
if parent_obj == utils.MISSING_VALUE:
return typing.ForwardRef( # pytype: disable=not-callable
type_id, False, parent_module
)
return parent_obj
return ..., True
return utils.MISSING_VALUE, False

names = type_id.split('.')
if len(names) == 1:
reference, is_builtin = _resolve_name(names[0], parent_module)
if is_builtin:
return reference
if not is_builtin and (
# When reference is not found, we should treat it as a forward
# reference.
reference == utils.MISSING_VALUE
# When module is being reloaded, we should treat all non-builtin
# references as forward references.
or getattr(parent_module, '__reloading__', False)
):
return _as_forward_ref()
return reference

root_obj, _ = _resolve_name(names[0], parent_module)
# When root object is not found, we should treat it as a forward reference.
if root_obj == utils.MISSING_VALUE:
return _as_forward_ref()

parent_obj = root_obj
# When root object is a module, we should treat reference to its children
# as non-forward references.
if inspect.ismodule(root_obj):
for name in names[1:]:
parent_obj, _ = _resolve_name(name, parent_obj)
if parent_obj == utils.MISSING_VALUE:
raise TypeError(f'{type_id!r} does not exist.')
return parent_obj
# When root object is non-module variable of current module, and when the
# module is being reloaded, we should treat reference to its children as
# forward references.
elif getattr(parent_module, '__reloading__', False):
return _as_forward_ref()
# When root object is non-module variable of current module, we should treat
# unresolved reference to its children as forward references.
else:
for name in names[1:]:
parent_obj, _ = _resolve_name(name, parent_obj)
if parent_obj == utils.MISSING_VALUE:
return _as_forward_ref()
return parent_obj

root = _maybe_union()
if _pos() != len(s):
Expand Down
40 changes: 39 additions & 1 deletion pyglove/core/typing/annotation_conversion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@


class Foo:
pass
class Bar:
pass


_MODULE = sys.modules[__name__]
Expand Down Expand Up @@ -58,6 +59,18 @@ def test_basic_types(self):
annotation_conversion.annotation_from_str('tuple[int, str]'),
tuple[int, str]
)
self.assertEqual(
annotation_conversion.annotation_from_str('list[Foo]', _MODULE),
list[Foo]
)
self.assertEqual(
annotation_conversion.annotation_from_str('list[Foo.Bar]', _MODULE),
list[Foo.Bar]
)
self.assertEqual(
annotation_conversion.annotation_from_str('list[Foo.Baz]', _MODULE),
list[typing.ForwardRef('Foo.Baz', False, _MODULE)]
)

def test_generic_types(self):
self.assertEqual(
Expand Down Expand Up @@ -139,6 +152,28 @@ def test_forward_ref(self):
]
)

def test_reloading(self):
setattr(_MODULE, '__reloading__', True)
self.assertEqual(
annotation_conversion.annotation_from_str(
'typing.List[Foo]', _MODULE),
typing.List[
typing.ForwardRef(
'Foo', False, _MODULE
)
]
)
self.assertEqual(
annotation_conversion.annotation_from_str(
'typing.List[Foo.Bar]', _MODULE),
typing.List[
typing.ForwardRef(
'Foo.Bar', False, _MODULE
)
]
)
delattr(_MODULE, '__reloading__')

def test_bad_annotation(self):
with self.assertRaisesRegex(SyntaxError, 'Expected type identifier'):
annotation_conversion.annotation_from_str('typing.List[]')
Expand All @@ -152,6 +187,9 @@ def test_bad_annotation(self):
with self.assertRaisesRegex(SyntaxError, 'Expected "]"'):
annotation_conversion.annotation_from_str('typing.Callable[[x')

with self.assertRaisesRegex(TypeError, '.* does not exist'):
annotation_conversion.annotation_from_str('typing.Foo', _MODULE)


class FieldFromAnnotationTest(unittest.TestCase):
"""Tests for Field.fromAnnotation."""
Expand Down
20 changes: 8 additions & 12 deletions pyglove/core/typing/class_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ class ForwardRef(utils.Formattable):
def __init__(self, module: types.ModuleType, qualname: str):
self._module = module
self._qualname = qualname
self._resolved_value = None

@property
def module(self) -> types.ModuleType:
Expand Down Expand Up @@ -129,9 +128,7 @@ def as_annotation(self) -> Union[Type[Any], str]:
@property
def resolved(self) -> bool:
"""Returns True if the symbol for the name is resolved.."""
if self._resolved_value is None:
self._resolved_value = self._resolve()
return self._resolved_value is not None
return self._resolve() is not None

def _resolve(self) -> Optional[Any]:
names = self._qualname.split('.')
Expand All @@ -150,14 +147,13 @@ def _resolve(self) -> Optional[Any]:
@property
def cls(self) -> Type[Any]:
"""Returns the resolved reference class.."""
if self._resolved_value is None:
self._resolved_value = self._resolve()
if self._resolved_value is None:
raise TypeError(
f'{self.qualname!r} does not exist in '
f'module {self.module.__name__!r}'
)
return self._resolved_value
reference = self._resolve()
if reference is None:
raise TypeError(
f'{self.qualname!r} does not exist in '
f'module {self.module.__name__!r}'
)
return reference

def format(
self,
Expand Down

0 comments on commit 65c7e52

Please sign in to comment.