Skip to content

Commit

Permalink
[Feature] First class dim compatibility (#525)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Oct 4, 2023
1 parent 302c342 commit 4c0eb1d
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 60 deletions.
7 changes: 4 additions & 3 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Any, Callable, Iterator, OrderedDict, Sequence

import torch
from functorch import dim as ftdim

from tensordict import TensorDictBase
from tensordict.nn.utils import Buffer
Expand Down Expand Up @@ -72,7 +73,7 @@ def _get_args_dict(func, args, kwargs):

def _maybe_make_param(tensor):
if (
isinstance(tensor, Tensor)
isinstance(tensor, (Tensor, ftdim.Tensor))
and not isinstance(tensor, nn.Parameter)
and tensor.dtype in (torch.float, torch.double, torch.half)
):
Expand All @@ -82,7 +83,7 @@ def _maybe_make_param(tensor):

def _maybe_make_param_or_buffer(tensor):
if (
isinstance(tensor, Tensor)
isinstance(tensor, (Tensor, ftdim.Tensor))
and not isinstance(tensor, nn.Parameter)
and tensor.dtype in (torch.float, torch.double, torch.half)
):
Expand Down Expand Up @@ -319,7 +320,7 @@ def __torch_function__(
if kwargs is None:
kwargs = {}
if func not in TDPARAM_HANDLED_FUNCTIONS or not all(
issubclass(t, (Tensor, TensorDictBase)) for t in types
issubclass(t, (Tensor, ftdim.Tensor, TensorDictBase)) for t in types
):
return NotImplemented
return TDPARAM_HANDLED_FUNCTIONS[func](*args, **kwargs)
Expand Down
89 changes: 33 additions & 56 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
TypeVar,
Union,
)

from warnings import warn

import numpy as np

import torch
from functorch import dim as ftdim
from tensordict._tensordict import _unravel_key_to_tuple
from tensordict.memmap import memmap_tensor_as_tensor, MemmapTensor
from tensordict.utils import (
Expand Down Expand Up @@ -69,7 +71,7 @@
NestedKey,
prod,
)
from torch import distributed as dist, multiprocessing as mp, Tensor
from torch import distributed as dist, multiprocessing as mp, nn, Tensor
from torch.utils._pytree import tree_map

try:
Expand Down Expand Up @@ -408,6 +410,24 @@ def from_module(module, as_module: bool = False):
return TensorDictParams(td, no_convert=True)
return td

def to_module(self, module):
from tensordict.nn.functional_modules import set_tensor_dict

__base__setattr__ = nn.Module.__setattr__
# we use __dict__ directly to avoid the getattr/setattr overhead whenever we can
__dict__ = module.__dict__

for key, value in self.items():
cls = value.__class__
if _is_tensor_collection(cls) or issubclass(cls, dict):
value.to_module(__dict__["_modules"][key])
else:
if module.__class__.__setattr__ is __base__setattr__:
set_tensor_dict(__dict__, module, key, value)
else:
# use specialized __setattr__ if needed
setattr(module, key, value)

@property
def shape(self) -> torch.Size:
"""See :obj:`TensorDictBase.batch_size`."""
Expand Down Expand Up @@ -3515,6 +3535,8 @@ def _get_names_idx(self, idx):
else:

def is_boolean(idx):
if isinstance(idx, ftdim.Dim):
return None
if isinstance(idx, tuple) and len(idx) == 1:
return is_boolean(idx[0])
if hasattr(idx, "dtype") and idx.dtype is torch.bool:
Expand Down Expand Up @@ -3886,6 +3908,7 @@ def type(self, dst_type):
Tensor,
MemmapTensor,
TensorDictBase,
ftdim.Tensor,
]
if _has_torchrec:
_ACCEPTED_CLASSES += [KeyedJaggedTensor]
Expand Down Expand Up @@ -4584,11 +4607,6 @@ def memmap_(
raise RuntimeError(
"memmap and shared memory are mutually exclusive features."
)
# if not self._tensordict.keys():
# raise Exception(
# "memmap_() must be called when the TensorDict is (partially) "
# "populated. Set a tensor first."
# )
for key, value in self.items():
if value.requires_grad:
raise Exception(
Expand Down Expand Up @@ -6527,7 +6545,13 @@ def _split_index(self, index):
continue
if cursor == self.stack_dim:
# we need to check which tds need to be indexed
if isinstance(idx, slice) or _is_number(idx):
if isinstance(idx, ftdim.Dim):
raise ValueError(
"Cannot index a lazy stacked tensordict along the stack dimension with "
"a first-class dimension index. Consider consolidating the tensordict first "
"using `tensordict.contiguous()`."
)
elif isinstance(idx, slice) or _is_number(idx):
selected_td_idx = range(len(self.tensordicts))[idx]
if not isinstance(selected_td_idx, range):
isinteger = True
Expand Down Expand Up @@ -6559,6 +6583,7 @@ def _split_index(self, index):
idx,
(
int,
ftdim.Dim,
slice,
list,
range,
Expand Down Expand Up @@ -7372,54 +7397,6 @@ def __getitem__(self, index: IndexType) -> T:
out._td_dim_name = self._td_dim_name
return out

# index_dict = _convert_index_lazystack(index, self.stack_dim, self.batch_size)
# if index_dict is None:
# # then we use a sub-tensordict
# return self.get_sub_tensordict(index)
# td_index = index_dict["remaining_index"]
# stack_index = index_dict["stack_index"]
# new_stack_dim = index_dict["new_stack_dim"]
# if new_stack_dim is not None:
# if isinstance(stack_index, slice):
# # we can't iterate but we can index the list directly
# out = LazyStackedTensorDict(
# *[td[td_index] for td in self.tensordicts[stack_index]],
# stack_dim=new_stack_dim,
# )
# elif isinstance(stack_index, (list, range)):
# # then we can iterate
# out = LazyStackedTensorDict(
# *[self.tensordicts[idx][td_index] for idx in stack_index],
# stack_dim=new_stack_dim,
# )
# elif isinstance(stack_index, Tensor):
# # td_index is a nested tuple that mimics the shape of stack_index
# def _nested_stack(t: list, stack_idx: Tensor, td_index):
# if stack_idx.ndim:
# out = LazyStackedTensorDict(
# *[
# _nested_stack(t, _idx, td_index[i])
# for i, _idx in enumerate(stack_idx.unbind(0))
# ],
# stack_dim=new_stack_dim,
# )
# return out
# return t[stack_idx][td_index]
#
# # print(index, td_index, stack_index)
# out = _nested_stack(self.tensordicts, stack_index, td_index)
# else:
# raise TypeError("Invalid index used for stack dimension.")
# out._td_dim_name = self._td_dim_name
# return out
# out = self.tensordicts[stack_index]
# if td_index:
# return out[td_index]
# return out

# def __hash__(self):
# return hash(self.tensordicts)

def __eq__(self, other):
if is_tensorclass(other):
return other == self
Expand Down Expand Up @@ -9084,7 +9061,7 @@ def _clone_value(value: CompatibleType, recurse: bool) -> CompatibleType:


def _is_number(item):
if isinstance(item, Number):
if isinstance(item, (Number, ftdim.Dim)):
return True
if isinstance(item, Tensor) and item.ndim == 0:
return True
Expand Down
6 changes: 5 additions & 1 deletion tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import numpy as np
import torch
from functorch import dim as ftdim

from packaging.version import parse
from tensordict._tensordict import ( # noqa: F401
Expand Down Expand Up @@ -150,7 +151,7 @@ def _getitem_batch_size(batch_size, index):
out.extend(bs_shape)
bs_shape = None
continue
elif isinstance(idx, int):
elif isinstance(idx, (int, ftdim.Dim)):
# could be spared for efficiency
continue
elif isinstance(idx, slice):
Expand Down Expand Up @@ -761,9 +762,12 @@ def _is_shared(tensor: torch.Tensor) -> bool:
if torch._C._functorch.is_batchedtensor(tensor):
return None
return tensor.is_shared()
if isinstance(tensor, ftdim.Tensor):
return None
elif isinstance(tensor, KeyedJaggedTensor):
return False
else:
print(type(tensor))
return tensor.is_shared()


Expand Down
136 changes: 136 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pytest
import torch

from tensordict.nn import TensorDictParams

try:
import torchsnapshot
Expand All @@ -30,6 +31,7 @@
_has_h5py = False

from _utils_internal import decompose, get_available_devices, prod, TestTensorDictsBase
from functorch import dim as ftdim

from tensordict import LazyStackedTensorDict, MemmapTensor, TensorDict
from tensordict.tensordict import (
Expand Down Expand Up @@ -6239,6 +6241,140 @@ def _pool_fixt():
yield pool


class TestFCD(TestTensorDictsBase):
"""Test stack for first-class dimension."""

@pytest.mark.parametrize(
"td_name",
[
"td",
"stacked_td",
"sub_td",
"sub_td2",
"idx_td",
"memmap_td",
"unsqueezed_td",
"squeezed_td",
"td_reset_bs",
"nested_td",
"nested_tensorclass",
"permute_td",
"nested_stacked_td",
"td_params",
pytest.param(
"td_h5",
marks=pytest.mark.skipif(not _has_h5py, reason="h5py not found."),
),
],
)
@pytest.mark.parametrize("device", get_available_devices())
def test_fcd(self, td_name, device):
td = getattr(self, td_name)(device)
d0 = ftdim.dims(1)
if isinstance(td, LazyStackedTensorDict) and td.stack_dim == 0:
with pytest.raises(ValueError, match="Cannot index"):
td[d0]
else:
assert td[d0].shape == td.shape[1:]
d0, d1 = ftdim.dims(2)
if isinstance(td, LazyStackedTensorDict) and td.stack_dim in (0, 1):
with pytest.raises(ValueError, match="Cannot index"):
td[d0, d1]
else:
assert td[d0, d1].shape == td.shape[2:]
d0, d1, d2 = ftdim.dims(3)
if isinstance(td, LazyStackedTensorDict) and td.stack_dim in (0, 1, 2):
with pytest.raises(ValueError, match="Cannot index"):
td[d0, d1, d2]
else:
assert td[d0, d1, d2].shape == td.shape[3:]
d0 = ftdim.dims(1)
if isinstance(td, LazyStackedTensorDict) and td.stack_dim == 1:
with pytest.raises(ValueError, match="Cannot index"):
td[:, d0]
else:
assert td[:, d0].shape == torch.Size((td.shape[0], *td.shape[2:]))

@pytest.mark.parametrize(
"td_name",
[
"td",
"stacked_td",
"idx_td",
"memmap_td",
"td_reset_bs",
"nested_td",
"nested_tensorclass",
"nested_stacked_td",
"td_params",
pytest.param(
"td_h5",
marks=pytest.mark.skipif(not _has_h5py, reason="h5py not found."),
),
# these tds cannot see their dim names edited:
# "sub_td",
# "sub_td2",
# "unsqueezed_td",
# "squeezed_td",
# "permute_td",
],
)
@pytest.mark.parametrize("device", get_available_devices())
def test_fcd_names(self, td_name, device):
td = getattr(self, td_name)(device)
td.names = ["a", "b", "c", "d"]
d0 = ftdim.dims(1)
if isinstance(td, LazyStackedTensorDict) and td.stack_dim == 0:
with pytest.raises(ValueError, match="Cannot index"):
td[d0]
else:
assert td[d0].names == ["b", "c", "d"]
d0, d1 = ftdim.dims(2)
if isinstance(td, LazyStackedTensorDict) and td.stack_dim in (0, 1):
with pytest.raises(ValueError, match="Cannot index"):
td[d0, d1]
else:
assert td[d0, d1].names == ["c", "d"]
d0, d1, d2 = ftdim.dims(3)
if isinstance(td, LazyStackedTensorDict) and td.stack_dim in (0, 1, 2):
with pytest.raises(ValueError, match="Cannot index"):
td[d0, d1, d2]
else:
assert td[d0, d1, d2].names == ["d"]
d0 = ftdim.dims(1)
if isinstance(td, LazyStackedTensorDict) and td.stack_dim == 1:
with pytest.raises(ValueError, match="Cannot index"):
td[:, d0]
else:
assert td[:, d0].names == ["a", "c", "d"]

@pytest.mark.parametrize("as_module", [False, True])
def test_modules(self, as_module):
modules = [
lambda: nn.Linear(3, 4),
lambda: nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 4)),
lambda: nn.Transformer(16, 4, 2, 2, 8),
lambda: nn.Sequential(nn.Conv2d(3, 4, 3), nn.Conv2d(4, 4, 3)),
]
inputs = [
lambda: (torch.randn(2, 3),),
lambda: (torch.randn(2, 3),),
lambda: (torch.randn(2, 3, 16), torch.randn(2, 3, 16)),
lambda: (torch.randn(2, 3, 16, 16),),
]
param_batch = 5
for make_module, make_input in zip(modules, inputs):
module = make_module()
td = TensorDict.from_module(module, as_module=as_module)
td = td.expand(param_batch).clone()
d0 = ftdim.dims(1)
td = TensorDictParams(td)[d0]
td.to_module(module)
y = module(*make_input())
assert y.dims == (d0,)
assert y._tensor.shape[0] == param_batch


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 comments on commit 4c0eb1d

Please sign in to comment.