From 4f2a602f6a969d85ecebf7b9e31ecfb5b067491b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 18 Jan 2024 10:06:29 +0000 Subject: [PATCH 1/3] [Refactor] Make unbind call tensor.unbind (#628) --- tensordict/_lazy.py | 12 +++--------- tensordict/_td.py | 41 ++++++++++++++++++++++++++++------------ tensordict/base.py | 16 +++++++++++++++- tensordict/nn/params.py | 2 +- tensordict/persistent.py | 2 +- 5 files changed, 49 insertions(+), 24 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 8d431a54f..1f0bbf171 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -746,13 +746,7 @@ def _legacy_squeeze(self, dim: int | None = None) -> T: stack_dim=stack_dim, ) - def unbind(self, dim: int) -> tuple[TensorDictBase, ...]: - if dim < 0: - dim = self.batch_dims + dim - if dim < 0 or dim >= self.ndim: - raise ValueError( - f"Cannot unbind along dimension {dim} with batch size {self.batch_size}." - ) + def _unbind(self, dim: int) -> tuple[TensorDictBase, ...]: if dim == self.stack_dim: return tuple(self.tensordicts) else: @@ -763,7 +757,7 @@ def unbind(self, dim: int) -> tuple[TensorDictBase, ...]: self.stack_dim if dim > self.stack_dim else self.stack_dim - 1 ) for td in self.tensordicts: - out.append(td.unbind(new_dim)) + out.append(td._unbind(new_dim)) return tuple(self.lazy_stack(vals, new_stack_dim) for vals in zip(*out)) @@ -2869,7 +2863,7 @@ def _unsqueeze(self, dim): all = TensorDict.all any = TensorDict.any expand = TensorDict.expand - unbind = TensorDict.unbind + _unbind = TensorDict._unbind _get_names_idx = TensorDict._get_names_idx diff --git a/tensordict/_td.py b/tensordict/_td.py index 3d62a19cd..1ad770f13 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -827,23 +827,40 @@ def _expand(tensor): _expand, batch_size=shape, call_on_nested=True, names=names ) - def unbind(self, dim: int) -> tuple[T, ...]: - if dim < 0: - dim = self.batch_dims + dim + def _unbind(self, dim: int): batch_size = torch.Size([s for i, s in enumerate(self.batch_size) if i != dim]) names = None if self._has_names(): names = copy(self.names) names = [name for i, name in enumerate(names) if i != dim] - out = [] - # unbind_self_dict = {key: tensor.unbind(dim) for key, tensor in self.items()} - prefix = (slice(None),) * dim + device = self.device + + is_shared = self._is_shared + is_memmap = self._is_memmap + + def empty(): + result = TensorDict( + {}, batch_size=batch_size, names=names, _run_checks=False, device=device + ) + result._is_shared = is_shared + result._is_memmap = is_memmap + return result + + tds = tuple(empty() for _ in range(self.batch_size[dim])) + + def unbind(key, val, tds=tds): + unbound = ( + val.unbind(dim) + if not isinstance(val, TensorDictBase) + # tensorclass is also unbound using plain unbind + else val._unbind(dim) + ) + for td, _val in zip(tds, unbound): + td._set_str(key, _val, validated=True, inplace=False) - for _idx in range(self.batch_size[dim]): - _idx = prefix + (_idx,) - td = self._index_tensordict(_idx, new_batch_size=batch_size, names=names) - out.append(td) - return tuple(out) + for key, val in self.items(): + unbind(key, val) + return tds def split(self, split_size: int | list[int], dim: int = 0) -> list[TensorDictBase]: # we must use slices to keep the storage of the tensors @@ -2745,7 +2762,7 @@ def _create_nested_str(self, key): reshape = TensorDict.reshape split = TensorDict.split to_module = TensorDict.to_module - unbind = TensorDict.unbind + _unbind = TensorDict._unbind def _view(self, *args, **kwargs): raise RuntimeError( diff --git a/tensordict/base.py b/tensordict/base.py index 86384dfa4..a966cbb7f 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -543,7 +543,6 @@ def expand(self, *args: int | torch.Size) -> T: """ ... - @abc.abstractmethod def unbind(self, dim: int) -> tuple[T, ...]: """Returns a tuple of indexed tensordicts, unbound along the indicated dimension. @@ -558,6 +557,21 @@ def unbind(self, dim: int) -> tuple[T, ...]: tensor([4, 5, 6, 7]) """ + batch_dims = self.batch_dims + if dim < -batch_dims or dim >= batch_dims: + raise RuntimeError( + f"the dimension provided ({dim}) is beyond the tensordict dimensions ({self.ndim})." + ) + if dim < 0: + dim = batch_dims + dim + results = self._unbind(dim) + if self._is_memmap or self._is_shared: + for result in results: + result.lock_() + return results + + @abc.abstractmethod + def _unbind(self, dim: int) -> tuple[T, ...]: ... def chunk(self, chunks: int, dim: int = 0) -> tuple[TensorDictBase, ...]: diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 27983d2c9..2f24a1c8a 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -564,7 +564,7 @@ def chunk(self, chunks: int, dim: int = 0) -> tuple[TensorDictBase, ...]: ... @_fallback - def unbind(self, dim: int) -> tuple[TensorDictBase, ...]: + def _unbind(self, dim: int) -> tuple[TensorDictBase, ...]: ... @_fallback diff --git a/tensordict/persistent.py b/tensordict/persistent.py index e87d48819..3af928ffa 100644 --- a/tensordict/persistent.py +++ b/tensordict/persistent.py @@ -1079,7 +1079,7 @@ def _unsqueeze(self, dim): reshape = TensorDict.reshape split = TensorDict.split to_module = TensorDict.to_module - unbind = TensorDict.unbind + _unbind = TensorDict._unbind _get_names_idx = TensorDict._get_names_idx From 89348f1987bfc551f912835484f687d6f19c950a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 18 Jan 2024 15:48:50 +0000 Subject: [PATCH 2/3] [Feature] `auto_batch_size_` (#630) --- tensordict/base.py | 26 ++++++++++++++++++++++++++ tensordict/utils.py | 4 +++- test/test_tensordict.py | 17 +++++++++++++++++ 3 files changed, 46 insertions(+), 1 deletion(-) diff --git a/tensordict/base.py b/tensordict/base.py index a966cbb7f..9609c928e 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -42,6 +42,7 @@ _KEY_ERROR, _proc_init, _prune_selected_keys, + _set_max_batch_size, _shape, _split_tensordict, _td_fields, @@ -326,6 +327,31 @@ def any(self, dim: int = None) -> bool | TensorDictBase: """ ... + def auto_batch_size_(self, batch_dims: int | None = None) -> T: + """Sets the maximum batch-size for the tensordict, up to an optional batch_dims. + + Args: + batch_dims (int, optional): if provided, the batch-size will be at + most ``batch_dims`` long. + + Returns: + self + + Examples: + >>> from tensordict import TensorDict + >>> import torch + >>> td = TensorDict({"a": torch.randn(3, 4, 5), "b": {"c": torch.randn(3, 4, 6)}}, batch_size=[]) + >>> td.auto_batch_size_() + >>> print(td.batch_size) + torch.Size([3, 4]) + >>> td.auto_batch_size_(batch_dims=1) + >>> print(td.batch_size) + torch.Size([3]) + + """ + _set_max_batch_size(self, batch_dims) + return self + # Module interaction @classmethod def from_module( diff --git a/tensordict/utils.py b/tensordict/utils.py index a68664bd7..0e7c05e51 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1441,12 +1441,14 @@ def _expand_to_match_shape( def _set_max_batch_size(source: T, batch_dims=None): """Updates a tensordict with its maximium batch size.""" + from tensordict import NonTensorData + tensor_data = list(source.values()) for val in tensor_data: from tensordict.base import _is_tensor_collection - if _is_tensor_collection(val.__class__): + if _is_tensor_collection(val.__class__) and not isinstance(val, NonTensorData): _set_max_batch_size(val, batch_dims=batch_dims) batch_size = [] if not tensor_data: # when source is empty diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 4c94358e3..41177460c 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -1784,6 +1784,23 @@ def test_assert(self, td_name, device): ): assert td + def test_auto_batch_size_(self, td_name, device): + td = getattr(self, td_name)(device) + batch_size = td.batch_size + error = None + try: + td.batch_size = [] + except Exception as err: + error = err + if error is not None: + with pytest.raises(type(error)): + td.auto_batch_size_() + return + td.auto_batch_size_() + assert td.batch_size[: len(batch_size)] == batch_size + td.auto_batch_size_(1) + assert len(td.batch_size) == 1 + def test_broadcast(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) From 0f75ac9feb7262a925ee78606e730ac59b3d93b8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 22 Jan 2024 21:17:53 +0000 Subject: [PATCH 3/3] [BugFix] Fix NonTensorData interaction (#631) --- tensordict/_lazy.py | 37 +++++++++++++++++++++---------------- tensordict/_torch_func.py | 12 +++++++++--- tensordict/base.py | 11 ++++++++++- tensordict/tensorclass.py | 14 ++++++++++++-- tensordict/utils.py | 13 +++++++++---- test/test_tensordict.py | 14 +++++++++++++- 6 files changed, 74 insertions(+), 27 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 1f0bbf171..a4806dc93 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -1515,7 +1515,12 @@ def __getitem__(self, index: IndexType) -> T: if isinstance(index, (tuple, str)): index_key = _unravel_key_to_tuple(index) if index_key: - return self._get_tuple(index_key, NO_DEFAULT) + result = self._get_tuple(index_key, NO_DEFAULT) + from .tensorclass import NonTensorData + + if isinstance(result, NonTensorData): + return result.data + return result split_index = self._split_index(index) converted_idx = split_index["index_dict"] isinteger = split_index["isinteger"] @@ -1527,22 +1532,22 @@ def __getitem__(self, index: IndexType) -> T: if has_bool: mask_unbind = split_index["individual_masks"] cat_dim = split_index["mask_loc"] - num_single - out = [] + result = [] if mask_unbind[0].ndim == 0: # we can return a stack for (i, _idx), mask in zip(converted_idx.items(), mask_unbind): if mask.any(): if mask.all() and self.tensordicts[i].ndim == 0: - out.append(self.tensordicts[i]) + result.append(self.tensordicts[i]) else: - out.append(self.tensordicts[i][_idx]) - out[-1] = out[-1].squeeze(cat_dim) - return LazyStackedTensorDict.lazy_stack(out, cat_dim) + result.append(self.tensordicts[i][_idx]) + result[-1] = result[-1].squeeze(cat_dim) + return LazyStackedTensorDict.lazy_stack(result, cat_dim) else: for i, _idx in converted_idx.items(): self_idx = (slice(None),) * split_index["mask_loc"] + (i,) - out.append(self[self_idx][_idx]) - return torch.cat(out, cat_dim) + result.append(self[self_idx][_idx]) + return torch.cat(result, cat_dim) elif is_nd_tensor: new_stack_dim = self.stack_dim - num_single + num_none return LazyStackedTensorDict.lazy_stack( @@ -1556,18 +1561,18 @@ def __getitem__(self, index: IndexType) -> T: ) in ( converted_idx.items() ): # for convenience but there's only one element - out = self.tensordicts[i] + result = self.tensordicts[i] if _idx is not None and _idx != (): - out = out[_idx] - return out + result = result[_idx] + return result else: - out = [] + result = [] new_stack_dim = self.stack_dim - num_single + num_none - num_squash for i, _idx in converted_idx.items(): - out.append(self.tensordicts[i][_idx]) - out = LazyStackedTensorDict.lazy_stack(out, new_stack_dim) - out._td_dim_name = self._td_dim_name - return out + result.append(self.tensordicts[i][_idx]) + result = LazyStackedTensorDict.lazy_stack(result, new_stack_dim) + result._td_dim_name = self._td_dim_name + return result def __eq__(self, other): if is_tensorclass(other): diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index 67d06a972..b267fbc86 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -13,7 +13,7 @@ from tensordict._lazy import LazyStackedTensorDict from tensordict._td import TensorDict -from tensordict.base import NO_DEFAULT, TensorDictBase +from tensordict.base import _is_leaf_nontensor, NO_DEFAULT, TensorDictBase from tensordict.persistent import PersistentTensorDict from tensordict.utils import ( _check_keys, @@ -95,12 +95,18 @@ def _gather_tensor(tensor, dest=None): names = input.names if input._has_names() else None return TensorDict( - {key: _gather_tensor(value) for key, value in input.items()}, + { + key: _gather_tensor(value) + for key, value in input.items(is_leaf=_is_leaf_nontensor) + }, batch_size=index.shape, names=names, ) TensorDict( - {key: _gather_tensor(value, out[key]) for key, value in input.items()}, + { + key: _gather_tensor(value, out.get(key)) + for key, value in input.items(is_leaf=_is_leaf_nontensor) + }, batch_size=index.shape, ) return out diff --git a/tensordict/base.py b/tensordict/base.py index 9609c928e..71862dae8 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -207,6 +207,9 @@ def __getitem__(self, index: IndexType) -> T: The index can be a (nested) key or any valid shape index given the tensordict batch size. + If the index is a nested key and the result is a :class:`~tensordict.NonTensorData` + object, the content of the non-tensor is returned. + Examples: >>> td = TensorDict({"root": torch.arange(2), ("nested", "entry"): torch.arange(2)}, [2]) >>> td["root"] @@ -232,7 +235,13 @@ def __getitem__(self, index: IndexType) -> T: # _unravel_key_to_tuple will return an empty tuple if the index isn't a NestedKey idx_unravel = _unravel_key_to_tuple(index) if idx_unravel: - return self._get_tuple(idx_unravel, NO_DEFAULT) + result = self._get_tuple(idx_unravel, NO_DEFAULT) + from .tensorclass import NonTensorData + + if isinstance(result, NonTensorData): + return result.data + return result + if (istuple and not index) or (not istuple and index is Ellipsis): # empty tuple returns self return self diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index c1ce5df1c..d026fe7d3 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -27,7 +27,7 @@ from tensordict._td import is_tensor_collection, NO_DEFAULT, TensorDict, TensorDictBase from tensordict._tensordict import _unravel_key_to_tuple from tensordict._torch_func import TD_HANDLED_FUNCTIONS -from tensordict.base import _register_tensor_class +from tensordict.base import _ACCEPTED_CLASSES, _register_tensor_class from tensordict.memmap_deprec import MemmapTensor as _MemmapTensor from tensordict.utils import ( @@ -1305,7 +1305,17 @@ def to_dict(self): def _stack_non_tensor(cls, list_of_non_tensor, dim=0): # checks have been performed previously, so we're sure the list is non-empty first = list_of_non_tensor[0] - if all(data.data == first.data for data in list_of_non_tensor[1:]): + + def _check_equal(a, b): + if isinstance(a, _ACCEPTED_CLASSES) or isinstance(b, _ACCEPTED_CLASSES): + return (a == b).all() + try: + iseq = a == b + except Exception: + iseq = False + return iseq + + if all(_check_equal(data.data, first.data) for data in list_of_non_tensor[1:]): batch_size = list(first.batch_size) batch_size.insert(dim, len(list_of_non_tensor)) return NonTensorData( diff --git a/tensordict/utils.py b/tensordict/utils.py index 0e7c05e51..9ff4ef93d 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1443,17 +1443,22 @@ def _set_max_batch_size(source: T, batch_dims=None): """Updates a tensordict with its maximium batch size.""" from tensordict import NonTensorData - tensor_data = list(source.values()) + tensor_data = [val for val in source.values() if not isinstance(val, NonTensorData)] for val in tensor_data: from tensordict.base import _is_tensor_collection - if _is_tensor_collection(val.__class__) and not isinstance(val, NonTensorData): + if _is_tensor_collection(val.__class__): _set_max_batch_size(val, batch_dims=batch_dims) + batch_size = [] if not tensor_data: # when source is empty - source.batch_size = batch_size - return + if batch_dims: + source.batch_size = source.batch_size[:batch_dims] + return source + else: + return source + curr_dim = 0 while True: if tensor_data[0].dim() > curr_dim: diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 41177460c..cf7fef16f 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -2872,7 +2872,13 @@ def test_memmap_like(self, td_name, device, use_dir, tmpdir, num_threads): ) assert tdmemmap is not td for key in td.keys(True): - assert td[key] is not tdmemmap[key] + v1 = td[key] + v2 = tdmemmap[key] + if isinstance(v1, str): + # non-tensor data storing strings share the same id in python + assert v1 is v2 + else: + assert v1 is not v2 assert (tdmemmap == 0).all() assert tdmemmap.is_memmap() @@ -3097,6 +3103,12 @@ def test_non_tensor_data(self, td_name, device): assert td.get_non_tensor(("this", "will")) == "succeed" assert isinstance(td.get(("this", "will")), NonTensorData) + with td.unlock_(): + td["this", "other", "tensor"] = "success" + assert td["this", "other", "tensor"] == "success" + assert isinstance(td.get(("this", "other", "tensor")), NonTensorData) + assert td.get_non_tensor(("this", "other", "tensor")) == "success" + def test_non_tensor_data_flatten_keys(self, td_name, device): td = getattr(self, td_name)(device) with td.unlock_():