Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into disable_compile_get_set
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 23, 2024
2 parents 25e87d3 + 0f75ac9 commit d336423
Show file tree
Hide file tree
Showing 9 changed files with 167 additions and 50 deletions.
49 changes: 24 additions & 25 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,13 +746,7 @@ def _legacy_squeeze(self, dim: int | None = None) -> T:
stack_dim=stack_dim,
)

def unbind(self, dim: int) -> tuple[TensorDictBase, ...]:
if dim < 0:
dim = self.batch_dims + dim
if dim < 0 or dim >= self.ndim:
raise ValueError(
f"Cannot unbind along dimension {dim} with batch size {self.batch_size}."
)
def _unbind(self, dim: int) -> tuple[TensorDictBase, ...]:
if dim == self.stack_dim:
return tuple(self.tensordicts)
else:
Expand All @@ -763,7 +757,7 @@ def unbind(self, dim: int) -> tuple[TensorDictBase, ...]:
self.stack_dim if dim > self.stack_dim else self.stack_dim - 1
)
for td in self.tensordicts:
out.append(td.unbind(new_dim))
out.append(td._unbind(new_dim))

return tuple(self.lazy_stack(vals, new_stack_dim) for vals in zip(*out))

Expand Down Expand Up @@ -1521,7 +1515,12 @@ def __getitem__(self, index: IndexType) -> T:
if isinstance(index, (tuple, str)):
index_key = _unravel_key_to_tuple(index)
if index_key:
return self._get_tuple(index_key, NO_DEFAULT)
result = self._get_tuple(index_key, NO_DEFAULT)
from .tensorclass import NonTensorData

if isinstance(result, NonTensorData):
return result.data
return result
split_index = self._split_index(index)
converted_idx = split_index["index_dict"]
isinteger = split_index["isinteger"]
Expand All @@ -1533,22 +1532,22 @@ def __getitem__(self, index: IndexType) -> T:
if has_bool:
mask_unbind = split_index["individual_masks"]
cat_dim = split_index["mask_loc"] - num_single
out = []
result = []
if mask_unbind[0].ndim == 0:
# we can return a stack
for (i, _idx), mask in zip(converted_idx.items(), mask_unbind):
if mask.any():
if mask.all() and self.tensordicts[i].ndim == 0:
out.append(self.tensordicts[i])
result.append(self.tensordicts[i])
else:
out.append(self.tensordicts[i][_idx])
out[-1] = out[-1].squeeze(cat_dim)
return LazyStackedTensorDict.lazy_stack(out, cat_dim)
result.append(self.tensordicts[i][_idx])
result[-1] = result[-1].squeeze(cat_dim)
return LazyStackedTensorDict.lazy_stack(result, cat_dim)
else:
for i, _idx in converted_idx.items():
self_idx = (slice(None),) * split_index["mask_loc"] + (i,)
out.append(self[self_idx][_idx])
return torch.cat(out, cat_dim)
result.append(self[self_idx][_idx])
return torch.cat(result, cat_dim)
elif is_nd_tensor:
new_stack_dim = self.stack_dim - num_single + num_none
return LazyStackedTensorDict.lazy_stack(
Expand All @@ -1562,18 +1561,18 @@ def __getitem__(self, index: IndexType) -> T:
) in (
converted_idx.items()
): # for convenience but there's only one element
out = self.tensordicts[i]
result = self.tensordicts[i]
if _idx is not None and _idx != ():
out = out[_idx]
return out
result = result[_idx]
return result
else:
out = []
result = []
new_stack_dim = self.stack_dim - num_single + num_none - num_squash
for i, _idx in converted_idx.items():
out.append(self.tensordicts[i][_idx])
out = LazyStackedTensorDict.lazy_stack(out, new_stack_dim)
out._td_dim_name = self._td_dim_name
return out
result.append(self.tensordicts[i][_idx])
result = LazyStackedTensorDict.lazy_stack(result, new_stack_dim)
result._td_dim_name = self._td_dim_name
return result

def __eq__(self, other):
if is_tensorclass(other):
Expand Down Expand Up @@ -2869,7 +2868,7 @@ def _unsqueeze(self, dim):
all = TensorDict.all
any = TensorDict.any
expand = TensorDict.expand
unbind = TensorDict.unbind
_unbind = TensorDict._unbind
_get_names_idx = TensorDict._get_names_idx


Expand Down
41 changes: 29 additions & 12 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,23 +831,40 @@ def _expand(tensor):
_expand, batch_size=shape, call_on_nested=True, names=names
)

