From 3bafc0a99e24d162666eb9cf359765b40606677b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 22 Nov 2024 17:47:56 +0100 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- docs/source/reference/tensorclass.rst | 1 + tensordict/__init__.py | 1 + tensordict/_lazy.py | 26 ++++- tensordict/_td.py | 150 +++++++++++++++++++++++--- tensordict/base.py | 120 ++++++++++++++++++++- tensordict/nn/params.py | 8 +- tensordict/persistent.py | 28 ++++- tensordict/tensorclass.py | 104 +++++++++++++++++- tensordict/tensorclass.pyi | 7 ++ tensordict/utils.py | 2 +- test/_utils_internal.py | 58 ++++++---- test/test_tensorclass.py | 53 +++++++++ test/test_tensordict.py | 50 +++++++++ 13 files changed, 564 insertions(+), 44 deletions(-) diff --git a/docs/source/reference/tensorclass.rst b/docs/source/reference/tensorclass.rst index 17dceff06..ea55aef40 100644 --- a/docs/source/reference/tensorclass.rst +++ b/docs/source/reference/tensorclass.rst @@ -282,6 +282,7 @@ Here is an example: TensorClass NonTensorData NonTensorStack + from_dataclass Auto-casting ------------ diff --git a/tensordict/__init__.py b/tensordict/__init__.py index 364a11f5a..7fc9d349d 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -43,6 +43,7 @@ from tensordict.memmap import MemoryMappedTensor from tensordict.persistent import PersistentTensorDict from tensordict.tensorclass import ( + from_dataclass, NonTensorData, NonTensorStack, tensorclass, diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index eb4248671..fc87258bf 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -329,15 +329,37 @@ def _reduce_get_metadata(self): @classmethod def from_dict( cls, - input_dict, + input_dict: List[Dict[NestedKey, Any]], + *other, + auto_batch_size: bool = False, batch_size=None, device=None, batch_dims=None, stack_dim_name=None, stack_dim=0, ): + if batch_size is not None: + batch_size = list(batch_size) + if stack_dim is None: + stack_dim = 0 + n = batch_size.pop(stack_dim) + if n != len(input_dict): + raise ValueError( + "The number of dicts and the corresponding batch-size must match, " + f"got len(input_dict)={len(input_dict)} and batch_size[{stack_dim}]={n}." + ) + batch_size = torch.Size(batch_size) return LazyStackedTensorDict( - *(input_dict[str(i)] for i in range(len(input_dict))), + *( + TensorDict.from_dict( + input_dict[str(i)], + *other, + auto_batch_size=auto_batch_size, + device=device, + batch_dims=batch_dims, + ) + for i in range(len(input_dict)) + ), stack_dim=stack_dim, stack_dim_name=stack_dim_name, ) diff --git a/tensordict/_td.py b/tensordict/_td.py index 7895fae4e..3029aeab0 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -1957,8 +1957,46 @@ def _unsqueeze(tensor): @classmethod def from_dict( - cls, input_dict, batch_size=None, device=None, batch_dims=None, names=None + cls, + input_dict, + *others, + auto_batch_size: bool | None = None, + batch_size=None, + device=None, + batch_dims=None, + names=None, ): + if others: + if batch_size is not None: + raise TypeError( + "conflicting batch size values. Please use the keyword argument only." + ) + if device is not None: + raise TypeError( + "conflicting device values. Please use the keyword argument only." + ) + if batch_dims is not None: + raise TypeError( + "conflicting batch_dims values. Please use the keyword argument only." + ) + if names is not None: + raise TypeError( + "conflicting names values. Please use the keyword argument only." + ) + warn( + "All positional arguments after filename will be deprecated in v0.8. Please use keyword arguments instead.", + category=DeprecationWarning, + ) + batch_size, *others = others + if len(others): + device, *others = others + if len(others): + batch_dims, *others = others + if len(others): + names, *others = others + if len(others): + raise TypeError("Too many positional arguments.") + if batch_dims is not None and batch_size is not None: raise ValueError( "Cannot pass both batch_size and batch_dims to `from_dict`." @@ -1967,12 +2005,12 @@ def from_dict( batch_size_set = torch.Size(()) if batch_size is None else batch_size input_dict = dict(input_dict) for key, value in list(input_dict.items()): - if isinstance(value, (dict,)): - # we don't know if another tensor of smaller size is coming - # so we can't be sure that the batch-size will still be valid later - input_dict[key] = TensorDict.from_dict( - value, batch_size=[], device=device, batch_dims=None - ) + # we don't know if another tensor of smaller size is coming + # so we can't be sure that the batch-size will still be valid later + input_dict[key] = TensorDict.from_any( + value, + auto_batch_size=False, + ) # regular __init__ breaks because a tensor may have the same batch-size as the tensordict out = cls( input_dict, @@ -1981,7 +2019,17 @@ def from_dict( names=names, ) if batch_size is None: - _set_max_batch_size(out, batch_dims) + if auto_batch_size is None: + warn( + "The batch-size was not provided and auto_batch_size isn't set either. " + "Currently, from_dict will call set auto_batch_size=True but this behaviour " + "will be changed in v0.8 and auto_batch_size will be False onward. " + "To silence this warning, pass auto_batch_size directly.", + category=DeprecationWarning, + ) + auto_batch_size = True + if auto_batch_size: + _set_max_batch_size(out, batch_dims) else: out.batch_size = batch_size return out @@ -1998,8 +2046,46 @@ def _from_dict_validated( ) def from_dict_instance( - self, input_dict, batch_size=None, device=None, batch_dims=None, names=None + self, + input_dict, + *others, + auto_batch_size: bool | None = None, + batch_size=None, + device=None, + batch_dims=None, + names=None, ): + if others: + if batch_size is not None: + raise TypeError( + "conflicting batch size values. Please use the keyword argument only." + ) + if device is not None: + raise TypeError( + "conflicting device values. Please use the keyword argument only." + ) + if batch_dims is not None: + raise TypeError( + "conflicting batch_dims values. Please use the keyword argument only." + ) + if names is not None: + raise TypeError( + "conflicting names values. Please use the keyword argument only." + ) + warn( + "All positional arguments after filename will be deprecated in v0.8. Please use keyword arguments instead.", + category=DeprecationWarning, + ) + batch_size, *others = others + if len(others): + device, *others = others + if len(others): + batch_dims, *others = others + if len(others): + names, *others = others + if len(others): + raise TypeError("Too many positional arguments.") + if batch_dims is not None and batch_size is not None: raise ValueError( "Cannot pass both batch_size and batch_dims to `from_dict`." @@ -2014,14 +2100,24 @@ def from_dict_instance( cur_value = self.get(key, None) if cur_value is not None: input_dict[key] = cur_value.from_dict_instance( - value, batch_size=[], device=device, batch_dims=None + value, + device=device, + auto_batch_size=auto_batch_size, ) continue # we don't know if another tensor of smaller size is coming # so we can't be sure that the batch-size will still be valid later input_dict[key] = TensorDict.from_dict( - value, batch_size=[], device=device, batch_dims=None + value, + device=device, + auto_batch_size=auto_batch_size, + ) + else: + input_dict[key] = TensorDict.from_any( + value, + auto_batch_size=auto_batch_size, ) + out = TensorDict.from_dict( input_dict, batch_size=batch_size_set, @@ -2029,7 +2125,17 @@ def from_dict_instance( names=names, ) if batch_size is None: - _set_max_batch_size(out, batch_dims) + if auto_batch_size is None: + warn( + "The batch-size was not provided and auto_batch_size isn't set either. " + "Currently, from_dict will call set auto_batch_size=True but this behaviour " + "will be changed in v0.8 and auto_batch_size will be False onward. " + "To silence this warning, pass auto_batch_size directly.", + category=DeprecationWarning, + ) + auto_batch_size = True + if auto_batch_size: + _set_max_batch_size(out, batch_dims) else: out.batch_size = batch_size return out @@ -3857,7 +3963,14 @@ def expand(self, *args: int, inplace: bool = False) -> T: @classmethod def from_dict( - cls, input_dict, batch_size=None, device=None, batch_dims=None, names=None + cls, + input_dict, + *others, + auto_batch_size: bool = False, + batch_size=None, + device=None, + batch_dims=None, + names=None, ): raise NotImplementedError(f"from_dict not implemented for {cls.__name__}.") @@ -4273,6 +4386,12 @@ def _items( (key, tensordict._get_str(key, NO_DEFAULT)) for key in tensordict._source.keys() ) + from tensordict.persistent import PersistentTensorDict + + if isinstance(tensordict, PersistentTensorDict): + return ( + (key, tensordict._get_str(key, NO_DEFAULT)) for key in tensordict.keys() + ) raise NotImplementedError(type(tensordict)) def _keys(self) -> _TensorDictKeysView: @@ -4697,7 +4816,9 @@ def from_modules( ) -def from_dict(input_dict, batch_size=None, device=None, batch_dims=None, names=None): +def from_dict( + input_dict, *others, batch_size=None, device=None, batch_dims=None, names=None +): """Returns a TensorDict created from a dictionary or another :class:`~.tensordict.TensorDict`. If ``batch_size`` is not specified, returns the maximum batch size possible. @@ -4762,6 +4883,7 @@ def from_dict(input_dict, batch_size=None, device=None, batch_dims=None, names=N """ return TensorDict.from_dict( input_dict, + *others, batch_size=batch_size, device=device, batch_dims=batch_dims, diff --git a/tensordict/base.py b/tensordict/base.py index 39729eba4..bea8af10c 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -12,6 +12,7 @@ import enum import gc import importlib +import importlib.util import os.path import queue import uuid @@ -112,6 +113,8 @@ except ImportError: from tensordict.utils import Buffer +_has_h5 = importlib.util.find_spec("h5py") is not None + # NO_DEFAULT is used as a placeholder whenever the default is not provided. # Using None is not an option since `td.get(key)` is a valid usage. @@ -1133,6 +1136,8 @@ def auto_device_(self) -> T: def from_dict( cls, input_dict, + *, + auto_batch_size: bool | None = None, batch_size: torch.Size | None = None, device: torch.device | None = None, batch_dims: int | None = None, @@ -1148,6 +1153,10 @@ def from_dict( Args: input_dict (dictionary, optional): a dictionary to use as a data source (nested keys compatible). + + Keyword Args: + auto_batch_size (bool, optional): if ``True``, the batch size will be computed automatically. + Defaults to ``False``. batch_size (iterable of int, optional): a batch size for the tensordict. device (torch.device or compatible type, optional): a device for the TensorDict. batch_dims (int, optional): the ``batch_dims`` (ie number of leading dimensions @@ -1213,6 +1222,7 @@ def _from_dict_validated(cls, *args, **kwargs): def from_dict_instance( self, input_dict, + *others, batch_size=None, device=None, batch_dims=None, @@ -9837,6 +9847,113 @@ def dict_to_namedtuple(dictionary): return dict_to_namedtuple(self.to_dict(retain_none=False)) + @classmethod + def from_any(cls, obj, *, auto_batch_size: bool = False): + """Converts any object to a TensorDict, recursively. + + Keyword Args: + auto_batch_size (bool, optional): if ``True``, the batch size will be computed automatically. + Defaults to ``False``. + + Support includes: + + - dataclasses through :meth:`~.from_dataclass` (dataclasses will be converted to TensorDict instances, not + tensorclasses). + - namedtuple through :meth:`~.from_namedtuple` + - dict through :meth:`~.from_dict` + - tuple through :meth:`~.from_tuple` + - numpy's structured arrays through :meth:`~.from_struct_array` + - h5 objects through :meth:`~.from_h5` + + """ + if isinstance(obj, dict): + return cls.from_dict(obj, auto_batch_size=auto_batch_size) + if isinstance(obj, np.ndarray) and hasattr(obj.dtype, "names"): + return cls.from_struct_array(obj, auto_batch_size=auto_batch_size) + from dataclasses import is_dataclass + + if is_dataclass(obj): + return cls.from_dataclass(obj, auto_batch_size=auto_batch_size) + if is_namedtuple(obj): + return cls.from_namedtuple(obj, auto_batch_size=auto_batch_size) + if isinstance(obj, tuple): + return cls.from_tuple(obj, auto_batch_size=auto_batch_size) + if isinstance(obj, list): + return cls.from_tuple(tuple(obj), auto_batch_size=auto_batch_size) + if _has_h5: + import h5py + + if isinstance(obj, h5py.File): + from tensordict.persistent import PersistentTensorDict + + obj = PersistentTensorDict(group=obj) + if auto_batch_size: + obj.auto_batch_size_() + return obj + return obj + + @classmethod + def from_tuple(cls, obj, *, auto_batch_size: bool = False): + from tensordict import TensorDict + + result = TensorDict({str(i): cls.from_any(item) for i, item in enumerate(obj)}) + if auto_batch_size: + result.auto_batch_size_() + return result + + @classmethod + def from_dataclass( + cls, dataclass, *, auto_batch_size: bool = False, as_tensorclass: bool = False + ): + """Converts a dataclass into a TensorDict instance. + + Args: + dataclass: The dataclass instance to be converted. + + Keyword Args: + auto_batch_size (bool, optional): If ``True``, automatically determines and applies batch size to the + resulting TensorDict. Defaults to ``False``. + as_tensorclass (bool, optional): If ``True``, delegates the conversion to the free function + :func:`~tensordict.from_dataclass` and returns a tensor-compatible class + (:func:`~tensordict.tensorclass`) or instance instead of a ``TensorDict``. Defaults to ``False``. + + Returns: + A TensorDict instance derived from the provided dataclass, unless `as_tensorclass` is True, in which case a tensor-compatible class or instance is returned. + + Raises: + TypeError: If the provided input is not a dataclass instance. + + .. warning:: This method is distinct from the free function `from_dataclass` and serves a different purpose. + While the free function returns a tensor-compatible class or instance, this method returns a TensorDict instance. + + .. notes:: + + - This method creates a new TensorDict instance with keys corresponding to the fields of the input dataclass. + - Each key in the resulting TensorDict is initialized using the `cls.from_any` method. + - The `auto_batch_size` option allows for automatic batch size determination and application to the + resulting TensorDict. + + """ + if as_tensorclass: + from tensordict.tensorclass import from_dataclass + + return from_dataclass(dataclass, auto_batch_size=auto_batch_size) + from dataclasses import fields, is_dataclass + + from tensordict import TensorDict + + if not is_dataclass(dataclass): + raise TypeError( + f"Expected a dataclass input, got a {type(dataclass)} input instead." + ) + source = {} + for field in fields(dataclass): + source[field.name] = cls.from_any(getattr(dataclass, field.name)) + result = TensorDict(source) + if auto_batch_size: + result.auto_batch_size_() + return result + @classmethod def from_namedtuple(cls, named_tuple, *, auto_batch_size: bool = False): """Converts a namedtuple to a TensorDict recursively. @@ -9885,8 +10002,7 @@ def namedtuple_to_dict(namedtuple_obj): "indices": namedtuple_obj.indices, } for key, value in namedtuple_obj.items(): - if is_namedtuple(value): - namedtuple_obj[key] = namedtuple_to_dict(value) + namedtuple_obj[key] = cls.from_any(value) return dict(namedtuple_obj) result = TensorDict(namedtuple_to_dict(named_tuple)) diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index df7bad0e6..0b6dca196 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -921,7 +921,13 @@ def _exclude( @_carry_over def from_dict_instance( - self, input_dict, batch_size=None, device=None, batch_dims=None + self, + input_dict, + *, + auto_batch_size: bool = False, + batch_size=None, + device=None, + batch_dims=None, ): ... @_carry_over diff --git a/tensordict/persistent.py b/tensordict/persistent.py index d5f59110a..332023587 100644 --- a/tensordict/persistent.py +++ b/tensordict/persistent.py @@ -207,12 +207,25 @@ def from_h5(cls, filename, mode="r"): return out @classmethod - def from_dict(cls, input_dict, filename, batch_size=None, device=None, **kwargs): + def from_dict( + cls, + input_dict, + filename, + *others, + auto_batch_size: bool = False, + batch_size=None, + device=None, + **kwargs, + ): """Converts a dictionary or a TensorDict to a h5 file. Args: input_dict (dict, TensorDict or compatible): data to be stored as h5. filename (str or path): path to the h5 file. + + Keyword Args: + auto_batch_size (bool, optional): if ``True``, the batch size will be computed automatically. + Defaults to ``False``. batch_size (tensordict batch-size, optional): if provided, batch size of the tensordict. If not, the batch size will be gathered from the input structure (if present) or determined automatically. @@ -225,6 +238,19 @@ def from_dict(cls, input_dict, filename, batch_size=None, device=None, **kwargs) A :class:`PersitentTensorDict` instance linked to the newly created file. """ + if others: + if batch_size is not None: + raise TypeError( + "conflicting batch size values. Please use the keyword argument only." + ) + warnings.warn( + "All positional arguments after filename will be deprecated in v0.8. Please use keyword arguments instead." + ) + if len(others) == 2: + batch_size, device = others + else: + batch_size = others[0] + import h5py file = h5py.File(filename, "w", locking=cls.LOCKING) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 2556729e5..fc188e7cf 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -237,6 +237,12 @@ def __subclasscheck__(self, subclass): "floor_", "frac", "frac_", + "from_any", + "from_dataclass", + "to_namedtuple", + "from_namedtuple", + "from_pytree", + "to_pytree", "gather", "isfinite", "isnan", @@ -379,6 +385,100 @@ def __call__(self, cls: T) -> T: return clz +def from_dataclass( + obj: Any, + *, + auto_batch_size: bool = False, + frozen: bool = False, + autocast: bool = False, + nocast: bool = False, +) -> Any: + """Converts a dataclass instance or a type into a tensorclass instance or type, respectively. + + This function takes a dataclass instance or a dataclass type and converts it into a tensor-compatible class, + optionally applying various configurations such as auto-batching, immutability, and type casting. + + Args: + obj (Any): The dataclass instance or type to be converted. If a type is provided, a new class is returned. + + Keyword Args: + auto_batch_size (bool, optional): If ``True``, automatically determines and applies batch size to the resulting object. Defaults to ``False``. + frozen (bool, optional): If ``True``, the resulting class or instance will be immutable. Defaults to ``False``. + autocast (bool, optional): If ``True``, enables automatic type casting for the resulting class or instance. Defaults to ``False``. + nocast (bool, optional): If ``True``, disables any type casting for the resulting class or instance. Defaults to ``False``. + + Returns: + A tensor-compatible class or instance derived from the provided dataclass. + + Raises: + TypeError: If the provided input is not a dataclass instance or type. + + Examples: + >>> from dataclasses import dataclass + >>> import torch + >>> from tensordict.tensorclass import from_dataclass + >>> + >>> @dataclass + >>> class X: + ... a: int + ... b: torch.Tensor + ... + >>> x = X(0, 0) + >>> x2 = from_dataclass(x) + >>> print(x2) + X( + a=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), + b=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> X2 = from_dataclass(X, autocast=True) + >>> print(X2(a=0, b=0)) + X( + a=NonTensorData(data=0, batch_size=torch.Size([]), device=None), + b=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), + batch_size=torch.Size([]), + device=None, + is_shared=False) + + .. notes:: If a dataclass type is provided, a new class is returned with the specified configurations. + If a dataclass instance is provided, a new instance of the tensor-compatible class is returned. + The `auto_batch_size`, `frozen`, `autocast`, and `nocast` options allow for flexible configuration of the resulting class or instance. + + .. warning:: Whereas :meth:`~tensordict.TensorDict.from_dataclass` will return a :class:`~tensordict.TensorDict` instance + by default, this method will return a tensorclass instance or type. + + """ + from dataclasses import asdict, is_dataclass, make_dataclass + + if isinstance(obj, type): + if is_tensorclass(obj): + return obj + cls = make_dataclass( + obj.__name__ + "_tc", fields=obj.__dataclass_fields__, bases=obj.__bases__ + ) + clz = _tensorclass(cls, frozen=frozen) + clz._type_hints = get_type_hints(obj) + clz._autocast = autocast + clz._nocast = nocast + clz._frozen = frozen + return clz + + if not is_dataclass(obj): + raise TypeError(f"Expected a obj input, got a {type(obj)} input instead.") + name = obj.__class__.__name__ + "_tc" + clz = _tensorclass( + make_dataclass(name, fields=obj.__dataclass_fields__), frozen=frozen + ) + clz._autocast = autocast + clz._nocast = nocast + clz._frozen = frozen + result = clz(**asdict(obj)) + if auto_batch_size: + result = result.auto_batch_size_() + return result + + @dataclass_transform() def tensorclass( cls: T = None, @@ -532,7 +632,8 @@ def __torch_function__( _is_non_tensor = getattr(cls, "_is_non_tensor", False) - cls = dataclass(cls, frozen=frozen) + if not dataclasses.is_dataclass(cls): + cls = dataclass(cls, frozen=frozen) _TENSORCLASS_MEMO[cls] = True expected_keys = cls.__expected_keys__ = set(cls.__dataclass_fields__) @@ -2483,7 +2584,6 @@ def __post_init__(self): data_inner = data.tolist() del _tensordict["data"] _non_tensordict["data"] = data_inner - # assert _tensordict.is_empty(), self._tensordict # TODO: this will probably fail with dynamo at some point, + it's terrible. # Make sure it's patched properly at init time diff --git a/tensordict/tensorclass.pyi b/tensordict/tensorclass.pyi index 75678b4b6..a77ef185a 100644 --- a/tensordict/tensorclass.pyi +++ b/tensordict/tensorclass.pyi @@ -209,9 +209,16 @@ class TensorClass: def auto_batch_size_(self, batch_dims: int | None = None) -> T: ... def auto_device_(self) -> T: ... @classmethod + def from_dataclass( + cls, dataclass, *, auto_batch_size: bool = False, as_tensorclass: bool = False + ): ... + @classmethod + def from_any(cls, obj, *, auto_batch_size: bool = False): ... + @classmethod def from_dict( cls, input_dict, + *, batch_size: torch.Size | None = None, device: torch.device | None = None, batch_dims: int | None = None, diff --git a/tensordict/utils.py b/tensordict/utils.py index cdc0756f8..54e00a174 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -858,7 +858,7 @@ def is_tensorclass(obj: type | Any) -> bool: def _is_tensorclass(cls: type) -> bool: - out = _TENSORCLASS_MEMO.get(cls, None) + out = _TENSORCLASS_MEMO.get(cls) if out is None: out = getattr(cls, "_is_tensorclass", False) if not is_dynamo_compiling(): diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 8879f0e68..ad1a194cd 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -53,7 +53,8 @@ class TestTensorDictsBase: TYPES_DEVICES = [] TYPES_DEVICES_NOLAZY = [] - def td(self, device): + @classmethod + def td(cls, device): return TensorDict( source={ "a": torch.randn(4, 3, 2, 1, 5), @@ -68,7 +69,8 @@ def td(self, device): TYPES_DEVICES += [["td", device]] TYPES_DEVICES_NOLAZY += [["td", device]] - def nested_td(self, device): + @classmethod + def nested_td(cls, device): return TensorDict( source={ "a": torch.randn(4, 3, 2, 1, 5), @@ -86,7 +88,8 @@ def nested_td(self, device): TYPES_DEVICES += [["nested_td", device]] TYPES_DEVICES_NOLAZY += [["nested_td", device]] - def nested_tensorclass(self, device): + @classmethod + def nested_tensorclass(cls, device): nested_class = MyClass( X=torch.randn(4, 3, 2, 1), @@ -119,8 +122,9 @@ def nested_tensorclass(self, device): TYPES_DEVICES += [["nested_tensorclass", device]] TYPES_DEVICES_NOLAZY += [["nested_tensorclass", device]] + @classmethod @set_lazy_legacy(True) - def nested_stacked_td(self, device): + def nested_stacked_td(cls, device): td = TensorDict( source={ "a": torch.randn(4, 3, 2, 1, 5), @@ -140,8 +144,9 @@ def nested_stacked_td(self, device): TYPES_DEVICES += [["nested_stacked_td", device]] TYPES_DEVICES_NOLAZY += [["nested_stacked_td", device]] + @classmethod @set_lazy_legacy(True) - def stacked_td(self, device): + def stacked_td(cls, device): td1 = TensorDict( source={ "a": torch.randn(4, 3, 1, 5), @@ -165,7 +170,8 @@ def stacked_td(self, device): for device in get_available_devices(): TYPES_DEVICES += [["stacked_td", device]] - def idx_td(self, device): + @classmethod + def idx_td(cls, device): td = TensorDict( source={ "a": torch.randn(2, 4, 3, 2, 1, 5), @@ -180,7 +186,8 @@ def idx_td(self, device): for device in get_available_devices(): TYPES_DEVICES += [["idx_td", device]] - def sub_td(self, device): + @classmethod + def sub_td(cls, device): td = TensorDict( source={ "a": torch.randn(2, 4, 3, 2, 1, 5), @@ -195,7 +202,8 @@ def sub_td(self, device): for device in get_available_devices(): TYPES_DEVICES += [["sub_td", device]] - def sub_td2(self, device): + @classmethod + def sub_td2(cls, device): td = TensorDict( source={ "a": torch.randn(4, 2, 3, 2, 1, 5), @@ -212,17 +220,19 @@ def sub_td2(self, device): temp_path_memmap = tempfile.TemporaryDirectory() - def memmap_td(self, device): - path = pathlib.Path(self.temp_path_memmap.name) + @classmethod + def memmap_td(cls, device): + path = pathlib.Path(cls.temp_path_memmap.name) shutil.rmtree(path) path.mkdir() - return self.td(device).memmap_(path) + return cls.td(device).memmap_(path) TYPES_DEVICES += [["memmap_td", torch.device("cpu")]] TYPES_DEVICES_NOLAZY += [["memmap_td", torch.device("cpu")]] + @classmethod @set_lazy_legacy(True) - def permute_td(self, device): + def permute_td(cls, device): return TensorDict( source={ "a": torch.randn(3, 1, 4, 2, 5), @@ -236,8 +246,9 @@ def permute_td(self, device): for device in get_available_devices(): TYPES_DEVICES += [["permute_td", device]] + @classmethod @set_lazy_legacy(True) - def unsqueezed_td(self, device): + def unsqueezed_td(cls, device): td = TensorDict( source={ "a": torch.randn(4, 3, 2, 5), @@ -252,8 +263,9 @@ def unsqueezed_td(self, device): for device in get_available_devices(): TYPES_DEVICES += [["unsqueezed_td", device]] + @classmethod @set_lazy_legacy(True) - def squeezed_td(self, device): + def squeezed_td(cls, device): td = TensorDict( source={ "a": torch.randn(4, 3, 1, 2, 1, 5), @@ -268,7 +280,8 @@ def squeezed_td(self, device): for device in get_available_devices(): TYPES_DEVICES += [["squeezed_td", device]] - def td_reset_bs(self, device): + @classmethod + def td_reset_bs(cls, device): td = TensorDict( source={ "a": torch.randn(4, 3, 2, 1, 5), @@ -285,13 +298,14 @@ def td_reset_bs(self, device): TYPES_DEVICES += [["td_reset_bs", device]] TYPES_DEVICES_NOLAZY += [["td_reset_bs", device]] + @classmethod def td_h5( - self, + cls, device, ): file = tempfile.NamedTemporaryFile() filename = file.name - nested_td = self.nested_td(device) + nested_td = cls.nested_td(device) td_h5 = PersistentTensorDict.from_dict( nested_td, filename=filename, device=device ) @@ -303,15 +317,17 @@ def td_h5( TYPES_DEVICES += [["td_h5", device]] TYPES_DEVICES_NOLAZY += [["td_h5", device]] - def td_params(self, device): - return TensorDictParams(self.td(device)) + @classmethod + def td_params(cls, device): + return TensorDictParams(cls.td(device)) for device in get_available_devices(): TYPES_DEVICES += [["td_params", device]] TYPES_DEVICES_NOLAZY += [["td_params", device]] - def td_with_non_tensor(self, device): - td = self.td(device) + @classmethod + def td_with_non_tensor(cls, device): + td = cls.td(device) return td.set_non_tensor( ("data", "non_tensor"), # this is allowed since nested NonTensorData are automatically unwrapped diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index 4753c3704..127d4b77a 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -23,6 +23,7 @@ import tensordict.utils import torch from tensordict import TensorClass +from tensordict.tensorclass import from_dataclass try: import torchsnapshot @@ -94,6 +95,21 @@ class MyData2: z: list +@dataclasses.dataclass +class MyDataClass: + a: int + b: torch.Tensor + c: str + + +try: + MyTensorClass_autocast = from_dataclass(MyDataClass, autocast=True) + MyTensorClass_nocast = from_dataclass(MyDataClass, nocast=True) + MyTensorClass = from_dataclass(MyDataClass) +except Exception: + MyTensorClass_autocast = MyTensorClass_nocast = MyTensorClass = None + + class TestTensorClass: def test_all_any(self): @tensorclass @@ -517,6 +533,43 @@ class MyClass2: assert (a != c.clone().zero_()).any() assert (c != a.clone().zero_()).any() + def test_from_dataclass(self): + assert is_tensorclass(MyTensorClass_autocast) + assert MyTensorClass_nocast is not MyDataClass + assert MyTensorClass_autocast._autocast + x = MyTensorClass_autocast(a=0, b=0, c=0) + assert isinstance(x.a, int) + assert isinstance(x.b, torch.Tensor) + assert isinstance(x.c, str) + + assert is_tensorclass(MyTensorClass_nocast) + assert MyTensorClass_nocast is not MyTensorClass_autocast + assert MyTensorClass_nocast._nocast + + x = MyTensorClass_nocast(a=0, b=0, c=0) + assert is_tensorclass(MyTensorClass) + assert not MyTensorClass._autocast + assert not MyTensorClass._nocast + assert isinstance(x.a, int) + assert isinstance(x.b, int) + assert isinstance(x.c, int) + + x = MyTensorClass(a=0, b=0, c=0) + assert isinstance(x.a, torch.Tensor) + assert isinstance(x.b, torch.Tensor) + assert isinstance(x.c, torch.Tensor) + + x = TensorDict.from_dataclass(MyTensorClass(a=0, b=0, c=0)) + assert isinstance(x, TensorDict) + assert isinstance(x["a"], torch.Tensor) + assert isinstance(x["b"], torch.Tensor) + assert isinstance(x["c"], torch.Tensor) + x = from_dataclass(MyTensorClass(a=0, b=0, c=0)) + assert is_tensorclass(x) + assert isinstance(x.a, torch.Tensor) + assert isinstance(x.b, torch.Tensor) + assert isinstance(x.c, torch.Tensor) + def test_from_dict(self): td = TensorDict( { diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 257c2b712..7d4148cd9 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -20,6 +20,7 @@ import warnings from dataclasses import dataclass from pathlib import Path +from typing import Any import numpy as np import pytest @@ -949,6 +950,55 @@ def test_fromkeys(self): td = TensorDict.fromkeys({"a", "b", "c"}, 1) assert td["a"] == 1 + def test_from_any(self): + from dataclasses import dataclass + + @dataclass + class MyClass: + a: int + + pytree = ( + [torch.randint(10, (3,)), torch.zeros(2)], + { + "tensor": torch.randn( + 2, + ), + "td": TensorDict({"one": 1}), + "tuple": (1, 2, 3), + }, + {"named_tuple": TensorDict({"two": torch.ones(1) * 2}).to_namedtuple()}, + {"dataclass": MyClass(a=0)}, + ) + if _has_h5py: + pytree = pytree + ({"h5py": TestTensorDictsBase.td_h5(device="cpu").file},) + td = TensorDict.from_any(pytree) + assert set(td.keys(True, True)) == { + ("0", "0"), + ("0", "1"), + ("1", "td", "one"), + ("1", "tensor"), + ("1", "tuple", "0"), + ("1", "tuple", "1"), + ("1", "tuple", "2"), + ("2", "named_tuple", "two"), + ("4", "h5py", "a"), + ("4", "h5py", "b"), + ("4", "h5py", "c"), + ("4", "h5py", "my_nested_td", "inner"), + } + + def test_from_dataclass(self): + @dataclass + class MyClass: + a: int + b: Any + + obj = MyClass(a=0, b=1) + obj_td = TensorDict.from_dataclass(obj) + obj_tc = TensorDict.from_dataclass(obj, as_tensorclass=True) + assert is_tensorclass(obj_tc) + assert not is_tensorclass(obj_td) + @pytest.mark.parametrize("batch_size", [None, [3, 4]]) @pytest.mark.parametrize("batch_dims", [None, 1, 2]) @pytest.mark.parametrize("device", get_available_devices())