From 5349f2af46f25291e1a5ba536f6e8071d8287a0b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 9 Dec 2024 14:38:01 -0800 Subject: [PATCH] [Feature] super() calls within TensorClass subclasses ghstack-source-id: 060a89982413869c54e1fb4aa74f90e2b9cdaac4 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1133 --- tensordict/nn/distributions/continuous.py | 3 +- tensordict/nn/distributions/discrete.py | 1 + tensordict/tensorclass.py | 50 +++++++++++++++++++---- test/test_tensorclass.py | 26 ++++++++++-- test/test_tensordict.py | 1 + 5 files changed, 67 insertions(+), 14 deletions(-) diff --git a/tensordict/nn/distributions/continuous.py b/tensordict/nn/distributions/continuous.py index 2e6015acb..5210d10b1 100644 --- a/tensordict/nn/distributions/continuous.py +++ b/tensordict/nn/distributions/continuous.py @@ -15,6 +15,7 @@ from tensordict.nn.utils import mappings from torch import distributions as D, nn +# We need this to build the distribution maps __all__ = [ "NormalParamExtractor", "NormalParamWrapper", @@ -23,7 +24,7 @@ ] # speeds up distribution construction -D.Distribution.set_default_validate_args(False) +# D.Distribution.set_default_validate_args(False) class NormalParamWrapper(nn.Module): diff --git a/tensordict/nn/distributions/discrete.py b/tensordict/nn/distributions/discrete.py index 4bcbd5626..81db384dc 100644 --- a/tensordict/nn/distributions/discrete.py +++ b/tensordict/nn/distributions/discrete.py @@ -10,6 +10,7 @@ import torch from torch import distributions as D +# We need this to build the distribution maps __all__ = [ "OneHotCategorical", ] diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index f0f8a1cf4..9c0f24e56 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -692,20 +692,40 @@ def __torch_function__( cls.__getstate__ = _getstate cls.__setstate__ = _setstate # cls.__getattribute__ = object.__getattribute__ + # if "__getattr__" not in cls.__dict__: cls.__getattr__ = _getattr - cls.__setattr__ = _setattr_wrapper(cls.__setattr__, expected_keys) - # cls.__getattr__ = _getattr - cls.__getitem__ = _getitem - cls.__getitems__ = _getitem - cls.__setitem__ = _setitem + if "__setattr__" not in cls.__dict__: + cls.__setattr__ = _setattr_wrapper(cls.__setattr__, expected_keys) + if "__getitem__" not in cls.__dict__: + cls.__getitem__ = _getitem + if "__getitems__" not in cls.__dict__: + cls.__getitems__ = _getitem + if "__setitem__" not in cls.__dict__: + cls.__setitem__ = _setitem if not _is_non_tensor: cls.__repr__ = _repr - cls.__len__ = _len + if "__len__" not in cls.__dict__: + cls.__len__ = _len + # cls.__eq__ = _eq cls.__ne__ = _ne cls.__or__ = _or cls.__xor__ = _xor cls.__bool__ = _bool + + # cls.__setattr__ = _setattr_wrapper(cls.__setattr__, expected_keys) + # # cls.__getattr__ = _getattr + # cls.__getitem__ = _getitem + # cls.__getitems__ = _getitem + # cls.__setitem__ = _setitem + # if not _is_non_tensor: + # cls.__repr__ = _repr + # cls.__len__ = _len + # cls.__eq__ = _eq + # cls.__ne__ = _ne + # cls.__or__ = _or + # cls.__xor__ = _xor + # cls.__bool__ = _bool if not hasattr(cls, "non_tensor_items"): cls.non_tensor_items = _non_tensor_items if not hasattr(cls, "set"): @@ -2915,15 +2935,21 @@ def to_tensordict(self, *, retain_none: bool | None = None): return self @classmethod - def _stack_non_tensor(cls, list_of_non_tensor, dim=0): + def _stack_non_tensor(cls, list_of_non_tensor, dim=0, raise_if_non_unique=False): # checks have been performed previously, so we're sure the list is non-empty first = list_of_non_tensor[0] ids = set() firstdata = NO_DEFAULT + return_stack = False for data in list_of_non_tensor: if not isinstance(data, NonTensorData): - return_stack = True + if raise_if_non_unique: + data = cls._stack_non_tensor( + data, raise_if_non_unique=raise_if_non_unique + ) + else: + return_stack = True break if firstdata is NO_DEFAULT: firstdata = data.data @@ -2931,6 +2957,10 @@ def _stack_non_tensor(cls, list_of_non_tensor, dim=0): if len(ids) > 1: if _check_equal(data.data, firstdata): continue + if raise_if_non_unique: + raise ValueError( + "More than one unique value has been found in the stack." + ) return_stack = True break else: @@ -3459,7 +3489,9 @@ def data(self): self.tensordicts, raise_if_non_unique=True ).data except ValueError: - raise AttributeError("Cannot get the non-unique data of a NonTensorStack. Use .tolist() instead.") + raise AttributeError( + "Cannot get the non-unique data of a NonTensorStack. Use .tolist() instead." + ) _register_tensor_class(NonTensorStack) diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index 806a282a1..476e52c08 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -621,7 +621,7 @@ class X: assert isinstance(x.y, torch.Tensor) _ = {x: 0} assert x.is_locked - with pytest.raises(RuntimeError, match="locked"): + with pytest.raises((RuntimeError, dataclasses.FrozenInstanceError)): x.y = 0 @tensorclass(frozen=False, autocast=True) @@ -643,7 +643,7 @@ class X: assert isinstance(x.y, str) _ = {x: 0} assert x.is_locked - with pytest.raises(RuntimeError, match="locked"): + with pytest.raises((RuntimeError, dataclasses.FrozenInstanceError)): x.y = 0 @tensorclass(frozen=False, autocast=False) @@ -2585,7 +2585,7 @@ class SubClass(TensorClass, nocast=True, frozen=True): assert issubclass(SubClass, TensorClass) s = SubClass(1) assert isinstance(s.a, int) - with pytest.raises(RuntimeError): + with pytest.raises((RuntimeError, dataclasses.FrozenInstanceError)): s.a = 2 class SubClass(TensorClass["nocast", "frozen"]): @@ -2599,9 +2599,27 @@ class SubClass(TensorClass["nocast", "frozen"]): assert issubclass(SubClass, TensorClass) s = SubClass(1) assert isinstance(s.a, int) - with pytest.raises(RuntimeError): + with pytest.raises((RuntimeError, dataclasses.FrozenInstanceError)): s.a = 2 + def test_subclassing_super_call(self): + class SubClass(TensorClass, nocast=True): + a: int + b: int + + def __setattr__(self, key, value): + if key == "b": + return super().__setattr__("b", value + 1) + return super().__setattr__("a", value - 1) + + s = SubClass(a=torch.zeros(3), b=torch.zeros(3)) + assert (s.a == -1).all() + assert (s.b == 1).all() + s.a = torch.ones(()) + s.b = torch.ones(()) + assert (s.a == 0).all() + assert (s.b == 2).all() + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 812a69f7d..de9fad1e7 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -8377,6 +8377,7 @@ def test_consolidate(self, device, use_file, tmpdir): assert hasattr(td_c, "_consolidated") assert type(td_c) == type(td) # noqa assert (td.to(td_c.device) == td_c).all() + assert td["d"] == [["a string!"] * 3] assert td_c["d"] == [["a string!"] * 3] storage = td_c._consolidated["storage"]