Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix generic dataclasses with bound parameters. #257

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 83 additions & 38 deletions chex/_src/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import collections
import dataclasses
import functools
import inspect

from absl import logging
import jax
Expand All @@ -27,32 +28,55 @@
_RESERVED_DCLS_FIELD_NAMES = frozenset(("from_tuple", "replace", "to_tuple"))


def mappable_dataclass(cls):
"""Exposes dataclass as ``collections.abc.Mapping`` descendent.
def _make_mappable(cls):
"""Create type that implements and inherits from ``collections.abc.Mapping``.

Allows to traverse dataclasses in methods from `dm-tree` library.
Note that this does not require the class to be a dataclass, as it is supposed
to be applied before creating the dataclass.

NOTE: changes dataclasses constructor to dict-type
(i.e. positional args aren't supported; however can use generators/iterables).
Allows to traverse dataclasses in methods from `dm-tree` library.

Args:
cls: A dataclass to mutate.
cls: A class to use as a base for the new type.

Returns:
Mutated dataclass implementing ``collections.abc.Mapping`` interface.
type implementing and inheriting from ``collections.abc.Mapping``.
"""
if not dataclasses.is_dataclass(cls):
raise ValueError(f"Expected dataclass, got {cls} (change wrappers order?).")

# Define methods for compatibility with `collections.abc.Mapping`.
setattr(cls, "__getitem__", lambda self, x: self.__dict__[x])
setattr(cls, "__len__", lambda self: len(self.__dict__))
setattr(cls, "__iter__", lambda self: iter(self.__dict__))

# Update constructor.
orig_init = cls.__init__
all_fields = set(f.name for f in cls.__dataclass_fields__.values())
init_fields = [f.name for f in cls.__dataclass_fields__.values() if f.init]
# Update base class to derive from Mapping
dct = dict(cls.__dict__)
if "__dict__" in dct:
dct.pop("__dict__") # Avoid self-references.

# Remove object from the sequence of base classes. Deriving from both Mapping
# and object will cause a failure to create a MRO for the updated class
bases = tuple(b for b in cls.__bases__ if b != object)
return type(cls.__name__, bases + (collections.abc.Mapping,), dct)


def _convert_kw_only_dataclass_init(dcls):
"""Create wrapped initializer that converts everything to keyword arguments.

This should be equivalent to passing `kw_only=True` when creating the
dataclass in Python <= 3.10.

Args:
dcls: the dataclass to take the constructor from.

Returns:
Initializer wrapping the original initializer but which requires
keyword-only arguments.

Throws:
ValueError: if all required arguments are not provided as keyword-only.
"""
orig_init = dcls.__init__
all_fields = set(f.name for f in dcls.__dataclass_fields__.values())
init_fields = [f.name for f in dcls.__dataclass_fields__.values() if f.init]

@functools.wraps(orig_init)
def new_init(self, *orig_args, **orig_kwargs):
Expand All @@ -69,17 +93,28 @@ def new_init(self, *orig_args, **orig_kwargs):
valid_kwargs = {k: v for k, v in all_kwargs.items() if k in init_fields}
orig_init(self, **valid_kwargs)

cls.__init__ = new_init
return new_init

# Update base class to derive from Mapping
dct = dict(cls.__dict__)
if "__dict__" in dct:
dct.pop("__dict__") # Avoid self-references.

# Remove object from the sequence of base classes. Deriving from both Mapping
# and object will cause a failure to create a MRO for the updated class
bases = tuple(b for b in cls.__bases__ if b != object)
cls = type(cls.__name__, bases + (collections.abc.Mapping,), dct)
def mappable_dataclass(cls):
"""Exposes dataclass as ``collections.abc.Mapping`` descendent.

Allows to traverse dataclasses in methods from `dm-tree` library.

NOTE: changes dataclasses constructor to dict-type
(i.e. positional args aren't supported; however can use generators/iterables).

Args:
cls: A dataclass to mutate.

Returns:
Mutated dataclass implementing ``collections.abc.Mapping`` interface.
"""
if not dataclasses.is_dataclass(cls):
raise ValueError(f"Expected dataclass, got {cls} (change wrappers order?).")

cls = _make_mappable(cls)
cls.__init__ = _convert_kw_only_dataclass_init(cls)
return cls


Expand Down Expand Up @@ -159,37 +194,40 @@ def __init__(
def __call__(self, cls):
"""Forwards class to dataclasses's wrapper and registers it with JAX."""

if self.mappable_dataclass:
cls = _make_mappable(cls)
# We remove `collection.abc.Mapping` mixin methods here to allow
# fields with these names.
for attr in ("values", "keys", "get", "items"):
setattr(cls, attr, None) # redefine to avoid AttributeError on delattr
delattr(cls, attr) # delete

# Remove once https://github.com/python/cpython/pull/24484 is merged.
for base in cls.__bases__:
if (dataclasses.is_dataclass(base) and
getattr(base, "__dataclass_params__").frozen and not self.frozen):
raise TypeError("cannot inherit non-frozen dataclass from a frozen one")

# Check for invalid field names.
annotations = inspect.get_annotations(cls)
fields_names = set(name for name in annotations.keys())
invalid_fields = fields_names.intersection(_RESERVED_DCLS_FIELD_NAMES)
if invalid_fields:
raise ValueError(f"The following dataclass fields are disallowed: "
f"{invalid_fields} ({cls}).")

# pytype: disable=wrong-keyword-args
dcls = dataclasses.dataclass(
cls,
init=self.init,
repr=self.repr,
eq=self.eq,
order=self.order,
# kw_only=self.mappable_dataclass,
unsafe_hash=self.unsafe_hash,
frozen=self.frozen)
# pytype: enable=wrong-keyword-args

fields_names = set(f.name for f in dataclasses.fields(dcls))
invalid_fields = fields_names.intersection(_RESERVED_DCLS_FIELD_NAMES)
if invalid_fields:
raise ValueError(f"The following dataclass fields are disallowed: "
f"{invalid_fields} ({dcls}).")

if self.mappable_dataclass:
dcls = mappable_dataclass(dcls)
# We remove `collection.abc.Mapping` mixin methods here to allow
# fields with these names.
for attr in ("values", "keys", "get", "items"):
setattr(dcls, attr, None) # redefine
delattr(dcls, attr) # delete

def _from_tuple(args):
return dcls(zip(dcls.__dataclass_fields__.keys(), args))

Expand All @@ -212,6 +250,9 @@ def _setstate(self, state):
self.__dict__.update(state)

orig_init = dcls.__init__
is_mappable_dataclass = self.mappable_dataclass
if self.mappable_dataclass:
kw_only_init = _convert_kw_only_dataclass_init(dcls)

# Patch object's __init__ such that the class is registered on creation if
# it is not registered on deserialization.
Expand All @@ -220,7 +261,11 @@ def _init(self, *args, **kwargs):
if not class_self.registered:
register_dataclass_type_with_jax_tree_util(dcls)
class_self.registered = True
return orig_init(self, *args, **kwargs)

if is_mappable_dataclass:
return kw_only_init(self, *args, **kwargs)
else:
return orig_init(self, *args, **kwargs)

setattr(dcls, "from_tuple", _from_tuple)
setattr(dcls, "to_tuple", _to_tuple)
Expand Down
14 changes: 9 additions & 5 deletions chex/_src/dataclass_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,6 @@ class ValidMappable:
get: int

with self.assertRaisesRegex(ValueError, 'dataclass fields are disallowed'):

@chex_dataclass(mappable_dataclass=True)
class InvalidMappable:
get: int
Expand Down Expand Up @@ -571,19 +570,24 @@ class Bar:
self.assertLen(jax.tree_util.tree_flatten(Bar())[0], 2)

@parameterized.named_parameters(
('mappable', True),
('not_mappable', False),
('mappable_frozen', True, True),
('not_mappable_frozen', False, True),
('mappable_not_frozen', True, False),
('not_mappable_not_frozen', False, False),
)
def test_generic_dataclass(self, mappable):
def test_generic_dataclass(self, mappable, frozen):
T = TypeVar('T')

@chex_dataclass(mappable_dataclass=mappable)
@chex_dataclass(mappable_dataclass=mappable, frozen=frozen)
class GenericDataclass(Generic[T]):
a: T # pytype: disable=invalid-annotation # enable-bare-annotations

obj = GenericDataclass(a=np.array([1.0, 1.0]))
asserts.assert_tree_all_close(obj.a, 1.0)

obj = GenericDataclass[np.array](a=np.array([1.0, 1.0]))
asserts.assert_tree_all_close(obj.a, 1.0)

def test_mappable_eq_override(self):

@chex_dataclass(mappable_dataclass=True)
Expand Down