Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix yaml serialization in io mixin #11106

Merged
merged 2 commits into from
Oct 31, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions nemo/lightning/io/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,33 @@ def _partial_representer_with_defaults(dumper, data):


def _safe_object_representer(dumper, data):
if not inspect.isclass(data):
cls = data.__class__
call = True
else:
cls = data
"""
Represent a given object as YAML using the specified dumper.

This function is a fallback for objects that don't have specific representers.
If the object has __qualname__ attr, the __target__ is set to f"{inspect.getmodule(obj).__name__}.{obj.__qualname__}".
If the object does not have a __qualname__ attr, the __target__ is set from its __class__ attr.
The __call__ key is used to indicate whether the target should be called to create an instance.

Args:
dumper (yaml.Dumper): The YAML dumper to use for serialization.
data (Any): The data to serialize. This can be any Python object,
but if it's a class or a class instance, special handling will be applied.

Returns:
str: The YAML representation of the data.
"""
try:
obj = data
target = f"{inspect.getmodule(obj).__name__}.{obj.__qualname__}"
call = False
except AttributeError:
obj = data.__class__
target = f"{inspect.getmodule(obj).__name__}.{obj.__qualname__}"
call = True

value = {
"_target_": f"{inspect.getmodule(cls).__name__}.{cls.__qualname__}", # type: ignore
"_target_": target, # type: ignore
"_call_": call,
}
return dumper.represent_data(value)
Expand Down
Loading