def unbind(self, dim: int) -> tuple[T, ...]:
if dim < 0:
dim = self.batch_dims + dim
def _unbind(self, dim: int):
batch_size = torch.Size([s for i, s in enumerate(self.batch_size) if i != dim])
names = None
if self._has_names():
names = copy(self.names)
names = [name for i, name in enumerate(names) if i != dim]
out = []
# unbind_self_dict = {key: tensor.unbind(dim) for key, tensor in self.items()}
prefix = (slice(None),) * dim
device = self.device

is_shared = self._is_shared
is_memmap = self._is_memmap

def empty():
result = TensorDict(
{}, batch_size=batch_size, names=names, _run_checks=False, device=device
)
result._is_shared = is_shared
result._is_memmap = is_memmap
return result

tds = tuple(empty() for _ in range(self.batch_size[dim]))

def unbind(key, val, tds=tds):
unbound = (
val.unbind(dim)
if not isinstance(val, TensorDictBase)
# tensorclass is also unbound using plain unbind
else val._unbind(dim)
)
for td, _val in zip(tds, unbound):
td._set_str(key, _val, validated=True, inplace=False)

for _idx in range(self.batch_size[dim]):
_idx = prefix + (_idx,)
td = self._index_tensordict(_idx, new_batch_size=batch_size, names=names)
out.append(td)
return tuple(out)
for key, val in self.items():
unbind(key, val)
return tds

def split(self, split_size: int | list[int], dim: int = 0) -> list[TensorDictBase]:
# we must use slices to keep the storage of the tensors
Expand Down Expand Up @@ -2749,7 +2766,7 @@ def _create_nested_str(self, key):
reshape = TensorDict.reshape
split = TensorDict.split
to_module = TensorDict.to_module
unbind = TensorDict.unbind
_unbind = TensorDict._unbind

