From c95a703944f3b761987c173c6fbcdb1f2c2fb712 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 23 Nov 2024 20:23:27 +0100 Subject: [PATCH 1/8] [Feature,Refactor] Refactor from_dict, add from_any, from_dataclass ghstack-source-id: eb25fe4b201fd7f27d60b140278820c0d5d51eb8 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1102 --- docs/source/reference/tensorclass.rst | 1 + tensordict/__init__.py | 1 + tensordict/_lazy.py | 27 +++- tensordict/_td.py | 177 ++++++++++++++++++++++---- tensordict/base.py | 125 +++++++++++++++++- tensordict/functional.py | 7 +- tensordict/nn/common.py | 4 +- tensordict/nn/params.py | 8 +- tensordict/persistent.py | 28 +++- tensordict/tensorclass.py | 149 ++++++++++++++++++++-- tensordict/tensorclass.pyi | 7 + tensordict/utils.py | 13 +- test/_utils_internal.py | 58 ++++++--- test/test_tensorclass.py | 63 ++++++++- test/test_tensordict.py | 112 +++++++++++++--- 15 files changed, 695 insertions(+), 85 deletions(-) diff --git a/docs/source/reference/tensorclass.rst b/docs/source/reference/tensorclass.rst index 17dceff06..ea55aef40 100644 --- a/docs/source/reference/tensorclass.rst +++ b/docs/source/reference/tensorclass.rst @@ -282,6 +282,7 @@ Here is an example: TensorClass NonTensorData NonTensorStack + from_dataclass Auto-casting ------------ diff --git a/tensordict/__init__.py b/tensordict/__init__.py index 364a11f5a..7fc9d349d 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -43,6 +43,7 @@ from tensordict.memmap import MemoryMappedTensor from tensordict.persistent import PersistentTensorDict from tensordict.tensorclass import ( + from_dataclass, NonTensorData, NonTensorStack, tensorclass, diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index eb4248671..73c316981 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -329,15 +329,38 @@ def _reduce_get_metadata(self): @classmethod def from_dict( cls, - input_dict, + input_dict: List[Dict[NestedKey, Any]], + *other, + auto_batch_size: bool = False, batch_size=None, device=None, batch_dims=None, stack_dim_name=None, stack_dim=0, ): + # if batch_size is not None: + # batch_size = list(batch_size) + # if stack_dim is None: + # stack_dim = 0 + # n = batch_size.pop(stack_dim) + # if n != len(input_dict): + # raise ValueError( + # "The number of dicts and the corresponding batch-size must match, " + # f"got len(input_dict)={len(input_dict)} and batch_size[{stack_dim}]={n}." + # ) + # batch_size = torch.Size(batch_size) return LazyStackedTensorDict( - *(input_dict[str(i)] for i in range(len(input_dict))), + *( + TensorDict.from_dict( + input_dict[str(i)], + *other, + auto_batch_size=auto_batch_size, + device=device, + batch_dims=batch_dims, + batch_size=batch_size, + ) + for i in range(len(input_dict)) + ), stack_dim=stack_dim, stack_dim_name=stack_dim_name, ) diff --git a/tensordict/_td.py b/tensordict/_td.py index 7895fae4e..07a98cdfb 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -615,7 +615,7 @@ def __ne__(self, other: object) -> T | bool: if is_tensorclass(other): return other != self if isinstance(other, (dict,)): - other = self.from_dict_instance(other) + other = self.from_dict_instance(other, auto_batch_size=False) if _is_tensor_collection(type(other)): keys1 = set(self.keys()) keys2 = set(other.keys()) @@ -639,7 +639,7 @@ def __xor__(self, other: object) -> T | bool: if is_tensorclass(other): return other ^ self if isinstance(other, (dict,)): - other = self.from_dict_instance(other) + other = self.from_dict_instance(other, auto_batch_size=False) if _is_tensor_collection(type(other)): keys1 = set(self.keys()) keys2 = set(other.keys()) @@ -663,7 +663,7 @@ def __or__(self, other: object) -> T | bool: if is_tensorclass(other): return other | self if isinstance(other, (dict,)): - other = self.from_dict_instance(other) + other = self.from_dict_instance(other, auto_batch_size=False) if _is_tensor_collection(type(other)): keys1 = set(self.keys()) keys2 = set(other.keys()) @@ -687,7 +687,7 @@ def __eq__(self, other: object) -> T | bool: if is_tensorclass(other): return other == self if isinstance(other, (dict,)): - other = self.from_dict_instance(other) + other = self.from_dict_instance(other, auto_batch_size=False) if _is_tensor_collection(type(other)): keys1 = set(self.keys()) keys2 = set(other.keys()) @@ -709,7 +709,7 @@ def __ge__(self, other: object) -> T | bool: if is_tensorclass(other): return other <= self if isinstance(other, (dict,)): - other = self.from_dict_instance(other) + other = self.from_dict_instance(other, auto_batch_size=False) if _is_tensor_collection(type(other)): keys1 = set(self.keys()) keys2 = set(other.keys()) @@ -731,7 +731,7 @@ def __gt__(self, other: object) -> T | bool: if is_tensorclass(other): return other < self if isinstance(other, (dict,)): - other = self.from_dict_instance(other) + other = self.from_dict_instance(other, auto_batch_size=False) if _is_tensor_collection(type(other)): keys1 = set(self.keys()) keys2 = set(other.keys()) @@ -753,7 +753,7 @@ def __le__(self, other: object) -> T | bool: if is_tensorclass(other): return other >= self if isinstance(other, (dict,)): - other = self.from_dict_instance(other) + other = self.from_dict_instance(other, auto_batch_size=False) if _is_tensor_collection(type(other)): keys1 = set(self.keys()) keys2 = set(other.keys()) @@ -775,7 +775,7 @@ def __lt__(self, other: object) -> T | bool: if is_tensorclass(other): return other > self if isinstance(other, (dict,)): - other = self.from_dict_instance(other) + other = self.from_dict_instance(other, auto_batch_size=False) if _is_tensor_collection(type(other)): keys1 = set(self.keys()) keys2 = set(other.keys()) @@ -1957,8 +1957,46 @@ def _unsqueeze(tensor): @classmethod def from_dict( - cls, input_dict, batch_size=None, device=None, batch_dims=None, names=None + cls, + input_dict, + *others, + auto_batch_size: bool | None = None, + batch_size=None, + device=None, + batch_dims=None, + names=None, ): + if others: + if batch_size is not None: + raise TypeError( + "conflicting batch size values. Please use the keyword argument only." + ) + if device is not None: + raise TypeError( + "conflicting device values. Please use the keyword argument only." + ) + if batch_dims is not None: + raise TypeError( + "conflicting batch_dims values. Please use the keyword argument only." + ) + if names is not None: + raise TypeError( + "conflicting names values. Please use the keyword argument only." + ) + warn( + "All positional arguments after filename will be deprecated in v0.8. Please use keyword arguments instead.", + category=DeprecationWarning, + ) + batch_size, *others = others + if len(others): + device, *others = others + if len(others): + batch_dims, *others = others + if len(others): + names, *others = others + if len(others): + raise TypeError("Too many positional arguments.") + if batch_dims is not None and batch_size is not None: raise ValueError( "Cannot pass both batch_size and batch_dims to `from_dict`." @@ -1967,12 +2005,12 @@ def from_dict( batch_size_set = torch.Size(()) if batch_size is None else batch_size input_dict = dict(input_dict) for key, value in list(input_dict.items()): - if isinstance(value, (dict,)): - # we don't know if another tensor of smaller size is coming - # so we can't be sure that the batch-size will still be valid later - input_dict[key] = TensorDict.from_dict( - value, batch_size=[], device=device, batch_dims=None - ) + # we don't know if another tensor of smaller size is coming + # so we can't be sure that the batch-size will still be valid later + input_dict[key] = TensorDict.from_any( + value, + auto_batch_size=False, + ) # regular __init__ breaks because a tensor may have the same batch-size as the tensordict out = cls( input_dict, @@ -1981,7 +2019,19 @@ def from_dict( names=names, ) if batch_size is None: - _set_max_batch_size(out, batch_dims) + if auto_batch_size is None and batch_dims is None: + warn( + "The batch-size was not provided and auto_batch_size isn't set either. " + "Currently, from_dict will call set auto_batch_size=True but this behaviour " + "will be changed in v0.8 and auto_batch_size will be False onward. " + "To silence this warning, pass auto_batch_size directly.", + category=DeprecationWarning, + ) + auto_batch_size = True + elif auto_batch_size is None: + auto_batch_size = True + if auto_batch_size: + _set_max_batch_size(out, batch_dims) else: out.batch_size = batch_size return out @@ -1998,8 +2048,46 @@ def _from_dict_validated( ) def from_dict_instance( - self, input_dict, batch_size=None, device=None, batch_dims=None, names=None + self, + input_dict, + *others, + auto_batch_size: bool | None = None, + batch_size=None, + device=None, + batch_dims=None, + names=None, ): + if others: + if batch_size is not None: + raise TypeError( + "conflicting batch size values. Please use the keyword argument only." + ) + if device is not None: + raise TypeError( + "conflicting device values. Please use the keyword argument only." + ) + if batch_dims is not None: + raise TypeError( + "conflicting batch_dims values. Please use the keyword argument only." + ) + if names is not None: + raise TypeError( + "conflicting names values. Please use the keyword argument only." + ) + warn( + "All positional arguments after filename will be deprecated in v0.8. Please use keyword arguments instead.", + category=DeprecationWarning, + ) + batch_size, *others = others + if len(others): + device, *others = others + if len(others): + batch_dims, *others = others + if len(others): + names, *others = others + if len(others): + raise TypeError("Too many positional arguments.") + if batch_dims is not None and batch_size is not None: raise ValueError( "Cannot pass both batch_size and batch_dims to `from_dict`." @@ -2014,14 +2102,25 @@ def from_dict_instance( cur_value = self.get(key, None) if cur_value is not None: input_dict[key] = cur_value.from_dict_instance( - value, batch_size=[], device=device, batch_dims=None + value, + device=device, + auto_batch_size=False, ) continue - # we don't know if another tensor of smaller size is coming - # so we can't be sure that the batch-size will still be valid later - input_dict[key] = TensorDict.from_dict( - value, batch_size=[], device=device, batch_dims=None + else: + # we don't know if another tensor of smaller size is coming + # so we can't be sure that the batch-size will still be valid later + input_dict[key] = TensorDict.from_dict( + value, + device=device, + auto_batch_size=False, + ) + else: + input_dict[key] = TensorDict.from_any( + value, + auto_batch_size=False, ) + out = TensorDict.from_dict( input_dict, batch_size=batch_size_set, @@ -2029,7 +2128,19 @@ def from_dict_instance( names=names, ) if batch_size is None: - _set_max_batch_size(out, batch_dims) + if auto_batch_size is None and batch_dims is None: + warn( + "The batch-size was not provided and auto_batch_size isn't set either. " + "Currently, from_dict will call set auto_batch_size=True but this behaviour " + "will be changed in v0.8 and auto_batch_size will be False onward. " + "To silence this warning, pass auto_batch_size directly.", + category=DeprecationWarning, + ) + auto_batch_size = True + elif auto_batch_size is None: + auto_batch_size = True + if auto_batch_size: + _set_max_batch_size(out, batch_dims) else: out.batch_size = batch_size return out @@ -3857,7 +3968,14 @@ def expand(self, *args: int, inplace: bool = False) -> T: @classmethod def from_dict( - cls, input_dict, batch_size=None, device=None, batch_dims=None, names=None + cls, + input_dict, + *others, + auto_batch_size: bool = False, + batch_size=None, + device=None, + batch_dims=None, + names=None, ): raise NotImplementedError(f"from_dict not implemented for {cls.__name__}.") @@ -4273,6 +4391,12 @@ def _items( (key, tensordict._get_str(key, NO_DEFAULT)) for key in tensordict._source.keys() ) + from tensordict.persistent import PersistentTensorDict + + if isinstance(tensordict, PersistentTensorDict): + return ( + (key, tensordict._get_str(key, NO_DEFAULT)) for key in tensordict.keys() + ) raise NotImplementedError(type(tensordict)) def _keys(self) -> _TensorDictKeysView: @@ -4697,7 +4821,9 @@ def from_modules( ) -def from_dict(input_dict, batch_size=None, device=None, batch_dims=None, names=None): +def from_dict( + input_dict, *others, batch_size=None, device=None, batch_dims=None, names=None +): """Returns a TensorDict created from a dictionary or another :class:`~.tensordict.TensorDict`. If ``batch_size`` is not specified, returns the maximum batch size possible. @@ -4762,6 +4888,7 @@ def from_dict(input_dict, batch_size=None, device=None, batch_dims=None, names=N """ return TensorDict.from_dict( input_dict, + *others, batch_size=batch_size, device=device, batch_dims=batch_dims, diff --git a/tensordict/base.py b/tensordict/base.py index 39729eba4..6c600b11f 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -12,6 +12,7 @@ import enum import gc import importlib +import importlib.util import os.path import queue import uuid @@ -54,6 +55,7 @@ _CloudpickleWrapper, _DTYPE2STRDTYPE, _GENERIC_NESTED_ERR, + _is_dataclass as is_dataclass, _is_non_tensor, _is_number, _is_tensorclass, @@ -112,6 +114,8 @@ except ImportError: from tensordict.utils import Buffer +_has_h5 = importlib.util.find_spec("h5py") is not None + # NO_DEFAULT is used as a placeholder whenever the default is not provided. # Using None is not an option since `td.get(key)` is a valid usage. @@ -120,7 +124,6 @@ class _NoDefault(enum.IntEnum): NO_DEFAULT = _NoDefault.ZERO -assert not NO_DEFAULT T = TypeVar("T", bound="TensorDictBase") @@ -1133,6 +1136,8 @@ def auto_device_(self) -> T: def from_dict( cls, input_dict, + *, + auto_batch_size: bool | None = None, batch_size: torch.Size | None = None, device: torch.device | None = None, batch_dims: int | None = None, @@ -1148,6 +1153,10 @@ def from_dict( Args: input_dict (dictionary, optional): a dictionary to use as a data source (nested keys compatible). + + Keyword Args: + auto_batch_size (bool, optional): if ``True``, the batch size will be computed automatically. + Defaults to ``False``. batch_size (iterable of int, optional): a batch size for the tensordict. device (torch.device or compatible type, optional): a device for the TensorDict. batch_dims (int, optional): the ``batch_dims`` (ie number of leading dimensions @@ -1207,12 +1216,15 @@ def _from_dict_validated(cls, *args, **kwargs): By default, falls back on :meth:`~.from_dict`. """ + kwargs.setdefault("auto_batch_size", True) return cls.from_dict(*args, **kwargs) @abc.abstractmethod def from_dict_instance( self, input_dict, + *others, + auto_batch_size: bool | None = None, batch_size=None, device=None, batch_dims=None, @@ -4284,7 +4296,6 @@ def _view_and_pad(tensor): elif k[-1].startswith(""): # NJT/NT always comes before offsets/shapes nt = oldv - assert not v.numel() nt_lengths = None del flat_dict[k] elif k[-1].startswith(""): @@ -9837,6 +9848,113 @@ def dict_to_namedtuple(dictionary): return dict_to_namedtuple(self.to_dict(retain_none=False)) + @classmethod + def from_any(cls, obj, *, auto_batch_size: bool = False): + """Converts any object to a TensorDict, recursively. + + Keyword Args: + auto_batch_size (bool, optional): if ``True``, the batch size will be computed automatically. + Defaults to ``False``. + + Support includes: + + - dataclasses through :meth:`~.from_dataclass` (dataclasses will be converted to TensorDict instances, not + tensorclasses). + - namedtuple through :meth:`~.from_namedtuple` + - dict through :meth:`~.from_dict` + - tuple through :meth:`~.from_tuple` + - numpy's structured arrays through :meth:`~.from_struct_array` + - h5 objects through :meth:`~.from_h5` + + """ + if is_tensor_collection(obj): + return obj + if isinstance(obj, dict): + return cls.from_dict(obj, auto_batch_size=auto_batch_size) + if isinstance(obj, np.ndarray) and hasattr(obj.dtype, "names"): + return cls.from_struct_array(obj, auto_batch_size=auto_batch_size) + if isinstance(obj, tuple): + if is_namedtuple(obj): + return cls.from_namedtuple(obj, auto_batch_size=auto_batch_size) + return cls.from_tuple(obj, auto_batch_size=auto_batch_size) + if isinstance(obj, list): + return cls.from_tuple(tuple(obj), auto_batch_size=auto_batch_size) + if is_dataclass(obj): + return cls.from_dataclass(obj, auto_batch_size=auto_batch_size) + if _has_h5: + import h5py + + if isinstance(obj, h5py.File): + from tensordict.persistent import PersistentTensorDict + + obj = PersistentTensorDict(group=obj) + if auto_batch_size: + obj.auto_batch_size_() + return obj + return obj + + @classmethod + def from_tuple(cls, obj, *, auto_batch_size: bool = False): + from tensordict import TensorDict + + result = TensorDict({str(i): cls.from_any(item) for i, item in enumerate(obj)}) + if auto_batch_size: + result.auto_batch_size_() + return result + + @classmethod + def from_dataclass( + cls, dataclass, *, auto_batch_size: bool = False, as_tensorclass: bool = False + ): + """Converts a dataclass into a TensorDict instance. + + Args: + dataclass: The dataclass instance to be converted. + + Keyword Args: + auto_batch_size (bool, optional): If ``True``, automatically determines and applies batch size to the + resulting TensorDict. Defaults to ``False``. + as_tensorclass (bool, optional): If ``True``, delegates the conversion to the free function + :func:`~tensordict.from_dataclass` and returns a tensor-compatible class + (:func:`~tensordict.tensorclass`) or instance instead of a ``TensorDict``. Defaults to ``False``. + + Returns: + A TensorDict instance derived from the provided dataclass, unless `as_tensorclass` is True, in which case a tensor-compatible class or instance is returned. + + Raises: + TypeError: If the provided input is not a dataclass instance. + + .. warning:: This method is distinct from the free function `from_dataclass` and serves a different purpose. + While the free function returns a tensor-compatible class or instance, this method returns a TensorDict instance. + + .. notes:: + + - This method creates a new TensorDict instance with keys corresponding to the fields of the input dataclass. + - Each key in the resulting TensorDict is initialized using the `cls.from_any` method. + - The `auto_batch_size` option allows for automatic batch size determination and application to the + resulting TensorDict. + + """ + if as_tensorclass: + from tensordict.tensorclass import from_dataclass + + return from_dataclass(dataclass, auto_batch_size=auto_batch_size) + from dataclasses import fields + + from tensordict import TensorDict + + if not is_dataclass(dataclass): + raise TypeError( + f"Expected a dataclass input, got a {type(dataclass)} input instead." + ) + source = {} + for field in fields(dataclass): + source[field.name] = cls.from_any(getattr(dataclass, field.name)) + result = TensorDict(source) + if auto_batch_size: + result.auto_batch_size_() + return result + @classmethod def from_namedtuple(cls, named_tuple, *, auto_batch_size: bool = False): """Converts a namedtuple to a TensorDict recursively. @@ -9885,8 +10003,7 @@ def namedtuple_to_dict(namedtuple_obj): "indices": namedtuple_obj.indices, } for key, value in namedtuple_obj.items(): - if is_namedtuple(value): - namedtuple_obj[key] = namedtuple_to_dict(value) + namedtuple_obj[key] = cls.from_any(value) return dict(namedtuple_obj) result = TensorDict(namedtuple_to_dict(named_tuple)) diff --git a/tensordict/functional.py b/tensordict/functional.py index a40095141..2699f36bb 100644 --- a/tensordict/functional.py +++ b/tensordict/functional.py @@ -437,6 +437,7 @@ def make_tensordict( input_dict: dict[str, CompatibleType] | None = None, batch_size: Sequence[int] | torch.Size | int | None = None, device: DeviceType | None = None, + auto_batch_size: bool | None = None, **kwargs: CompatibleType, # source ) -> TensorDict: """Returns a TensorDict created from the keyword arguments or an input dictionary. @@ -453,6 +454,8 @@ def make_tensordict( (incompatible with nested keys). batch_size (iterable of int, optional): a batch size for the tensordict. device (torch.device or compatible type, optional): a device for the TensorDict. + auto_batch_size (bool, optional): if ``True``, the batch size will be computed automatically. + Defaults to ``False``. Examples: >>> input_dict = {"a": torch.randn(3, 4), "b": torch.randn(3)} @@ -500,4 +503,6 @@ def make_tensordict( """ if input_dict is not None: kwargs.update(input_dict) - return TensorDict.from_dict(kwargs, batch_size=batch_size, device=device) + return TensorDict.from_dict( + kwargs, batch_size=batch_size, device=device, auto_batch_size=auto_batch_size + ) diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 0b55d1cef..ffedba9ad 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -297,9 +297,11 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: f"The key {expected_key} wasn't found in the keyword arguments " f"but is expected to execute that function." ) + batch_size = torch.Size([]) if not self.auto_batch_size else None tensordict = make_tensordict( tensordict_values, - batch_size=torch.Size([]) if not self.auto_batch_size else None, + batch_size=batch_size, + auto_batch_size=False, ) if _self is not None: out = func(_self, tensordict, *args, **kwargs) diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 00d984330..bc07b7689 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -928,7 +928,13 @@ def _exclude( @_carry_over def from_dict_instance( - self, input_dict, batch_size=None, device=None, batch_dims=None + self, + input_dict, + *, + auto_batch_size: bool = False, + batch_size=None, + device=None, + batch_dims=None, ): ... @_carry_over diff --git a/tensordict/persistent.py b/tensordict/persistent.py index d5f59110a..332023587 100644 --- a/tensordict/persistent.py +++ b/tensordict/persistent.py @@ -207,12 +207,25 @@ def from_h5(cls, filename, mode="r"): return out @classmethod - def from_dict(cls, input_dict, filename, batch_size=None, device=None, **kwargs): + def from_dict( + cls, + input_dict, + filename, + *others, + auto_batch_size: bool = False, + batch_size=None, + device=None, + **kwargs, + ): """Converts a dictionary or a TensorDict to a h5 file. Args: input_dict (dict, TensorDict or compatible): data to be stored as h5. filename (str or path): path to the h5 file. + + Keyword Args: + auto_batch_size (bool, optional): if ``True``, the batch size will be computed automatically. + Defaults to ``False``. batch_size (tensordict batch-size, optional): if provided, batch size of the tensordict. If not, the batch size will be gathered from the input structure (if present) or determined automatically. @@ -225,6 +238,19 @@ def from_dict(cls, input_dict, filename, batch_size=None, device=None, **kwargs) A :class:`PersitentTensorDict` instance linked to the newly created file. """ + if others: + if batch_size is not None: + raise TypeError( + "conflicting batch size values. Please use the keyword argument only." + ) + warnings.warn( + "All positional arguments after filename will be deprecated in v0.8. Please use keyword arguments instead." + ) + if len(others) == 2: + batch_size, device = others + else: + batch_size = others[0] + import h5py file = h5py.File(filename, "w", locking=cls.LOCKING) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 2556729e5..e1c8e77b4 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -27,6 +27,7 @@ from textwrap import indent from typing import Any, Callable, get_type_hints, List, Sequence, Type, TypeVar +from warnings import warn import numpy as np import orjson as json @@ -45,6 +46,7 @@ CompatibleType, ) from tensordict.utils import ( # @manual=//pytorch/tensordict:_C + _is_dataclass as is_dataclass, _is_json_serializable, _is_tensorclass, _LOCK_ERROR, @@ -237,6 +239,12 @@ def __subclasscheck__(self, subclass): "floor_", "frac", "frac_", + "from_any", + "from_dataclass", + "to_namedtuple", + "from_namedtuple", + "from_pytree", + "to_pytree", "gather", "isfinite", "isnan", @@ -379,6 +387,100 @@ def __call__(self, cls: T) -> T: return clz +def from_dataclass( + obj: Any, + *, + auto_batch_size: bool = False, + frozen: bool = False, + autocast: bool = False, + nocast: bool = False, +) -> Any: + """Converts a dataclass instance or a type into a tensorclass instance or type, respectively. + + This function takes a dataclass instance or a dataclass type and converts it into a tensor-compatible class, + optionally applying various configurations such as auto-batching, immutability, and type casting. + + Args: + obj (Any): The dataclass instance or type to be converted. If a type is provided, a new class is returned. + + Keyword Args: + auto_batch_size (bool, optional): If ``True``, automatically determines and applies batch size to the resulting object. Defaults to ``False``. + frozen (bool, optional): If ``True``, the resulting class or instance will be immutable. Defaults to ``False``. + autocast (bool, optional): If ``True``, enables automatic type casting for the resulting class or instance. Defaults to ``False``. + nocast (bool, optional): If ``True``, disables any type casting for the resulting class or instance. Defaults to ``False``. + + Returns: + A tensor-compatible class or instance derived from the provided dataclass. + + Raises: + TypeError: If the provided input is not a dataclass instance or type. + + Examples: + >>> from dataclasses import dataclass + >>> import torch + >>> from tensordict.tensorclass import from_dataclass + >>> + >>> @dataclass + >>> class X: + ... a: int + ... b: torch.Tensor + ... + >>> x = X(0, 0) + >>> x2 = from_dataclass(x) + >>> print(x2) + X( + a=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), + b=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> X2 = from_dataclass(X, autocast=True) + >>> print(X2(a=0, b=0)) + X( + a=NonTensorData(data=0, batch_size=torch.Size([]), device=None), + b=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), + batch_size=torch.Size([]), + device=None, + is_shared=False) + + .. notes:: If a dataclass type is provided, a new class is returned with the specified configurations. + If a dataclass instance is provided, a new instance of the tensor-compatible class is returned. + The `auto_batch_size`, `frozen`, `autocast`, and `nocast` options allow for flexible configuration of the resulting class or instance. + + .. warning:: Whereas :meth:`~tensordict.TensorDict.from_dataclass` will return a :class:`~tensordict.TensorDict` instance + by default, this method will return a tensorclass instance or type. + + """ + from dataclasses import asdict, make_dataclass + + if isinstance(obj, type): + if is_tensorclass(obj): + return obj + cls = make_dataclass( + obj.__name__ + "_tc", fields=obj.__dataclass_fields__, bases=obj.__bases__ + ) + clz = _tensorclass(cls, frozen=frozen) + clz._type_hints = get_type_hints(obj) + clz._autocast = autocast + clz._nocast = nocast + clz._frozen = frozen + return clz + + if not is_dataclass(obj): + raise TypeError(f"Expected a obj input, got a {type(obj)} input instead.") + name = obj.__class__.__name__ + "_tc" + clz = _tensorclass( + make_dataclass(name, fields=obj.__dataclass_fields__), frozen=frozen + ) + clz._autocast = autocast + clz._nocast = nocast + clz._frozen = frozen + result = clz(**asdict(obj)) + if auto_batch_size: + result = result.auto_batch_size_() + return result + + @dataclass_transform() def tensorclass( cls: T = None, @@ -532,6 +634,8 @@ def __torch_function__( _is_non_tensor = getattr(cls, "_is_non_tensor", False) + # Breaks some tests, don't do that: + # if not dataclasses.is_dataclass(cls): cls = dataclass(cls, frozen=frozen) _TENSORCLASS_MEMO[cls] = True @@ -1266,7 +1370,7 @@ def _update( non_blocking: bool = False, ): if isinstance(input_dict_or_td, dict): - input_dict_or_td = self.from_dict(input_dict_or_td) + input_dict_or_td = self.from_dict(input_dict_or_td, auto_batch_size=False) if is_tensorclass(input_dict_or_td): non_tensordict = { @@ -1478,7 +1582,15 @@ def _to_dict(self, *, retain_none: bool = True) -> dict: return td_dict -def _from_dict(cls, input_dict, batch_size=None, device=None, batch_dims=None): +def _from_dict( + cls, + input_dict, + *, + auto_batch_size: bool | None = None, + batch_size=None, + device=None, + batch_dims=None, +): # we pass through a tensordict because keys could be passed as NestedKeys # We can't assume all keys are strings, otherwise calling cls(**kwargs) # would work ok @@ -1492,7 +1604,11 @@ def _from_dict(cls, input_dict, batch_size=None, device=None, batch_dims=None): non_tensordict=input_dict, ) td = TensorDict.from_dict( - input_dict, batch_size=batch_size, device=device, batch_dims=batch_dims + input_dict, + batch_size=batch_size, + device=device, + batch_dims=batch_dims, + auto_batch_size=auto_batch_size, ) non_tensordict = {} @@ -1500,7 +1616,13 @@ def _from_dict(cls, input_dict, batch_size=None, device=None, batch_dims=None): def _from_dict_instance( - self, input_dict, batch_size=None, device=None, batch_dims=None + self, + input_dict, + *, + auto_batch_size: bool | None = None, + batch_size=None, + device=None, + batch_dims=None, ): if batch_dims is not None and batch_size is not None: raise ValueError("Cannot pass both batch_size and batch_dims to `from_dict`.") @@ -1510,7 +1632,7 @@ def _from_dict_instance( # TODO: this is a bit slow and will be a bottleneck every time td[idx] = dict(subtd) # is called when there are non tensor data in it if not _is_tensor_collection(type(input_dict)): - input_tdict = TensorDict.from_dict(input_dict) + input_tdict = TensorDict.from_dict(input_dict, auto_batch_size=auto_batch_size) else: input_tdict = input_dict trsf_dict = {} @@ -1538,7 +1660,19 @@ def _from_dict_instance( ) # check that if batch_size is None: - out._tensordict.auto_batch_size_() + if auto_batch_size is None and batch_dims is None: + warn( + "The batch-size was not provided and auto_batch_size isn't set either. " + "Currently, from_dict will call set auto_batch_size=True but this behaviour " + "will be changed in v0.8 and auto_batch_size will be False onward. " + "To silence this warning, pass auto_batch_size directly.", + category=DeprecationWarning, + ) + auto_batch_size = True + elif auto_batch_size is None: + auto_batch_size = True + if auto_batch_size: + out.auto_batch_size_() return out @@ -1658,7 +1792,7 @@ def _is_castable(datatype): if isinstance(value, dict): if _is_tensor_collection(target_cls): - cast_val = target_cls.from_dict(value) + cast_val = target_cls.from_dict(value, auto_batch_size=False) self._tensordict.set( key, cast_val, inplace=inplace, non_blocking=non_blocking ) @@ -2483,7 +2617,6 @@ def __post_init__(self): data_inner = data.tolist() del _tensordict["data"] _non_tensordict["data"] = data_inner - # assert _tensordict.is_empty(), self._tensordict # TODO: this will probably fail with dynamo at some point, + it's terrible. # Make sure it's patched properly at init time diff --git a/tensordict/tensorclass.pyi b/tensordict/tensorclass.pyi index 75678b4b6..a77ef185a 100644 --- a/tensordict/tensorclass.pyi +++ b/tensordict/tensorclass.pyi @@ -209,9 +209,16 @@ class TensorClass: def auto_batch_size_(self, batch_dims: int | None = None) -> T: ... def auto_device_(self) -> T: ... @classmethod + def from_dataclass( + cls, dataclass, *, auto_batch_size: bool = False, as_tensorclass: bool = False + ): ... + @classmethod + def from_any(cls, obj, *, auto_batch_size: bool = False): ... + @classmethod def from_dict( cls, input_dict, + *, batch_size: torch.Size | None = None, device: torch.device | None = None, batch_dims: int | None = None, diff --git a/tensordict/utils.py b/tensordict/utils.py index cdc0756f8..81ab2fa0c 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -20,6 +20,7 @@ from collections import defaultdict from collections.abc import KeysView from copy import copy +from dataclasses import _FIELDS, GenericAlias from functools import wraps from importlib import import_module from numbers import Number @@ -858,7 +859,7 @@ def is_tensorclass(obj: type | Any) -> bool: def _is_tensorclass(cls: type) -> bool: - out = _TENSORCLASS_MEMO.get(cls, None) + out = _TENSORCLASS_MEMO.get(cls) if out is None: out = getattr(cls, "_is_tensorclass", False) if not is_dynamo_compiling(): @@ -2813,3 +2814,13 @@ def _mismatch_keys(keys1, keys2): if sub2 is not None: main.append(sub2) raise KeyError(r" ".join(main)) + + +def _is_dataclass(obj): + """Like dataclasses.is_dataclass but compatible with compile.""" + cls = ( + obj + if isinstance(obj, type) and not isinstance(obj, GenericAlias) + else type(obj) + ) + return hasattr(cls, _FIELDS) diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 8879f0e68..ad1a194cd 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -53,7 +53,8 @@ class TestTensorDictsBase: TYPES_DEVICES = [] TYPES_DEVICES_NOLAZY = [] - def td(self, device): + @classmethod + def td(cls, device): return TensorDict( source={ "a": torch.randn(4, 3, 2, 1, 5), @@ -68,7 +69,8 @@ def td(self, device): TYPES_DEVICES += [["td", device]] TYPES_DEVICES_NOLAZY += [["td", device]] - def nested_td(self, device): + @classmethod + def nested_td(cls, device): return TensorDict( source={ "a": torch.randn(4, 3, 2, 1, 5), @@ -86,7 +88,8 @@ def nested_td(self, device): TYPES_DEVICES += [["nested_td", device]] TYPES_DEVICES_NOLAZY += [["nested_td", device]] - def nested_tensorclass(self, device): + @classmethod + def nested_tensorclass(cls, device): nested_class = MyClass( X=torch.randn(4, 3, 2, 1), @@ -119,8 +122,9 @@ def nested_tensorclass(self, device): TYPES_DEVICES += [["nested_tensorclass", device]] TYPES_DEVICES_NOLAZY += [["nested_tensorclass", device]] + @classmethod @set_lazy_legacy(True) - def nested_stacked_td(self, device): + def nested_stacked_td(cls, device): td = TensorDict( source={ "a": torch.randn(4, 3, 2, 1, 5), @@ -140,8 +144,9 @@ def nested_stacked_td(self, device): TYPES_DEVICES += [["nested_stacked_td", device]] TYPES_DEVICES_NOLAZY += [["nested_stacked_td", device]] + @classmethod @set_lazy_legacy(True) - def stacked_td(self, device): + def stacked_td(cls, device): td1 = TensorDict( source={ "a": torch.randn(4, 3, 1, 5), @@ -165,7 +170,8 @@ def stacked_td(self, device): for device in get_available_devices(): TYPES_DEVICES += [["stacked_td", device]] - def idx_td(self, device): + @classmethod + def idx_td(cls, device): td = TensorDict( source={ "a": torch.randn(2, 4, 3, 2, 1, 5), @@ -180,7 +186,8 @@ def idx_td(self, device): for device in get_available_devices(): TYPES_DEVICES += [["idx_td", device]] - def sub_td(self, device): + @classmethod + def sub_td(cls, device): td = TensorDict( source={ "a": torch.randn(2, 4, 3, 2, 1, 5), @@ -195,7 +202,8 @@ def sub_td(self, device): for device in get_available_devices(): TYPES_DEVICES += [["sub_td", device]] - def sub_td2(self, device): + @classmethod + def sub_td2(cls, device): td = TensorDict( source={ "a": torch.randn(4, 2, 3, 2, 1, 5), @@ -212,17 +220,19 @@ def sub_td2(self, device): temp_path_memmap = tempfile.TemporaryDirectory() - def memmap_td(self, device): - path = pathlib.Path(self.temp_path_memmap.name) + @classmethod + def memmap_td(cls, device): + path = pathlib.Path(cls.temp_path_memmap.name) shutil.rmtree(path) path.mkdir() - return self.td(device).memmap_(path) + return cls.td(device).memmap_(path) TYPES_DEVICES += [["memmap_td", torch.device("cpu")]] TYPES_DEVICES_NOLAZY += [["memmap_td", torch.device("cpu")]] + @classmethod @set_lazy_legacy(True) - def permute_td(self, device): + def permute_td(cls, device): return TensorDict( source={ "a": torch.randn(3, 1, 4, 2, 5), @@ -236,8 +246,9 @@ def permute_td(self, device): for device in get_available_devices(): TYPES_DEVICES += [["permute_td", device]] + @classmethod @set_lazy_legacy(True) - def unsqueezed_td(self, device): + def unsqueezed_td(cls, device): td = TensorDict( source={ "a": torch.randn(4, 3, 2, 5), @@ -252,8 +263,9 @@ def unsqueezed_td(self, device): for device in get_available_devices(): TYPES_DEVICES += [["unsqueezed_td", device]] + @classmethod @set_lazy_legacy(True) - def squeezed_td(self, device): + def squeezed_td(cls, device): td = TensorDict( source={ "a": torch.randn(4, 3, 1, 2, 1, 5), @@ -268,7 +280,8 @@ def squeezed_td(self, device): for device in get_available_devices(): TYPES_DEVICES += [["squeezed_td", device]] - def td_reset_bs(self, device): + @classmethod + def td_reset_bs(cls, device): td = TensorDict( source={ "a": torch.randn(4, 3, 2, 1, 5), @@ -285,13 +298,14 @@ def td_reset_bs(self, device): TYPES_DEVICES += [["td_reset_bs", device]] TYPES_DEVICES_NOLAZY += [["td_reset_bs", device]] + @classmethod def td_h5( - self, + cls, device, ): file = tempfile.NamedTemporaryFile() filename = file.name - nested_td = self.nested_td(device) + nested_td = cls.nested_td(device) td_h5 = PersistentTensorDict.from_dict( nested_td, filename=filename, device=device ) @@ -303,15 +317,17 @@ def td_h5( TYPES_DEVICES += [["td_h5", device]] TYPES_DEVICES_NOLAZY += [["td_h5", device]] - def td_params(self, device): - return TensorDictParams(self.td(device)) + @classmethod + def td_params(cls, device): + return TensorDictParams(cls.td(device)) for device in get_available_devices(): TYPES_DEVICES += [["td_params", device]] TYPES_DEVICES_NOLAZY += [["td_params", device]] - def td_with_non_tensor(self, device): - td = self.td(device) + @classmethod + def td_with_non_tensor(cls, device): + td = cls.td(device) return td.set_non_tensor( ("data", "non_tensor"), # this is allowed since nested NonTensorData are automatically unwrapped diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index 4753c3704..0f71bd743 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -23,6 +23,7 @@ import tensordict.utils import torch from tensordict import TensorClass +from tensordict.tensorclass import from_dataclass try: import torchsnapshot @@ -94,6 +95,21 @@ class MyData2: z: list +@dataclasses.dataclass +class MyDataClass: + a: int + b: torch.Tensor + c: str + + +try: + MyTensorClass_autocast = from_dataclass(MyDataClass, autocast=True) + MyTensorClass_nocast = from_dataclass(MyDataClass, nocast=True) + MyTensorClass = from_dataclass(MyDataClass) +except Exception: + MyTensorClass_autocast = MyTensorClass_nocast = MyTensorClass = None + + class TestTensorClass: def test_all_any(self): @tensorclass @@ -517,6 +533,43 @@ class MyClass2: assert (a != c.clone().zero_()).any() assert (c != a.clone().zero_()).any() + def test_from_dataclass(self): + assert is_tensorclass(MyTensorClass_autocast) + assert MyTensorClass_nocast is not MyDataClass + assert MyTensorClass_autocast._autocast + x = MyTensorClass_autocast(a=0, b=0, c=0) + assert isinstance(x.a, int) + assert isinstance(x.b, torch.Tensor) + assert isinstance(x.c, str) + + assert is_tensorclass(MyTensorClass_nocast) + assert MyTensorClass_nocast is not MyTensorClass_autocast + assert MyTensorClass_nocast._nocast + + x = MyTensorClass_nocast(a=0, b=0, c=0) + assert is_tensorclass(MyTensorClass) + assert not MyTensorClass._autocast + assert not MyTensorClass._nocast + assert isinstance(x.a, int) + assert isinstance(x.b, int) + assert isinstance(x.c, int) + + x = MyTensorClass(a=0, b=0, c=0) + assert isinstance(x.a, torch.Tensor) + assert isinstance(x.b, torch.Tensor) + assert isinstance(x.c, torch.Tensor) + + x = TensorDict.from_dataclass(MyTensorClass(a=0, b=0, c=0)) + assert isinstance(x, TensorDict) + assert isinstance(x["a"], torch.Tensor) + assert isinstance(x["b"], torch.Tensor) + assert isinstance(x["c"], torch.Tensor) + x = from_dataclass(MyTensorClass(a=0, b=0, c=0)) + assert is_tensorclass(x) + assert isinstance(x.a, torch.Tensor) + assert isinstance(x.b, torch.Tensor) + assert isinstance(x.c, torch.Tensor) + def test_from_dict(self): td = TensorDict( { @@ -531,7 +584,7 @@ def test_from_dict(self): class MyClass: a: TensorDictBase - tc = MyClass.from_dict(d) + tc = MyClass.from_dict(d, auto_batch_size=True) assert isinstance(tc, MyClass) assert isinstance(tc.a, TensorDict) assert tc.batch_size == torch.Size([10]) @@ -2095,7 +2148,9 @@ class TestClass: my_tensor=torch.tensor([1, 2, 3]), my_str="hello", batch_size=[3] ) - assert (test_class == TestClass.from_dict(test_class.to_dict())).all() + assert ( + test_class == TestClass.from_dict(test_class.to_dict(), auto_batch_size=True) + ).all() # Currently we don't test non-tensor in __eq__ because __eq__ can break with arrays and such # test_class2 = TestClass( @@ -2108,7 +2163,9 @@ class TestClass: my_tensor=torch.tensor([1, 2, 0]), my_str="hello", batch_size=[3] ) - assert not (test_class == TestClass.from_dict(test_class3.to_dict())).all() + assert not ( + test_class == TestClass.from_dict(test_class3.to_dict(), auto_batch_size=True) + ).all() @tensorclass(autocast=True) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 257c2b712..73d401c03 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -20,6 +20,7 @@ import warnings from dataclasses import dataclass from pathlib import Path +from typing import Any import numpy as np import pytest @@ -949,6 +950,64 @@ def test_fromkeys(self): td = TensorDict.fromkeys({"a", "b", "c"}, 1) assert td["a"] == 1 + def test_from_any(self): + from dataclasses import dataclass + + @dataclass + class MyClass: + a: int + + pytree = ( + [torch.randint(10, (3,)), torch.zeros(2)], + { + "tensor": torch.randn( + 2, + ), + "td": TensorDict({"one": 1}), + "tuple": (1, 2, 3), + }, + {"named_tuple": TensorDict({"two": torch.ones(1) * 2}).to_namedtuple()}, + {"dataclass": MyClass(a=0)}, + ) + if _has_h5py: + pytree = pytree + ({"h5py": TestTensorDictsBase.td_h5(device="cpu").file},) + td = TensorDict.from_any(pytree) + expected = { + ("0", "0"), + ("0", "1"), + ("1", "td", "one"), + ("1", "tensor"), + ("1", "tuple", "0"), + ("1", "tuple", "1"), + ("1", "tuple", "2"), + ("2", "named_tuple", "two"), + ("3", "dataclass", "a"), + } + if _has_h5py: + expected = expected.union( + { + ("4", "h5py", "a"), + ("4", "h5py", "b"), + ("4", "h5py", "c"), + ("4", "h5py", "my_nested_td", "inner"), + } + ) + assert set(td.keys(True, True)) == expected, set( + td.keys(True, True) + ).symmetric_difference(expected) + + def test_from_dataclass(self): + @dataclass + class MyClass: + a: int + b: Any + + obj = MyClass(a=0, b=1) + obj_td = TensorDict.from_dataclass(obj) + obj_tc = TensorDict.from_dataclass(obj, as_tensorclass=True) + assert is_tensorclass(obj_tc) + assert not is_tensorclass(obj_td) + @pytest.mark.parametrize("batch_size", [None, [3, 4]]) @pytest.mark.parametrize("batch_dims", [None, 1, 2]) @pytest.mark.parametrize("device", get_available_devices()) @@ -967,7 +1026,11 @@ def test_from_dict(self, batch_size, batch_dims, device): ) return data = TensorDict.from_dict( - data, batch_size=batch_size, batch_dims=batch_dims, device=device + data, + batch_size=batch_size, + batch_dims=batch_dims, + device=device, + auto_batch_size=True, ) assert data.device == device assert "a" in data.keys() @@ -1001,7 +1064,7 @@ class MyClass: assert isinstance(td_dict["b"]["y"], int) assert isinstance(td_dict["b"]["z"], dict) assert isinstance(td_dict["b"]["z"]["y"], int) - td_recon = td.from_dict_instance(td_dict) + td_recon = td.from_dict_instance(td_dict, auto_batch_size=True) assert isinstance(td_recon["a"], torch.Tensor) assert isinstance(td_recon["b"], MyClass) assert isinstance(td_recon["b"].x, torch.Tensor) @@ -6443,7 +6506,7 @@ def recursive_checker(cur_dict): assert recursive_checker(td_dict) if td_name == "td_with_non_tensor": assert td_dict["data"]["non_tensor"] == "some text data" - assert (TensorDict.from_dict(td_dict) == td).all() + assert (TensorDict.from_dict(td_dict, auto_batch_size=False) == td).all() def test_to_namedtuple(self, td_name, device): def is_namedtuple(obj): @@ -7771,7 +7834,7 @@ def test_mp(self, td_type, unbind_as): class TestMakeTensorDict: def test_create_tensordict(self): - tensordict = make_tensordict(a=torch.zeros(3, 4)) + tensordict = make_tensordict(a=torch.zeros(3, 4), auto_batch_size=True) assert (tensordict["a"] == torch.zeros(3, 4)).all() def test_nested(self): @@ -7779,7 +7842,7 @@ def test_nested(self): "a": {"b": torch.randn(3, 4), "c": torch.randn(3, 4, 5)}, "d": torch.randn(3), } - tensordict = make_tensordict(input_dict) + tensordict = make_tensordict(input_dict, auto_batch_size=True) assert tensordict.shape == torch.Size([3]) assert tensordict["a"].shape == torch.Size([3, 4]) input_tensordict = TensorDict( @@ -7789,7 +7852,7 @@ def test_nested(self): }, [], ) - tensordict = make_tensordict(input_tensordict) + tensordict = make_tensordict(input_tensordict, auto_batch_size=True) assert tensordict.shape == torch.Size([3]) assert tensordict["a"].shape == torch.Size([3, 4]) input_dict = { @@ -7797,30 +7860,40 @@ def test_nested(self): ("a", "c"): torch.randn(3, 4, 5), "d": torch.randn(3), } - tensordict = make_tensordict(input_dict) + tensordict = make_tensordict(input_dict, auto_batch_size=True) assert tensordict.shape == torch.Size([3]) assert tensordict["a"].shape == torch.Size([3, 4]) def test_tensordict_batch_size(self): - tensordict = make_tensordict() + tensordict = make_tensordict(auto_batch_size=True) assert tensordict.batch_size == torch.Size([]) - tensordict = make_tensordict(a=torch.randn(3, 4)) + tensordict = make_tensordict(a=torch.randn(3, 4), auto_batch_size=True) assert tensordict.batch_size == torch.Size([3, 4]) - tensordict = make_tensordict(a=torch.randn(3, 4), b=torch.randn(3, 4, 5)) + tensordict = make_tensordict( + a=torch.randn(3, 4), b=torch.randn(3, 4, 5), auto_batch_size=True + ) assert tensordict.batch_size == torch.Size([3, 4]) - nested_tensordict = make_tensordict(c=tensordict, d=torch.randn(3, 5)) # nested + nested_tensordict = make_tensordict( + c=tensordict, d=torch.randn(3, 5), auto_batch_size=True + ) # nested assert nested_tensordict.batch_size == torch.Size([3]) - nested_tensordict = make_tensordict(c=tensordict, d=torch.randn(4, 5)) # nested + nested_tensordict = make_tensordict( + c=tensordict, d=torch.randn(4, 5), auto_batch_size=True + ) # nested assert nested_tensordict.batch_size == torch.Size([]) - tensordict = make_tensordict(a=torch.randn(3, 4, 2), b=torch.randn(3, 4, 5)) + tensordict = make_tensordict( + a=torch.randn(3, 4, 2), b=torch.randn(3, 4, 5), auto_batch_size=True + ) assert tensordict.batch_size == torch.Size([3, 4]) - tensordict = make_tensordict(a=torch.randn(3, 4), b=torch.randn(1)) + tensordict = make_tensordict( + a=torch.randn(3, 4), b=torch.randn(1), auto_batch_size=True + ) assert tensordict.batch_size == torch.Size([]) tensordict = make_tensordict( @@ -7836,7 +7909,10 @@ def test_tensordict_batch_size(self): @pytest.mark.parametrize("device", get_available_devices()) def test_tensordict_device(self, device): tensordict = make_tensordict( - a=torch.randn(3, 4), b=torch.randn(3, 4), device=device + a=torch.randn(3, 4), + b=torch.randn(3, 4), + device=device, + auto_batch_size=True, ) assert tensordict.device == device assert tensordict["a"].device == device @@ -7847,6 +7923,7 @@ def test_tensordict_device(self, device): b=torch.randn(3, 4), c=torch.randn(3, 4, device="cpu"), device=device, + auto_batch_size=True, ) assert tensordict.device == device assert tensordict["a"].device == device @@ -10584,7 +10661,8 @@ def test_non_tensor_call(self): def test_nontensor_dict(self, non_tensor_data): assert ( - TensorDict.from_dict(non_tensor_data.to_dict()) == non_tensor_data + TensorDict.from_dict(non_tensor_data.to_dict(), auto_batch_size=True) + == non_tensor_data ).all() def test_nontensor_tensor(self): @@ -11125,7 +11203,7 @@ def _to_float(td, td_name, tmpdir): td._source = td._source.float() elif td_name in ("td_h5",): td = PersistentTensorDict.from_dict( - td.float().to_dict(), filename=tmpdir + "/file.t" + td.float().to_dict(), filename=tmpdir + "/file.t", auto_batch_size=True ) elif td_name in ("td_params",): td = TensorDictParams(td.data.float()) From df61d64a34ab6b3ccb6b3ed2b876e1ef827d2f72 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 23 Nov 2024 20:23:28 +0100 Subject: [PATCH 2/8] [BugFix] select_out_keys for Prob sequential ghstack-source-id: a566ae225c54f07a680b4bf380b16d8e797f62ea Pull Request resolved: https://github.com/pytorch/tensordict/pull/1103 --- tensordict/nn/probabilistic.py | 29 +++++++++++++------ tensordict/nn/sequence.py | 6 ++-- test/test_nn.py | 53 ++++++++++++++++++++++++++++++++++ 3 files changed, 76 insertions(+), 12 deletions(-) diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index 3b254731d..e13c43f6d 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -283,7 +283,6 @@ def __init__( in_keys: NestedKey | List[NestedKey] | Dict[str, NestedKey], out_keys: NestedKey | List[NestedKey] | None = None, *, - default_interaction_mode: str | None = None, default_interaction_type: InteractionType = InteractionType.DETERMINISTIC, distribution_class: type = Delta, distribution_kwargs: dict | None = None, @@ -332,11 +331,7 @@ def __init__( log_prob_key = "sample_log_prob" self.log_prob_key = log_prob_key - if default_interaction_mode is not None: - raise ValueError( - "default_interaction_mode is deprecated, use default_interaction_type instead." - ) - self.default_interaction_type = default_interaction_type + self.default_interaction_type = InteractionType(default_interaction_type) if isinstance(distribution_class, str): distribution_class = distributions_maps.get(distribution_class.lower()) @@ -356,7 +351,7 @@ def get_dist(self, tensordict: TensorDictBase) -> D.Distribution: for dist_key, td_key in _zip_strict(self.dist_keys, self.in_keys): if isinstance(dist_key, tuple): dist_key = dist_key[-1] - dist_kwargs[dist_key] = tensordict.get(td_key) + dist_kwargs[dist_key] = tensordict.get(td_key, None) dist = self.distribution_class( **dist_kwargs, **_dynamo_friendly_to_dict(self.distribution_kwargs) ) @@ -630,8 +625,24 @@ def forward( tensordict_out: TensorDictBase | None = None, **kwargs, ) -> TensorDictBase: - tensordict_out = self.get_dist_params(tensordict, tensordict_out, **kwargs) - return self.module[-1](tensordict_out, _requires_sample=self._requires_sample) + if (tensordict_out is None and self._select_before_return) or ( + tensordict_out is not None + ): + tensordict_exec = tensordict.copy() + else: + tensordict_exec = tensordict + tensordict_exec = self.get_dist_params(tensordict_exec, **kwargs) + tensordict_exec = self.module[-1]( + tensordict_exec, _requires_sample=self._requires_sample + ) + if tensordict_out is not None: + result = tensordict_out + result.update(tensordict_exec, keys_to_update=self.out_keys) + else: + result = tensordict_exec + if self._select_before_return: + return tensordict.update(result, keys_to_update=self.out_keys) + return result def _dynamo_friendly_to_dict(data): diff --git a/tensordict/nn/sequence.py b/tensordict/nn/sequence.py index d33577975..5f8c84bb9 100644 --- a/tensordict/nn/sequence.py +++ b/tensordict/nn/sequence.py @@ -470,13 +470,13 @@ def forward( tensordict_out: TensorDictBase | None = None, **kwargs: Any, ) -> TensorDictBase: - if tensordict_out is None and self._select_before_return: + if (tensordict_out is None and self._select_before_return) or ( + tensordict_out is not None + ): tensordict_exec = tensordict.copy() else: tensordict_exec = tensordict if not len(kwargs): - if tensordict_out is not None: - tensordict_exec = tensordict_exec.copy() for module in self.module: tensordict_exec = self._run_module(module, tensordict_exec, **kwargs) else: diff --git a/test/test_nn.py b/test/test_nn.py index 274d08b29..8c78dd3b2 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -890,6 +890,59 @@ def test_stateful_probabilistic_deprec(self, lazy): dist = tdmodule.get_dist(td) assert dist.rsample().shape[: td.ndimension()] == td.shape + @pytest.mark.parametrize("return_log_prob", [True, False]) + @pytest.mark.parametrize("td_out", [True, False]) + def test_probtdseq(self, return_log_prob, td_out): + mod = ProbabilisticTensorDictSequential( + TensorDictModule(lambda x: x + 2, in_keys=["a"], out_keys=["c"]), + TensorDictModule(lambda x: (x + 2, x), in_keys=["b"], out_keys=["d", "e"]), + ProbabilisticTensorDictModule( + in_keys={"loc": "d", "scale": "e"}, + out_keys=["f"], + distribution_class=Normal, + return_log_prob=return_log_prob, + default_interaction_type="random", + ), + ) + inp = TensorDict({"a": 0.0, "b": 1.0}) + inp_clone = inp.clone() + if td_out: + out = TensorDict() + else: + out = None + out2 = mod(inp, tensordict_out=out) + assert not mod._select_before_return + if td_out: + assert out is out2 + else: + assert out2 is inp + assert set(out2.keys()) - {"a", "b"} == set(mod.out_keys), ( + td_out, + return_log_prob, + ) + + inp = inp_clone.clone() + mod.select_out_keys("f") + if td_out: + out = TensorDict() + else: + out = None + out2 = mod(inp, tensordict_out=out) + assert mod._select_before_return + if td_out: + assert out is out2 + else: + assert out2 is inp + expected = {"f"} + if td_out: + assert set(out2.keys()) == set(mod.out_keys) == expected + else: + assert ( + set(out2.keys()) - set(inp_clone.keys()) + == set(mod.out_keys) + == expected + ) + @pytest.mark.parametrize("lazy", [True, False]) def test_stateful_probabilistic(self, lazy): torch.manual_seed(0) From d5fcace5326891641a8145499ee5068141c15f03 Mon Sep 17 00:00:00 2001 From: Emmanuel Ferdman Date: Sat, 23 Nov 2024 22:53:15 +0200 Subject: [PATCH 3/8] [Doc] Update `distributed_replay_buffer.py` reference (#1105) Signed-off-by: Emmanuel Ferdman --- docs/source/distributed.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst index 832b5d800..2eaa971e3 100644 --- a/docs/source/distributed.rst +++ b/docs/source/distributed.rst @@ -131,7 +131,7 @@ Although the call to :obj:`rpc.rpc_sync` involved passing the entire tensordict, updating specific indices of this object and return it to the original worker, the execution of this snippet is extremely fast (even more so if the reference to the memory location is already passed beforehand, see `torchrl's distributed -replay buffer documentation `_ to learn more). +replay buffer documentation `_ to learn more). The script contains additional RPC configuration steps that are beyond the purpose of this document. From 3485c2c890ea4672a7775ff0cdc5c87ee41729c4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 25 Nov 2024 08:51:03 +0000 Subject: [PATCH 4/8] [Feature] from_any with UserDict ghstack-source-id: 420464209cff29c3a1c58ec521fbf4ed69d1355f Pull Request resolved: https://github.com/pytorch/tensordict/pull/1106 --- tensordict/base.py | 3 +++ test/test_tensordict.py | 8 ++++++++ 2 files changed, 11 insertions(+) diff --git a/tensordict/base.py b/tensordict/base.py index 6c600b11f..f3034fd96 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -18,6 +18,7 @@ import uuid import warnings import weakref +from collections import UserDict from collections.abc import MutableMapping from concurrent.futures import Future, ThreadPoolExecutor, wait @@ -9871,6 +9872,8 @@ def from_any(cls, obj, *, auto_batch_size: bool = False): return obj if isinstance(obj, dict): return cls.from_dict(obj, auto_batch_size=auto_batch_size) + if isinstance(obj, UserDict): + return cls.from_dict(dict(obj), auto_batch_size=auto_batch_size) if isinstance(obj, np.ndarray) and hasattr(obj.dtype, "names"): return cls.from_struct_array(obj, auto_batch_size=auto_batch_size) if isinstance(obj, tuple): diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 73d401c03..50fc7170e 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -18,6 +18,7 @@ import sys import uuid import warnings +from collections import UserDict from dataclasses import dataclass from pathlib import Path from typing import Any @@ -996,6 +997,13 @@ class MyClass: td.keys(True, True) ).symmetric_difference(expected) + def test_from_any_userdict(self): + class D(UserDict): ... + + d = D(a=0) + assert TensorDict.from_any(d)["a"] == 0 + assert isinstance(TensorDict.from_any(d)["a"], torch.Tensor) + def test_from_dataclass(self): @dataclass class MyClass: From f924afcd1a82dbf9c1cf6cde2147c9a0a5f42e55 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 25 Nov 2024 08:51:03 +0000 Subject: [PATCH 5/8] [Feature] NonTensorStack.from_list ghstack-source-id: e8f349cb06a72dcb69a639420b14406c9c08aa99 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1107 --- tensordict/tensorclass.py | 13 +++++++++++++ test/test_tensordict.py | 34 ++++++++++++++++++++++++---------- 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index e1c8e77b4..728fc4055 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -194,6 +194,7 @@ def __subclasscheck__(self, subclass): "any", "apply", "apply_", + "as_tensor", "asin", "asin_", "atan", @@ -3114,6 +3115,18 @@ def maybe_to_stack(self): stack_dim=self.stack_dim, ) + @classmethod + def from_list(cls, non_tensors: List[Any]): + # Use local function because refers to cls + def _maybe_from_list(nontensor): + if isinstance(nontensor, list): + return cls.from_list(nontensor) + if is_non_tensor(nontensor): + return nontensor + return NonTensorData(nontensor) + + return cls(*[_maybe_from_list(nontensor) for nontensor in non_tensors]) + @classmethod def from_nontensordata(cls, non_tensor: NonTensorData): data = non_tensor.data diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 50fc7170e..737ff4f24 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -142,8 +142,8 @@ def device_fixture(): device = torch.get_default_device() if torch.cuda.is_available(): torch.set_default_device(torch.device("cuda:0")) - elif torch.backends.mps.is_available(): - torch.set_default_device(torch.device("mps:0")) + # elif torch.backends.mps.is_available(): + # torch.set_default_device(torch.device("mps:0")) yield torch.set_default_device(device) @@ -1468,8 +1468,8 @@ def check_meta(tensor): if torch.cuda.is_available(): device = "cuda:0" - elif torch.backends.mps.is_available(): - device = "mps:0" + # elif torch.backends.mps.is_available(): + # device = "mps:0" else: pytest.skip("no device to test") device_state_dict = TensorDict.load(tmpdir, device=device) @@ -1717,8 +1717,8 @@ def test_no_batch_size(self): def test_non_blocking(self): if torch.cuda.is_available(): device = "cuda" - elif torch.backends.mps.is_available(): - device = "mps" + # elif torch.backends.mps.is_available(): + # device = "mps" else: pytest.skip("No device found") for _ in range(10): @@ -1792,9 +1792,9 @@ def test_non_blocking_single_sync(self, _path_td_sync): TensorDict(td_dict, device="cpu") assert _SYNC_COUNTER == 0 - if torch.backends.mps.is_available(): - device = "mps" - elif torch.cuda.is_available(): + # if torch.backends.mps.is_available(): + # device = "mps" + if torch.cuda.is_available(): device = "cuda" else: device = None @@ -9857,7 +9857,8 @@ def check_weakref_count(weakref_list, expected): assert count == expected, {id(ref()) for ref in weakref_list} @pytest.mark.skipif( - not torch.cuda.is_available() and not torch.backends.mps.is_available(), + not torch.cuda.is_available(), + # and not torch.backends.mps.is_available(), reason="a device is required.", ) def test_cached_data_lock_device(self): @@ -10659,6 +10660,19 @@ def test_comparison(self, non_tensor_data): ("nested", "bool") ) + def test_from_list(self): + nd = NonTensorStack.from_list( + [[True, "b", torch.randn(())], ["another", 0, NonTensorData("final")]] + ) + assert isinstance(nd, NonTensorStack) + assert nd.shape == (2, 3) + assert nd[0, 0].data + assert nd[0, 1].data == "b" + assert isinstance(nd[0, 2].data, torch.Tensor) + assert nd[1, 0].data == "another" + assert nd[1, 1].data == 0 + assert nd[1, 2].data == "final" + def test_non_tensor_call(self): td0 = TensorDict({"a": 0, "b": 0}) td1 = TensorDict({"a": 1, "b": 1}) From 1ffc463728de060ac0ee0b3eefdec8bf3be12774 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 25 Nov 2024 08:51:04 +0000 Subject: [PATCH 6/8] [Feature] Better list casting in TensorDict.from_any ghstack-source-id: 427d19d5ef7c0d2779e064e64522fc0094a885af Pull Request resolved: https://github.com/pytorch/tensordict/pull/1108 --- tensordict/base.py | 10 +++++++++- tensordict/nn/common.py | 2 +- tensordict/utils.py | 33 +++++++++++++++++++++++++++++++++ test/test_tensordict.py | 7 +++++++ 4 files changed, 50 insertions(+), 2 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index f3034fd96..159777e28 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -57,6 +57,7 @@ _DTYPE2STRDTYPE, _GENERIC_NESTED_ERR, _is_dataclass as is_dataclass, + _is_list_tensor_compatible, _is_non_tensor, _is_number, _is_tensorclass, @@ -9869,6 +9870,8 @@ def from_any(cls, obj, *, auto_batch_size: bool = False): """ if is_tensor_collection(obj): + if is_non_tensor(obj): + return cls.from_any(obj.data, auto_batch_size=auto_batch_size) return obj if isinstance(obj, dict): return cls.from_dict(obj, auto_batch_size=auto_batch_size) @@ -9881,7 +9884,12 @@ def from_any(cls, obj, *, auto_batch_size: bool = False): return cls.from_namedtuple(obj, auto_batch_size=auto_batch_size) return cls.from_tuple(obj, auto_batch_size=auto_batch_size) if isinstance(obj, list): - return cls.from_tuple(tuple(obj), auto_batch_size=auto_batch_size) + if _is_list_tensor_compatible(obj)[0]: + return torch.tensor(obj) + else: + from tensordict.tensorclass import NonTensorStack + + return NonTensorStack.from_list(obj) if is_dataclass(obj): return cls.from_dataclass(obj, auto_batch_size=auto_batch_size) if _has_h5: diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index ffedba9ad..c38bd31de 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -1003,7 +1003,7 @@ def _write_to_tensordict( tensordict_out = tensordict for _out_key, _tensor in zip(out_keys, tensors): if _out_key != "_": - tensordict_out.set(_out_key, _tensor) + tensordict_out.set(_out_key, TensorDict.from_any(_tensor)) return tensordict_out def _call_module( diff --git a/tensordict/utils.py b/tensordict/utils.py index 81ab2fa0c..2344da517 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -2824,3 +2824,36 @@ def _is_dataclass(obj): else type(obj) ) return hasattr(cls, _FIELDS) + + +def _is_list_tensor_compatible(t) -> Tuple[bool, tuple | None, type | None]: + length_t = len(t) + dtypes = set() + sizes = set() + for i in t: + if isinstance(i, (float, int, torch.SymInt, Number)): + dtypes.add(type(i)) + if len(dtypes) > 1: + return False, None, None + continue + elif isinstance(i, list): + is_compat, size_i, dtype = _is_list_tensor_compatible(i) + if not is_compat: + return False, None, None + if dtype is not None: + dtypes.add(dtype) + if len(dtypes) > 1: + return False, None, None + sizes.add(size_i) + if len(sizes) > 1: + return False, None, None + continue + return False, None + else: + if len(dtypes): + dtype = list(dtypes)[0] + else: + dtype = None + if len(sizes): + return True, (length_t, *list(sizes)[0]), dtype + return True, (length_t,), dtype diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 737ff4f24..27b53b686 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -997,6 +997,13 @@ class MyClass: td.keys(True, True) ).symmetric_difference(expected) + def test_from_any_list(self): + t = torch.randn(3, 4, 5) + t = t.tolist() + assert isinstance(TensorDict.from_any(t), torch.Tensor) + t[0].extend([0, 2]) + assert isinstance(TensorDict.from_any(t), TensorDict) + def test_from_any_userdict(self): class D(UserDict): ... From 2728dbf9863d3a010addaf5e5bd6bf44a7f96013 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 25 Nov 2024 08:51:05 +0000 Subject: [PATCH 7/8] [BugFix] auto-batch-size in dipatch ghstack-source-id: ca5b36195c28da65a20d42699346fbc06083181c Pull Request resolved: https://github.com/pytorch/tensordict/pull/1109 --- tensordict/nn/common.py | 2 +- test/test_nn.py | 19 +++++++++++++++++++ test/test_tensordict.py | 2 +- 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index c38bd31de..e6d150528 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -301,7 +301,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: tensordict = make_tensordict( tensordict_values, batch_size=batch_size, - auto_batch_size=False, + auto_batch_size=self.auto_batch_size, ) if _self is not None: out = func(_self, tensordict, *args, **kwargs) diff --git a/test/test_nn.py b/test/test_nn.py index 8c78dd3b2..3134932e9 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -559,6 +559,25 @@ def forward(self, tensordict): with pytest.raises(RuntimeError, match="Duplicated argument"): module(torch.zeros(1, 2), a_c=torch.ones(1, 2)) + @pytest.mark.parametrize("auto_batch_size", [True, False]) + def test_dispatch_auto_batch_size(self, auto_batch_size): + class MyModuleNest(nn.Module): + in_keys = [("a", "c"), "d"] + out_keys = ["b"] + + @dispatch(auto_batch_size=auto_batch_size) + def forward(self, tensordict): + if auto_batch_size: + assert tensordict.shape == (2, 3) + else: + assert tensordict.shape == () + tensordict["b"] = tensordict["a", "c"] + tensordict["d"] + return tensordict + + module = MyModuleNest() + b = module(torch.zeros(2, 3), d=torch.ones(2, 3)) + assert (b == 1).all() + def test_dispatch_nested_extra_args(self): class MyModuleNest(nn.Module): in_keys = [("a", "c"), "d"] diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 27b53b686..63ebc8935 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -9865,7 +9865,7 @@ def check_weakref_count(weakref_list, expected): @pytest.mark.skipif( not torch.cuda.is_available(), - # and not torch.backends.mps.is_available(), + # and not torch.backends.mps.is_available(), reason="a device is required.", ) def test_cached_data_lock_device(self): From e2444ed97ec5cacac75c253797034099302893df Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 25 Nov 2024 11:32:29 +0000 Subject: [PATCH 8/8] [BugFix] Fix from_any tests ghstack-source-id: 8c3b3d825555c727c7c18c7e8a87311f718a94b6 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1110 --- test/test_tensordict.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 63ebc8935..f7fe9a9ff 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -959,7 +959,7 @@ class MyClass: a: int pytree = ( - [torch.randint(10, (3,)), torch.zeros(2)], + [[-1, 0, 1], [2, 3, 4]], { "tensor": torch.randn( 2, @@ -974,8 +974,7 @@ class MyClass: pytree = pytree + ({"h5py": TestTensorDictsBase.td_h5(device="cpu").file},) td = TensorDict.from_any(pytree) expected = { - ("0", "0"), - ("0", "1"), + "0", ("1", "td", "one"), ("1", "tensor"), ("1", "tuple", "0"), @@ -1001,8 +1000,8 @@ def test_from_any_list(self): t = torch.randn(3, 4, 5) t = t.tolist() assert isinstance(TensorDict.from_any(t), torch.Tensor) - t[0].extend([0, 2]) - assert isinstance(TensorDict.from_any(t), TensorDict) + t[0][1].extend([0, 2]) + assert isinstance(TensorDict.from_any(t), NonTensorStack) def test_from_any_userdict(self): class D(UserDict): ...