From 0f4f0a6b0dcf550d2a755a166c32699dfd559595 Mon Sep 17 00:00:00 2001 From: Stephen Spencer Date: Mon, 3 Apr 2023 03:15:54 -0700 Subject: [PATCH] Fix generic dataclasses with bound parameters. This alters the way in which mappable dataclasses are created in order to fix a crash when a mappable, frozen generic dataclass is instantiated with a bound type parameter. PiperOrigin-RevId: 521409933 --- chex/_src/dataclass.py | 121 +++++++++++++++++++++++++----------- chex/_src/dataclass_test.py | 14 +++-- 2 files changed, 92 insertions(+), 43 deletions(-) diff --git a/chex/_src/dataclass.py b/chex/_src/dataclass.py index bd743ae6..6f653261 100644 --- a/chex/_src/dataclass.py +++ b/chex/_src/dataclass.py @@ -17,6 +17,7 @@ import collections import dataclasses import functools +import inspect from absl import logging import jax @@ -27,32 +28,55 @@ _RESERVED_DCLS_FIELD_NAMES = frozenset(("from_tuple", "replace", "to_tuple")) -def mappable_dataclass(cls): - """Exposes dataclass as ``collections.abc.Mapping`` descendent. +def _make_mappable(cls): + """Create type that implements and inherits from ``collections.abc.Mapping``. - Allows to traverse dataclasses in methods from `dm-tree` library. + Note that this does not require the class to be a dataclass, as it is supposed + to be applied before creating the dataclass. - NOTE: changes dataclasses constructor to dict-type - (i.e. positional args aren't supported; however can use generators/iterables). + Allows to traverse dataclasses in methods from `dm-tree` library. Args: - cls: A dataclass to mutate. + cls: A class to use as a base for the new type. Returns: - Mutated dataclass implementing ``collections.abc.Mapping`` interface. + type implementing and inheriting from ``collections.abc.Mapping``. """ - if not dataclasses.is_dataclass(cls): - raise ValueError(f"Expected dataclass, got {cls} (change wrappers order?).") - # Define methods for compatibility with `collections.abc.Mapping`. setattr(cls, "__getitem__", lambda self, x: self.__dict__[x]) setattr(cls, "__len__", lambda self: len(self.__dict__)) setattr(cls, "__iter__", lambda self: iter(self.__dict__)) - # Update constructor. - orig_init = cls.__init__ - all_fields = set(f.name for f in cls.__dataclass_fields__.values()) - init_fields = [f.name for f in cls.__dataclass_fields__.values() if f.init] + # Update base class to derive from Mapping + dct = dict(cls.__dict__) + if "__dict__" in dct: + dct.pop("__dict__") # Avoid self-references. + + # Remove object from the sequence of base classes. Deriving from both Mapping + # and object will cause a failure to create a MRO for the updated class + bases = tuple(b for b in cls.__bases__ if b != object) + return type(cls.__name__, bases + (collections.abc.Mapping,), dct) + + +def _convert_kw_only_dataclass_init(dcls): + """Create wrapped initializer that converts everything to keyword arguments. + + This should be equivalent to passing `kw_only=True` when creating the + dataclass in Python <= 3.10. + + Args: + dcls: the dataclass to take the constructor from. + + Returns: + Initializer wrapping the original initializer but which requires + keyword-only arguments. + + Throws: + ValueError: if all required arguments are not provided as keyword-only. + """ + orig_init = dcls.__init__ + all_fields = set(f.name for f in dcls.__dataclass_fields__.values()) + init_fields = [f.name for f in dcls.__dataclass_fields__.values() if f.init] @functools.wraps(orig_init) def new_init(self, *orig_args, **orig_kwargs): @@ -69,17 +93,28 @@ def new_init(self, *orig_args, **orig_kwargs): valid_kwargs = {k: v for k, v in all_kwargs.items() if k in init_fields} orig_init(self, **valid_kwargs) - cls.__init__ = new_init + return new_init - # Update base class to derive from Mapping - dct = dict(cls.__dict__) - if "__dict__" in dct: - dct.pop("__dict__") # Avoid self-references. - # Remove object from the sequence of base classes. Deriving from both Mapping - # and object will cause a failure to create a MRO for the updated class - bases = tuple(b for b in cls.__bases__ if b != object) - cls = type(cls.__name__, bases + (collections.abc.Mapping,), dct) +def mappable_dataclass(cls): + """Exposes dataclass as ``collections.abc.Mapping`` descendent. + + Allows to traverse dataclasses in methods from `dm-tree` library. + + NOTE: changes dataclasses constructor to dict-type + (i.e. positional args aren't supported; however can use generators/iterables). + + Args: + cls: A dataclass to mutate. + + Returns: + Mutated dataclass implementing ``collections.abc.Mapping`` interface. + """ + if not dataclasses.is_dataclass(cls): + raise ValueError(f"Expected dataclass, got {cls} (change wrappers order?).") + + cls = _make_mappable(cls) + cls.__init__ = _convert_kw_only_dataclass_init(cls) return cls @@ -159,12 +194,28 @@ def __init__( def __call__(self, cls): """Forwards class to dataclasses's wrapper and registers it with JAX.""" + if self.mappable_dataclass: + cls = _make_mappable(cls) + # We remove `collection.abc.Mapping` mixin methods here to allow + # fields with these names. + for attr in ("values", "keys", "get", "items"): + setattr(cls, attr, None) # redefine to avoid AttributeError on delattr + delattr(cls, attr) # delete + # Remove once https://github.com/python/cpython/pull/24484 is merged. for base in cls.__bases__: if (dataclasses.is_dataclass(base) and getattr(base, "__dataclass_params__").frozen and not self.frozen): raise TypeError("cannot inherit non-frozen dataclass from a frozen one") + # Check for invalid field names. + annotations = inspect.get_annotations(cls) + fields_names = set(name for name in annotations.keys()) + invalid_fields = fields_names.intersection(_RESERVED_DCLS_FIELD_NAMES) + if invalid_fields: + raise ValueError(f"The following dataclass fields are disallowed: " + f"{invalid_fields} ({cls}).") + # pytype: disable=wrong-keyword-args dcls = dataclasses.dataclass( cls, @@ -172,24 +223,11 @@ def __call__(self, cls): repr=self.repr, eq=self.eq, order=self.order, + # kw_only=self.mappable_dataclass, unsafe_hash=self.unsafe_hash, frozen=self.frozen) # pytype: enable=wrong-keyword-args - fields_names = set(f.name for f in dataclasses.fields(dcls)) - invalid_fields = fields_names.intersection(_RESERVED_DCLS_FIELD_NAMES) - if invalid_fields: - raise ValueError(f"The following dataclass fields are disallowed: " - f"{invalid_fields} ({dcls}).") - - if self.mappable_dataclass: - dcls = mappable_dataclass(dcls) - # We remove `collection.abc.Mapping` mixin methods here to allow - # fields with these names. - for attr in ("values", "keys", "get", "items"): - setattr(dcls, attr, None) # redefine - delattr(dcls, attr) # delete - def _from_tuple(args): return dcls(zip(dcls.__dataclass_fields__.keys(), args)) @@ -212,6 +250,9 @@ def _setstate(self, state): self.__dict__.update(state) orig_init = dcls.__init__ + is_mappable_dataclass = self.mappable_dataclass + if self.mappable_dataclass: + kw_only_init = _convert_kw_only_dataclass_init(dcls) # Patch object's __init__ such that the class is registered on creation if # it is not registered on deserialization. @@ -220,7 +261,11 @@ def _init(self, *args, **kwargs): if not class_self.registered: register_dataclass_type_with_jax_tree_util(dcls) class_self.registered = True - return orig_init(self, *args, **kwargs) + + if is_mappable_dataclass: + return kw_only_init(self, *args, **kwargs) + else: + return orig_init(self, *args, **kwargs) setattr(dcls, "from_tuple", _from_tuple) setattr(dcls, "to_tuple", _to_tuple) diff --git a/chex/_src/dataclass_test.py b/chex/_src/dataclass_test.py index 614c650c..20849c28 100644 --- a/chex/_src/dataclass_test.py +++ b/chex/_src/dataclass_test.py @@ -507,7 +507,6 @@ class ValidMappable: get: int with self.assertRaisesRegex(ValueError, 'dataclass fields are disallowed'): - @chex_dataclass(mappable_dataclass=True) class InvalidMappable: get: int @@ -571,19 +570,24 @@ class Bar: self.assertLen(jax.tree_util.tree_flatten(Bar())[0], 2) @parameterized.named_parameters( - ('mappable', True), - ('not_mappable', False), + ('mappable_frozen', True, True), + ('not_mappable_frozen', False, True), + ('mappable_not_frozen', True, False), + ('not_mappable_not_frozen', False, False), ) - def test_generic_dataclass(self, mappable): + def test_generic_dataclass(self, mappable, frozen): T = TypeVar('T') - @chex_dataclass(mappable_dataclass=mappable) + @chex_dataclass(mappable_dataclass=mappable, frozen=frozen) class GenericDataclass(Generic[T]): a: T # pytype: disable=invalid-annotation # enable-bare-annotations obj = GenericDataclass(a=np.array([1.0, 1.0])) asserts.assert_tree_all_close(obj.a, 1.0) + obj = GenericDataclass[np.array](a=np.array([1.0, 1.0])) + asserts.assert_tree_all_close(obj.a, 1.0) + def test_mappable_eq_override(self): @chex_dataclass(mappable_dataclass=True)