def _view(self, *args, **kwargs):
raise RuntimeError(
Expand Down
12 changes: 9 additions & 3 deletions tensordict/_torch_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from tensordict._lazy import LazyStackedTensorDict
from tensordict._td import TensorDict

from tensordict.base import NO_DEFAULT, TensorDictBase
from tensordict.base import _is_leaf_nontensor, NO_DEFAULT, TensorDictBase
from tensordict.persistent import PersistentTensorDict
from tensordict.utils import (
_check_keys,
Expand Down Expand Up @@ -95,12 +95,18 @@ def _gather_tensor(tensor, dest=None):
names = input.names if input._has_names() else None

return TensorDict(
{key: _gather_tensor(value) for key, value in input.items()},
{
key: _gather_tensor(value)
for key, value in input.items(is_leaf=_is_leaf_nontensor)
},
batch_size=index.shape,
names=names,
)
TensorDict(
{key: _gather_tensor(value, out[key]) for key, value in input.items()},
{
key: _gather_tensor(value, out.get(key))
for key, value in input.items(is_leaf=_is_leaf_nontensor)
},
batch_size=index.shape,
)
return out
Expand Down
53 changes: 51 additions & 2 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
_KEY_ERROR,
_proc_init,
_prune_selected_keys,
_set_max_batch_size,
_shape,
_split_tensordict,
_td_fields,
Expand Down Expand Up @@ -207,6 +208,9 @@ def __getitem__(self, index: IndexType) -> T:
The index can be a (nested) key or any valid shape index given the
tensordict batch size.
If the index is a nested key and the result is a :class:`~tensordict.NonTensorData`
object, the content of the non-tensor is returned.
Examples:
>>> td = TensorDict({"root": torch.arange(2), ("nested", "entry"): torch.arange(2)}, [2])
>>> td["root"]
Expand All @@ -232,7 +236,13 @@ def __getitem__(self, index: IndexType) -> T:
# _unravel_key_to_tuple will return an empty tuple if the index isn't a NestedKey
idx_unravel = _unravel_key_to_tuple(index)
if idx_unravel:
return self._get_tuple(idx_unravel, NO_DEFAULT)
result = self._get_tuple(idx_unravel, NO_DEFAULT)
from .tensorclass import NonTensorData

if isinstance(result, NonTensorData):
return result.data
return result

if (istuple and not index) or (not istuple and index is Ellipsis):
# empty tuple returns self
return self
Expand Down Expand Up @@ -327,6 +337,31 @@ def any(self, dim: int = None) -> bool | TensorDictBase:
"""
...

def auto_batch_size_(self, batch_dims: int | None = None) -> T:
"""Sets the maximum batch-size for the tensordict, up to an optional batch_dims.
Args:
batch_dims (int, optional): if provided, the batch-size will be at
most ``batch_dims`` long.
Returns:
self
Examples:
>>> from tensordict import TensorDict
>>> import torch
>>> td = TensorDict({"a": torch.randn(3, 4, 5), "b": {"c": torch.randn(3, 4, 6)}}, batch_size=[])
>>> td.auto_batch_size_()
>>> print(td.batch_size)
torch.Size([3, 4])
>>> td.auto_batch_size_(batch_dims=1)
>>> print(td.batch_size)
torch.Size([3])
"""
_set_max_batch_size(self, batch_dims)
return self

# Module interaction
@classmethod
def from_module(
Expand Down Expand Up @@ -544,7 +579,6 @@ def expand(self, *args: int | torch.Size) -> T:
"""
...

@abc.abstractmethod
def unbind(self, dim: int) -> tuple[T, ...]:
"""Returns a tuple of indexed tensordicts, unbound along the indicated dimension.
Expand All @@ -559,6 +593,21 @@ def unbind(self, dim: int) -> tuple[T, ...]:
tensor([4, 5, 6, 7])
"""
batch_dims = self.batch_dims
if dim < -batch_dims or dim >= batch_dims:
raise RuntimeError(
f"the dimension provided ({dim}) is beyond the tensordict dimensions ({self.ndim})."
)
if dim < 0:
dim = batch_dims + dim
results = self._unbind(dim)
if self._is_memmap or self._is_shared:
for result in results:
result.lock_()
return results

@abc.abstractmethod
def _unbind(self, dim: int) -> tuple[T, ...]:
...

def chunk(self, chunks: int, dim: int = 0) -> tuple[TensorDictBase, ...]:
Expand Down
2 changes: 1 addition & 1 deletion tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ def chunk(self, chunks: int, dim: int = 0) -> tuple[TensorDictBase, ...]:
...

@_fallback
def unbind(self, dim: int) -> tuple[TensorDictBase, ...]:
def _unbind(self, dim: int) -> tuple[TensorDictBase, ...]:
...

@_fallback
Expand Down
2 changes: 1 addition & 1 deletion tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,7 +1079,7 @@ def _unsqueeze(self, dim):
reshape = TensorDict.reshape
split = TensorDict.split
to_module = TensorDict.to_module
unbind = TensorDict.unbind
_unbind = TensorDict._unbind
_get_names_idx = TensorDict._get_names_idx


Expand Down
14 changes: 12 additions & 2 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from tensordict._td import is_tensor_collection, NO_DEFAULT, TensorDict, TensorDictBase
from tensordict._tensordict import _unravel_key_to_tuple
from tensordict._torch_func import TD_HANDLED_FUNCTIONS
from tensordict.base import _register_tensor_class
from tensordict.base import _ACCEPTED_CLASSES, _register_tensor_class
from tensordict.memmap_deprec import MemmapTensor as _MemmapTensor

from tensordict.utils import (
Expand Down Expand Up @@ -1305,7 +1305,17 @@ def to_dict(self):
def _stack_non_tensor(cls, list_of_non_tensor, dim=0):
# checks have been performed previously, so we're sure the list is non-empty
first = list_of_non_tensor[0]
if all(data.data == first.data for data in list_of_non_tensor[1:]):

def _check_equal(a, b):
if isinstance(a, _ACCEPTED_CLASSES) or isinstance(b, _ACCEPTED_CLASSES):
return (a == b).all()
try:
iseq = a == b
except Exception:
iseq = False
return iseq

if all(_check_equal(data.data, first.data) for data in list_of_non_tensor[1:]):
batch_size = list(first.batch_size)
batch_size.insert(dim, len(list_of_non_tensor))
return NonTensorData(
Expand Down
13 changes: 10 additions & 3 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1441,17 +1441,24 @@ def _expand_to_match_shape(

def _set_max_batch_size(source: T, batch_dims=None):
"""Updates a tensordict with its maximium batch size."""
tensor_data = list(source.values())
from tensordict import NonTensorData

tensor_data = [val for val in source.values() if not isinstance(val, NonTensorData)]

for val in tensor_data:
from tensordict.base import _is_tensor_collection

if _is_tensor_collection(val.__class__):
_set_max_batch_size(val, batch_dims=batch_dims)

batch_size = []
if not tensor_data: # when source is empty
source.batch_size = batch_size
return
if batch_dims:
source.batch_size = source.batch_size[:batch_dims]
return source
else:
return source

curr_dim = 0
while True:
if tensor_data[0].dim() > curr_dim:
Expand Down
Loading

0 comments on commit d336423

Please sign in to comment.