From b5704d8384daea61dcd36c30dd53dcf2c0992d59 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Fri, 20 Jan 2023 15:09:49 +0000 Subject: [PATCH 1/4] Drop _tensordict argument in constructor --- tensordict/prototype/tensorclass.py | 100 ++++++++++++---------------- test/test_tensorclass.py | 11 +-- test/test_tensorclass_nofuture.py | 11 +-- 3 files changed, 46 insertions(+), 76 deletions(-) diff --git a/tensordict/prototype/tensorclass.py b/tensordict/prototype/tensorclass.py index 4b7e6b480..9abe51615 100644 --- a/tensordict/prototype/tensorclass.py +++ b/tensordict/prototype/tensorclass.py @@ -125,7 +125,7 @@ def __torch_function__( ) cls.__init__ = _init_wrapper(cls.__init__, expected_keys) - cls._build_from_tensordict = classmethod(_build_from_tensordict) + cls.from_tensordict = classmethod(_from_tensordict) cls.__torch_function__ = classmethod(__torch_function__) cls.__getstate__ = _getstate cls.__setstate__ = _setstate @@ -165,47 +165,32 @@ def _init_wrapper(init, expected_keys): required_params = [p.name for p in params[1:] if p.default is inspect._empty] @functools.wraps(init) - def wrapper(self, *args, batch_size=None, device=None, _tensordict=None, **kwargs): - if (args or kwargs) and _tensordict is not None: - raise ValueError("Cannot pass both args/kwargs and _tensordict.") - - if _tensordict is not None: - if not all(key in expected_keys for key in _tensordict.keys()): - raise ValueError( - f"Keys from the tensordict ({set(_tensordict.keys())}) must " - f"correspond to the class attributes ({expected_keys})." - ) - input_dict = {key: None for key in _tensordict.keys()} - init(self, **input_dict) - self.tensordict = _tensordict - else: - for value, key in zip(args, self.__dataclass_fields__): - if key in kwargs: - raise ValueError(f"The key {key} is already set in kwargs") - kwargs[key] = value - - for key, field in self.__dataclass_fields__.items(): - if field.default_factory is not dataclasses.MISSING: - default = field.default_factory() - else: - default = field.default - if default not in (None, dataclasses.MISSING): - kwargs.setdefault(key, default) - - missing_params = [p for p in required_params if p not in kwargs] - if missing_params: - n_missing = len(missing_params) - raise TypeError( - f"{self.__class__.__name__}.__init__() missing {n_missing} " - f"required positional argument{'' if n_missing == 1 else 's'}: " - f"""{", ".join(f"'{name}'" for name in missing_params)}""" - ) + def wrapper(self, *args, batch_size=None, device=None, **kwargs): + for value, key in zip(args, self.__dataclass_fields__): + if key in kwargs: + raise ValueError(f"The key {key} is already set in kwargs") + kwargs[key] = value + + for key, field in self.__dataclass_fields__.items(): + if field.default_factory is not dataclasses.MISSING: + default = field.default_factory() + else: + default = field.default + if default not in (None, dataclasses.MISSING): + kwargs.setdefault(key, default) - self.tensordict = TensorDict({}, batch_size=batch_size, device=device) - init( - self, **{key: _get_typed_value(value) for key, value in kwargs.items()} + missing_params = [p for p in required_params if p not in kwargs] + if missing_params: + n_missing = len(missing_params) + raise TypeError( + f"{self.__class__.__name__}.__init__() missing {n_missing} " + f"required positional argument{'' if n_missing == 1 else 's'}: " + f"""{", ".join(f"'{name}'" for name in missing_params)}""" ) + self.tensordict = TensorDict({}, batch_size=batch_size, device=device) + init(self, **{key: _get_typed_value(value) for key, value in kwargs.items()}) + new_params = [ inspect.Parameter( "batch_size", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None @@ -213,17 +198,16 @@ def wrapper(self, *args, batch_size=None, device=None, _tensordict=None, **kwarg inspect.Parameter( "device", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None ), - inspect.Parameter( - "_tensordict", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None - ), ] wrapper.__signature__ = init_sig.replace(parameters=params + new_params) return wrapper -def _build_from_tensordict(cls, tensordict): - return cls(_tensordict=tensordict) +def _from_tensordict(cls, tensordict): + tc = cls(**tensordict, batch_size=tensordict.batch_size) + tc.__dict__["tensordict"] = tensordict + return tc def _getstate(self): @@ -274,7 +258,7 @@ def _getattr(self, attr): def wrapped_func(*args, **kwargs): res = func(*args, **kwargs) if isinstance(res, TensorDictBase): - new = self.__class__(_tensordict=res) + new = self.from_tensordict(res) return new else: return res @@ -288,7 +272,7 @@ def _getitem(self, item): ): raise ValueError("Invalid indexing arguments.") res = self.tensordict[item] - return self.__class__(_tensordict=res) # device=res.device) + return self.from_tensordict(res) # device=res.device) def _setitem(self, item, value): @@ -351,13 +335,13 @@ def _batch_size_setter(self, new_size: torch.Size) -> None: def _unbind(tdc, dim): tensordicts = torch.unbind(tdc.tensordict, dim) - out = [tdc.__class__(_tensordict=td) for td in tensordicts] + out = [tdc.from_tensordict(td) for td in tensordicts] return out def _full_like(tdc, fill_value): tensordict = torch.full_like(tdc.tensordict, fill_value) - out = tdc.__class__(_tensordict=tensordict) + out = tdc.from_tensordict(tensordict) return out @@ -371,43 +355,43 @@ def _ones_like(tdc): def _clone(tdc): tensordict = torch.clone(tdc.tensordict) - out = tdc.__class__(_tensordict=tensordict) + out = tdc.from_tensordict(tensordict) return out def _squeeze(tdc): tensordict = torch.squeeze(tdc.tensordict) - out = tdc.__class__(_tensordict=tensordict) + out = tdc.from_tensordict(tensordict) return out def _unsqueeze(tdc, dim=0): tensordict = torch.unsqueeze(tdc.tensordict, dim) - out = tdc.__class__(_tensordict=tensordict) + out = tdc.from_tensordict(tensordict) return out def _permute(tdc, dims): tensordict = torch.permute(tdc.tensordict, dims) - out = tdc.__class__(_tensordict=tensordict) + out = tdc.from_tensordict(tensordict) return out def _split(tdc, split_size_or_sections, dim=0): tensordicts = torch.split(tdc.tensordict, split_size_or_sections, dim) - out = [tdc.__class__(_tensordict=td) for td in tensordicts] + out = [tdc.from_tensordict(td) for td in tensordicts] return out def _stack(list_of_tdc, dim): tensordict = torch.stack([tdc.tensordict for tdc in list_of_tdc], dim) - out = list_of_tdc[0].__class__(_tensordict=tensordict) + out = list_of_tdc[0].from_tensordict(tensordict) return out def _cat(list_of_tdc, dim): tensordict = torch.cat([tdc.tensordict for tdc in list_of_tdc], dim) - out = list_of_tdc[0].__class__(_tensordict=tensordict) + out = list_of_tdc[0].from_tensordict(tensordict) return out @@ -425,18 +409,18 @@ def _get_typed_output(out, expected_type): # Otherwise, if the output is some TensorDictBase subclass, we check the type and if it # does not match, we map it. In all other cases, just return what has been gathered. if isinstance(expected_type, str) and expected_type in CLASSES_DICT: - out = CLASSES_DICT[expected_type](_tensordict=out) + out = CLASSES_DICT[expected_type].from_tensordict(out) elif ( isinstance(expected_type, type) and not isinstance(out, expected_type) and isinstance(out, TensorDictBase) ): - out = expected_type(_tensordict=out) + out = expected_type.from_tensordict(out) elif isinstance(out, TensorDictBase): dest_dtype = _check_td_out_type(expected_type) if dest_dtype is not None: print(dest_dtype) - out = dest_dtype(_tensordict=out) + out = dest_dtype.from_tensordict(out) return out diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index ee4152600..d4d403a04 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -70,7 +70,7 @@ def test_type(): def test_signature(): sig = inspect.signature(MyData) - assert list(sig.parameters) == ["X", "y", "batch_size", "device", "_tensordict"] + assert list(sig.parameters) == ["X", "y", "batch_size", "device"] with pytest.raises(TypeError, match="missing 2 required positional arguments"): MyData() @@ -85,13 +85,6 @@ def test_signature(): with pytest.raises(ValueError, match="batch size was not specified"): MyData(X=torch.rand(10), y=torch.rand(10)) - # instantiation via _tensordict ignores argument checks, no TypeError - MyData( - _tensordict=TensorDict( - {"X": torch.rand(10), "y": torch.rand(10)}, batch_size=[10] - ) - ) - # all positional arguments + batch_size is fine MyData(X=torch.rand(10), y=torch.rand(10), batch_size=[10]) @@ -158,7 +151,7 @@ class MyUnionClass: subclass: Union[MyOptionalClass, TensorDict] = None data = MyUnionClass( - subclass=MyUnionClass(_tensordict=TensorDict({}, [3])), batch_size=[3] + subclass=MyUnionClass.from_tensordict(TensorDict({}, [3])), batch_size=[3] ) with pytest.raises(TypeError, match="can't be deterministically cast."): assert data.subclass is not None diff --git a/test/test_tensorclass_nofuture.py b/test/test_tensorclass_nofuture.py index bb3fce755..9ee350bca 100644 --- a/test/test_tensorclass_nofuture.py +++ b/test/test_tensorclass_nofuture.py @@ -68,7 +68,7 @@ def test_type(): def test_signature(): sig = inspect.signature(MyData) - assert list(sig.parameters) == ["X", "y", "batch_size", "device", "_tensordict"] + assert list(sig.parameters) == ["X", "y", "batch_size", "device"] with pytest.raises(TypeError, match="missing 2 required positional arguments"): MyData() @@ -83,13 +83,6 @@ def test_signature(): with pytest.raises(ValueError, match="batch size was not specified"): MyData(X=torch.rand(10), y=torch.rand(10)) - # instantiation via _tensordict ignores argument checks, no TypeError - MyData( - _tensordict=TensorDict( - {"X": torch.rand(10), "y": torch.rand(10)}, batch_size=[10] - ) - ) - # all positional arguments + batch_size is fine MyData(X=torch.rand(10), y=torch.rand(10), batch_size=[10]) @@ -156,7 +149,7 @@ class MyUnionClass: subclass: Union[MyOptionalClass, TensorDict] = None data = MyUnionClass( - subclass=MyUnionClass(_tensordict=TensorDict({}, [3])), batch_size=[3] + subclass=MyUnionClass.from_tensordict(TensorDict({}, [3])), batch_size=[3] ) with pytest.raises(TypeError, match="can't be deterministically cast."): assert data.subclass is not None From bcace640243f9839d72c011db0f6eb09b9535511 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Fri, 20 Jan 2023 15:16:22 +0000 Subject: [PATCH 2/4] Make batch_size required keyword-only argument --- tensordict/prototype/tensorclass.py | 10 +++------- test/test_tensorclass.py | 8 +++++--- test/test_tensorclass_nofuture.py | 8 +++++--- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/tensordict/prototype/tensorclass.py b/tensordict/prototype/tensorclass.py index 9abe51615..0d539947f 100644 --- a/tensordict/prototype/tensorclass.py +++ b/tensordict/prototype/tensorclass.py @@ -165,7 +165,7 @@ def _init_wrapper(init, expected_keys): required_params = [p.name for p in params[1:] if p.default is inspect._empty] @functools.wraps(init) - def wrapper(self, *args, batch_size=None, device=None, **kwargs): + def wrapper(self, *args, batch_size, device=None, **kwargs): for value, key in zip(args, self.__dataclass_fields__): if key in kwargs: raise ValueError(f"The key {key} is already set in kwargs") @@ -192,12 +192,8 @@ def wrapper(self, *args, batch_size=None, device=None, **kwargs): init(self, **{key: _get_typed_value(value) for key, value in kwargs.items()}) new_params = [ - inspect.Parameter( - "batch_size", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None - ), - inspect.Parameter( - "device", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None - ), + inspect.Parameter("batch_size", inspect.Parameter.KEYWORD_ONLY), + inspect.Parameter("device", inspect.Parameter.KEYWORD_ONLY, default=None), ] wrapper.__signature__ = init_sig.replace(parameters=params + new_params) diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index d4d403a04..6c337806e 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -73,16 +73,18 @@ def test_signature(): assert list(sig.parameters) == ["X", "y", "batch_size", "device"] with pytest.raises(TypeError, match="missing 2 required positional arguments"): - MyData() + MyData(batch_size=[10]) with pytest.raises(TypeError, match="missing 1 required positional argument"): - MyData(X=torch.rand(10)) + MyData(X=torch.rand(10), batch_size=[10]) with pytest.raises(TypeError, match="missing 1 required positional argument"): MyData(X=torch.rand(10), batch_size=[10], device="cpu") # if all positional arguments are specified, ommitting batch_size gives error - with pytest.raises(ValueError, match="batch size was not specified"): + with pytest.raises( + TypeError, match="missing 1 required keyword-only argument: 'batch_size'" + ): MyData(X=torch.rand(10), y=torch.rand(10)) # all positional arguments + batch_size is fine diff --git a/test/test_tensorclass_nofuture.py b/test/test_tensorclass_nofuture.py index 9ee350bca..3b730cac6 100644 --- a/test/test_tensorclass_nofuture.py +++ b/test/test_tensorclass_nofuture.py @@ -71,16 +71,18 @@ def test_signature(): assert list(sig.parameters) == ["X", "y", "batch_size", "device"] with pytest.raises(TypeError, match="missing 2 required positional arguments"): - MyData() + MyData(batch_size=[10]) with pytest.raises(TypeError, match="missing 1 required positional argument"): - MyData(X=torch.rand(10)) + MyData(X=torch.rand(10), batch_size=[10]) with pytest.raises(TypeError, match="missing 1 required positional argument"): MyData(X=torch.rand(10), batch_size=[10], device="cpu") # if all positional arguments are specified, ommitting batch_size gives error - with pytest.raises(ValueError, match="batch size was not specified"): + with pytest.raises( + TypeError, match="missing 1 required keyword-only argument: 'batch_size'" + ): MyData(X=torch.rand(10), y=torch.rand(10)) # all positional arguments + batch_size is fine From 7228d9b26d314e43fdbde82ce0030671d32dcb0a Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Mon, 23 Jan 2023 10:56:16 +0000 Subject: [PATCH 3/4] Check for expected keys --- tensordict/prototype/tensorclass.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/tensordict/prototype/tensorclass.py b/tensordict/prototype/tensorclass.py index 0d539947f..6f8773821 100644 --- a/tensordict/prototype/tensorclass.py +++ b/tensordict/prototype/tensorclass.py @@ -124,8 +124,8 @@ def __torch_function__( f"Attribute name {attr} can't be used with @tensorclass" ) - cls.__init__ = _init_wrapper(cls.__init__, expected_keys) - cls.from_tensordict = classmethod(_from_tensordict) + cls.__init__ = _init_wrapper(cls.__init__) + cls.from_tensordict = classmethod(_from_tensordict_wrapper(expected_keys)) cls.__torch_function__ = classmethod(__torch_function__) cls.__getstate__ = _getstate cls.__setstate__ = _setstate @@ -158,7 +158,7 @@ def __torch_function__( return cls -def _init_wrapper(init, expected_keys): +def _init_wrapper(init): init_sig = inspect.signature(init) params = list(init_sig.parameters.values()) # drop first entry of params which corresponds to self and isn't passed by the user @@ -200,10 +200,18 @@ def wrapper(self, *args, batch_size, device=None, **kwargs): return wrapper -def _from_tensordict(cls, tensordict): - tc = cls(**tensordict, batch_size=tensordict.batch_size) - tc.__dict__["tensordict"] = tensordict - return tc +def _from_tensordict_wrapper(expected_keys): + def wrapper(cls, tensordict): + if not all(key in expected_keys for key in tensordict.keys()): + raise ValueError( + f"Keys from the tensordict ({set(tensordict.keys())}) must " + f"correspond to the class attributes ({expected_keys})." + ) + tc = cls(**tensordict, batch_size=tensordict.batch_size) + tc.__dict__["tensordict"] = tensordict + return tc + + return wrapper def _getstate(self): From 7f8d19768ec70be4339d0a55ab59bb4f5b80fc36 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Mon, 23 Jan 2023 15:58:14 +0000 Subject: [PATCH 4/4] Update test --- test/test_tensorclass.py | 4 ++-- test/test_tensorclass_nofuture.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index 9c05f574e..064868d1f 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -546,7 +546,7 @@ def __post_init__(self): assert (data.y == y.abs()).all() # initialising from tensordict is fine - data = MyDataPostInit._build_from_tensordict( + data = MyDataPostInit.from_tensordict( TensorDict({"X": torch.rand(3, 4), "y": y}, batch_size=[3, 4]) ) @@ -554,7 +554,7 @@ def __post_init__(self): MyDataPostInit(X=-torch.ones(2), y=torch.rand(2), batch_size=[2]) with pytest.raises(AssertionError): - MyDataPostInit._build_from_tensordict( + MyDataPostInit.from_tensordict( TensorDict({"X": -torch.ones(2), "y": torch.rand(2)}, batch_size=[2]) ) diff --git a/test/test_tensorclass_nofuture.py b/test/test_tensorclass_nofuture.py index 6cf51af89..48c593900 100644 --- a/test/test_tensorclass_nofuture.py +++ b/test/test_tensorclass_nofuture.py @@ -544,7 +544,7 @@ def __post_init__(self): assert (data.y == y.abs()).all() # initialising from tensordict is fine - data = MyDataPostInit._build_from_tensordict( + data = MyDataPostInit.from_tensordict( TensorDict({"X": torch.rand(3, 4), "y": y}, batch_size=[3, 4]) ) @@ -552,7 +552,7 @@ def __post_init__(self): MyDataPostInit(X=-torch.ones(2), y=torch.rand(2), batch_size=[2]) with pytest.raises(AssertionError): - MyDataPostInit._build_from_tensordict( + MyDataPostInit.from_tensordict( TensorDict({"X": -torch.ones(2), "y": torch.rand(2)}, batch_size=[2]) )