Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Nov 23, 2023
1 parent dc4eb6b commit 79210c2
Show file tree
Hide file tree
Showing 7 changed files with 458 additions and 243 deletions.
62 changes: 35 additions & 27 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,28 +78,36 @@ def is_batchedtensor(tensor: Tensor) -> bool:
class _LazyStackedTensorDictKeysView(_TensorDictKeysView):
tensordict: LazyStackedTensorDict

def __len__(self) -> int:
return len(self._keys())

def _keys(self) -> list[str]:
def _tensor_keys(self):
return self.tensordict._key_list(leaves_only=True)
def _node_keys(self):
return self.tensordict._key_list(nodes_only=True)
def _keys(self):
return self.tensordict._key_list()

def __contains__(self, item):
item = _unravel_key_to_tuple(item)
if item[0] in self.tensordict._iterate_over_keys():
if self.leaves_only:
return not _is_tensor_collection(self.tensordict.entry_class(item[0]))
has_first_key = True
else:
has_first_key = False
if not has_first_key or len(item) == 1:
return has_first_key
# otherwise take the long way
return all(
item[1:]
in tensordict.get(item[0]).keys(self.include_nested, self.leaves_only)
for tensordict in self.tensordict.tensordicts
)
#
# def __len__(self) -> int:
# return len(self._keys())
#
# def _keys(self) -> list[str]:
# return self.tensordict._key_list()
#
# def __contains__(self, item):
# item = _unravel_key_to_tuple(item)
# if item[0] in self.tensordict._iterate_over_keys():
# if self.leaves_only:
# return not _is_tensor_collection(self.tensordict.entry_class(item[0]))
# has_first_key = True
# else:
# has_first_key = False
# if not has_first_key or len(item) == 1:
# return has_first_key
# # otherwise take the long way
# return all(
# item[1:]
# in tensordict.get(item[0]).keys(self.include_nested, self.leaves_only)
# for tensordict in self.tensordict.tensordicts
# )


class LazyStackedTensorDict(TensorDictBase):
Expand Down Expand Up @@ -1053,10 +1061,10 @@ def _change_batch_size(self, new_size: torch.Size) -> None:
self._batch_size = new_size

def keys(
self, include_nested: bool = False, leaves_only: bool = False
self, include_nested: bool = False, leaves_only: bool = False, nodes_only: bool = False,
) -> _LazyStackedTensorDictKeysView:
keys = _LazyStackedTensorDictKeysView(
self, include_nested=include_nested, leaves_only=leaves_only
self, include_nested=include_nested, leaves_only=leaves_only, nodes_only=nodes_only,
)
return keys

Expand All @@ -1071,10 +1079,10 @@ def _iterate_over_keys(self) -> None:
yield from self._key_list()

@cache # noqa: B019
def _key_list(self):
keys = set(self.tensordicts[0].keys())
def _key_list(self, leaves_only=False, nodes_only=False):
keys = set(self.tensordicts[0].keys(leaves_only=leaves_only, nodes_only=nodes_only))
for td in self.tensordicts[1:]:
keys = keys.intersection(td.keys())
keys = keys.intersection(td.keys(leaves_only=leaves_only, nodes_only=nodes_only))
return sorted(keys, key=str)

def entry_class(self, key: NestedKey) -> type:
Expand Down Expand Up @@ -2152,9 +2160,9 @@ def __repr__(self) -> str:

# @cache # noqa: B019
def keys(
self, include_nested: bool = False, leaves_only: bool = False
self, include_nested: bool = False, leaves_only: bool = False, nodes_only:bool = False,
) -> _TensorDictKeysView:
return self._source.keys(include_nested=include_nested, leaves_only=leaves_only)
return self._source.keys(include_nested=include_nested, leaves_only=leaves_only, nodes_only=nodes_only)

def select(
self, *keys: str, inplace: bool = False, strict: bool = True
Expand Down
Loading

0 comments on commit 79210c2

Please sign in to comment.