From 4c0eb1de9d0b635fb08b3d117e2cb966804f251f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 4 Oct 2023 09:36:18 -0400 Subject: [PATCH] [Feature] First class dim compatibility (#525) --- tensordict/nn/params.py | 7 +- tensordict/tensordict.py | 89 ++++++++++--------------- tensordict/utils.py | 6 +- test/test_tensordict.py | 136 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 178 insertions(+), 60 deletions(-) diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 5bc4fb723..3685661c7 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -13,6 +13,7 @@ from typing import Any, Callable, Iterator, OrderedDict, Sequence import torch +from functorch import dim as ftdim from tensordict import TensorDictBase from tensordict.nn.utils import Buffer @@ -72,7 +73,7 @@ def _get_args_dict(func, args, kwargs): def _maybe_make_param(tensor): if ( - isinstance(tensor, Tensor) + isinstance(tensor, (Tensor, ftdim.Tensor)) and not isinstance(tensor, nn.Parameter) and tensor.dtype in (torch.float, torch.double, torch.half) ): @@ -82,7 +83,7 @@ def _maybe_make_param(tensor): def _maybe_make_param_or_buffer(tensor): if ( - isinstance(tensor, Tensor) + isinstance(tensor, (Tensor, ftdim.Tensor)) and not isinstance(tensor, nn.Parameter) and tensor.dtype in (torch.float, torch.double, torch.half) ): @@ -319,7 +320,7 @@ def __torch_function__( if kwargs is None: kwargs = {} if func not in TDPARAM_HANDLED_FUNCTIONS or not all( - issubclass(t, (Tensor, TensorDictBase)) for t in types + issubclass(t, (Tensor, ftdim.Tensor, TensorDictBase)) for t in types ): return NotImplemented return TDPARAM_HANDLED_FUNCTIONS[func](*args, **kwargs) diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index 82f23c5c8..ae52b7ad4 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -32,11 +32,13 @@ TypeVar, Union, ) + from warnings import warn import numpy as np import torch +from functorch import dim as ftdim from tensordict._tensordict import _unravel_key_to_tuple from tensordict.memmap import memmap_tensor_as_tensor, MemmapTensor from tensordict.utils import ( @@ -69,7 +71,7 @@ NestedKey, prod, ) -from torch import distributed as dist, multiprocessing as mp, Tensor +from torch import distributed as dist, multiprocessing as mp, nn, Tensor from torch.utils._pytree import tree_map try: @@ -408,6 +410,24 @@ def from_module(module, as_module: bool = False): return TensorDictParams(td, no_convert=True) return td + def to_module(self, module): + from tensordict.nn.functional_modules import set_tensor_dict + + __base__setattr__ = nn.Module.__setattr__ + # we use __dict__ directly to avoid the getattr/setattr overhead whenever we can + __dict__ = module.__dict__ + + for key, value in self.items(): + cls = value.__class__ + if _is_tensor_collection(cls) or issubclass(cls, dict): + value.to_module(__dict__["_modules"][key]) + else: + if module.__class__.__setattr__ is __base__setattr__: + set_tensor_dict(__dict__, module, key, value) + else: + # use specialized __setattr__ if needed + setattr(module, key, value) + @property def shape(self) -> torch.Size: """See :obj:`TensorDictBase.batch_size`.""" @@ -3515,6 +3535,8 @@ def _get_names_idx(self, idx): else: def is_boolean(idx): + if isinstance(idx, ftdim.Dim): + return None if isinstance(idx, tuple) and len(idx) == 1: return is_boolean(idx[0]) if hasattr(idx, "dtype") and idx.dtype is torch.bool: @@ -3886,6 +3908,7 @@ def type(self, dst_type): Tensor, MemmapTensor, TensorDictBase, + ftdim.Tensor, ] if _has_torchrec: _ACCEPTED_CLASSES += [KeyedJaggedTensor] @@ -4584,11 +4607,6 @@ def memmap_( raise RuntimeError( "memmap and shared memory are mutually exclusive features." ) - # if not self._tensordict.keys(): - # raise Exception( - # "memmap_() must be called when the TensorDict is (partially) " - # "populated. Set a tensor first." - # ) for key, value in self.items(): if value.requires_grad: raise Exception( @@ -6527,7 +6545,13 @@ def _split_index(self, index): continue if cursor == self.stack_dim: # we need to check which tds need to be indexed - if isinstance(idx, slice) or _is_number(idx): + if isinstance(idx, ftdim.Dim): + raise ValueError( + "Cannot index a lazy stacked tensordict along the stack dimension with " + "a first-class dimension index. Consider consolidating the tensordict first " + "using `tensordict.contiguous()`." + ) + elif isinstance(idx, slice) or _is_number(idx): selected_td_idx = range(len(self.tensordicts))[idx] if not isinstance(selected_td_idx, range): isinteger = True @@ -6559,6 +6583,7 @@ def _split_index(self, index): idx, ( int, + ftdim.Dim, slice, list, range, @@ -7372,54 +7397,6 @@ def __getitem__(self, index: IndexType) -> T: out._td_dim_name = self._td_dim_name return out - # index_dict = _convert_index_lazystack(index, self.stack_dim, self.batch_size) - # if index_dict is None: - # # then we use a sub-tensordict - # return self.get_sub_tensordict(index) - # td_index = index_dict["remaining_index"] - # stack_index = index_dict["stack_index"] - # new_stack_dim = index_dict["new_stack_dim"] - # if new_stack_dim is not None: - # if isinstance(stack_index, slice): - # # we can't iterate but we can index the list directly - # out = LazyStackedTensorDict( - # *[td[td_index] for td in self.tensordicts[stack_index]], - # stack_dim=new_stack_dim, - # ) - # elif isinstance(stack_index, (list, range)): - # # then we can iterate - # out = LazyStackedTensorDict( - # *[self.tensordicts[idx][td_index] for idx in stack_index], - # stack_dim=new_stack_dim, - # ) - # elif isinstance(stack_index, Tensor): - # # td_index is a nested tuple that mimics the shape of stack_index - # def _nested_stack(t: list, stack_idx: Tensor, td_index): - # if stack_idx.ndim: - # out = LazyStackedTensorDict( - # *[ - # _nested_stack(t, _idx, td_index[i]) - # for i, _idx in enumerate(stack_idx.unbind(0)) - # ], - # stack_dim=new_stack_dim, - # ) - # return out - # return t[stack_idx][td_index] - # - # # print(index, td_index, stack_index) - # out = _nested_stack(self.tensordicts, stack_index, td_index) - # else: - # raise TypeError("Invalid index used for stack dimension.") - # out._td_dim_name = self._td_dim_name - # return out - # out = self.tensordicts[stack_index] - # if td_index: - # return out[td_index] - # return out - - # def __hash__(self): - # return hash(self.tensordicts) - def __eq__(self, other): if is_tensorclass(other): return other == self @@ -9084,7 +9061,7 @@ def _clone_value(value: CompatibleType, recurse: bool) -> CompatibleType: def _is_number(item): - if isinstance(item, Number): + if isinstance(item, (Number, ftdim.Dim)): return True if isinstance(item, Tensor) and item.ndim == 0: return True diff --git a/tensordict/utils.py b/tensordict/utils.py index 37a34b86a..78da3281f 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -20,6 +20,7 @@ import numpy as np import torch +from functorch import dim as ftdim from packaging.version import parse from tensordict._tensordict import ( # noqa: F401 @@ -150,7 +151,7 @@ def _getitem_batch_size(batch_size, index): out.extend(bs_shape) bs_shape = None continue - elif isinstance(idx, int): + elif isinstance(idx, (int, ftdim.Dim)): # could be spared for efficiency continue elif isinstance(idx, slice): @@ -761,9 +762,12 @@ def _is_shared(tensor: torch.Tensor) -> bool: if torch._C._functorch.is_batchedtensor(tensor): return None return tensor.is_shared() + if isinstance(tensor, ftdim.Tensor): + return None elif isinstance(tensor, KeyedJaggedTensor): return False else: + print(type(tensor)) return tensor.is_shared() diff --git a/test/test_tensordict.py b/test/test_tensordict.py index abad0dd81..8ec9edd9f 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -12,6 +12,7 @@ import pytest import torch +from tensordict.nn import TensorDictParams try: import torchsnapshot @@ -30,6 +31,7 @@ _has_h5py = False from _utils_internal import decompose, get_available_devices, prod, TestTensorDictsBase +from functorch import dim as ftdim from tensordict import LazyStackedTensorDict, MemmapTensor, TensorDict from tensordict.tensordict import ( @@ -6239,6 +6241,140 @@ def _pool_fixt(): yield pool +class TestFCD(TestTensorDictsBase): + """Test stack for first-class dimension.""" + + @pytest.mark.parametrize( + "td_name", + [ + "td", + "stacked_td", + "sub_td", + "sub_td2", + "idx_td", + "memmap_td", + "unsqueezed_td", + "squeezed_td", + "td_reset_bs", + "nested_td", + "nested_tensorclass", + "permute_td", + "nested_stacked_td", + "td_params", + pytest.param( + "td_h5", + marks=pytest.mark.skipif(not _has_h5py, reason="h5py not found."), + ), + ], + ) + @pytest.mark.parametrize("device", get_available_devices()) + def test_fcd(self, td_name, device): + td = getattr(self, td_name)(device) + d0 = ftdim.dims(1) + if isinstance(td, LazyStackedTensorDict) and td.stack_dim == 0: + with pytest.raises(ValueError, match="Cannot index"): + td[d0] + else: + assert td[d0].shape == td.shape[1:] + d0, d1 = ftdim.dims(2) + if isinstance(td, LazyStackedTensorDict) and td.stack_dim in (0, 1): + with pytest.raises(ValueError, match="Cannot index"): + td[d0, d1] + else: + assert td[d0, d1].shape == td.shape[2:] + d0, d1, d2 = ftdim.dims(3) + if isinstance(td, LazyStackedTensorDict) and td.stack_dim in (0, 1, 2): + with pytest.raises(ValueError, match="Cannot index"): + td[d0, d1, d2] + else: + assert td[d0, d1, d2].shape == td.shape[3:] + d0 = ftdim.dims(1) + if isinstance(td, LazyStackedTensorDict) and td.stack_dim == 1: + with pytest.raises(ValueError, match="Cannot index"): + td[:, d0] + else: + assert td[:, d0].shape == torch.Size((td.shape[0], *td.shape[2:])) + + @pytest.mark.parametrize( + "td_name", + [ + "td", + "stacked_td", + "idx_td", + "memmap_td", + "td_reset_bs", + "nested_td", + "nested_tensorclass", + "nested_stacked_td", + "td_params", + pytest.param( + "td_h5", + marks=pytest.mark.skipif(not _has_h5py, reason="h5py not found."), + ), + # these tds cannot see their dim names edited: + # "sub_td", + # "sub_td2", + # "unsqueezed_td", + # "squeezed_td", + # "permute_td", + ], + ) + @pytest.mark.parametrize("device", get_available_devices()) + def test_fcd_names(self, td_name, device): + td = getattr(self, td_name)(device) + td.names = ["a", "b", "c", "d"] + d0 = ftdim.dims(1) + if isinstance(td, LazyStackedTensorDict) and td.stack_dim == 0: + with pytest.raises(ValueError, match="Cannot index"): + td[d0] + else: + assert td[d0].names == ["b", "c", "d"] + d0, d1 = ftdim.dims(2) + if isinstance(td, LazyStackedTensorDict) and td.stack_dim in (0, 1): + with pytest.raises(ValueError, match="Cannot index"): + td[d0, d1] + else: + assert td[d0, d1].names == ["c", "d"] + d0, d1, d2 = ftdim.dims(3) + if isinstance(td, LazyStackedTensorDict) and td.stack_dim in (0, 1, 2): + with pytest.raises(ValueError, match="Cannot index"): + td[d0, d1, d2] + else: + assert td[d0, d1, d2].names == ["d"] + d0 = ftdim.dims(1) + if isinstance(td, LazyStackedTensorDict) and td.stack_dim == 1: + with pytest.raises(ValueError, match="Cannot index"): + td[:, d0] + else: + assert td[:, d0].names == ["a", "c", "d"] + + @pytest.mark.parametrize("as_module", [False, True]) + def test_modules(self, as_module): + modules = [ + lambda: nn.Linear(3, 4), + lambda: nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 4)), + lambda: nn.Transformer(16, 4, 2, 2, 8), + lambda: nn.Sequential(nn.Conv2d(3, 4, 3), nn.Conv2d(4, 4, 3)), + ] + inputs = [ + lambda: (torch.randn(2, 3),), + lambda: (torch.randn(2, 3),), + lambda: (torch.randn(2, 3, 16), torch.randn(2, 3, 16)), + lambda: (torch.randn(2, 3, 16, 16),), + ] + param_batch = 5 + for make_module, make_input in zip(modules, inputs): + module = make_module() + td = TensorDict.from_module(module, as_module=as_module) + td = td.expand(param_batch).clone() + d0 = ftdim.dims(1) + td = TensorDictParams(td)[d0] + td.to_module(module) + y = module(*make_input()) + assert y.dims == (d0,) + assert y._tensor.shape[0] == param_batch + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)