Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 18, 2024
1 parent 8647e74 commit 0b9b375
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 17 deletions.
10 changes: 2 additions & 8 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,13 +746,7 @@ def _legacy_squeeze(self, dim: int | None = None) -> T:
stack_dim=stack_dim,
)

def unbind(self, dim: int) -> tuple[TensorDictBase, ...]:
if dim < 0:
dim = self.batch_dims + dim
if dim < 0 or dim >= self.ndim:
raise ValueError(
f"Cannot unbind along dimension {dim} with batch size {self.batch_size}."
)
def _unbind(self, dim: int) -> tuple[TensorDictBase, ...]:
if dim == self.stack_dim:
return tuple(self.tensordicts)
else:
Expand Down Expand Up @@ -2869,7 +2863,7 @@ def _unsqueeze(self, dim):
all = TensorDict.all
any = TensorDict.any
expand = TensorDict.expand
unbind = TensorDict.unbind
_unbind = TensorDict._unbind
_get_names_idx = TensorDict._get_names_idx


Expand Down
14 changes: 8 additions & 6 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import numbers
import os
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from copy import copy
from numbers import Number
from pathlib import Path
Expand Down Expand Up @@ -828,9 +827,7 @@ def _expand(tensor):
_expand, batch_size=shape, call_on_nested=True, names=names
)

def unbind(self, dim: int) -> tuple[T, ...]:
if dim < 0:
dim = self.batch_dims + dim
def _unbind(self, dim: int):
batch_size = torch.Size([s for i, s in enumerate(self.batch_size) if i != dim])
names = None
if self._has_names():
Expand All @@ -847,7 +844,12 @@ def empty():
tds = tuple(empty() for _ in range(self.batch_size[dim]))

def unbind(key, val, tds=tds):
for td, _val in zip(tds, val.unbind(dim)):
unbound = (
val.unbind(dim)
if not _is_tensor_collection(type(val))
else val._unbind(dim)
)
for td, _val in zip(tds, unbound):
td._set_str(key, _val, validated=True, inplace=False)

for key, val in self.items():
Expand Down Expand Up @@ -2754,7 +2756,7 @@ def _create_nested_str(self, key):
reshape = TensorDict.reshape
split = TensorDict.split
to_module = TensorDict.to_module
unbind = TensorDict.unbind
_unbind = TensorDict._unbind

def _view(self, *args, **kwargs):
raise RuntimeError(
Expand Down
16 changes: 15 additions & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,6 @@ def expand(self, *args: int | torch.Size) -> T:
"""
...

@abc.abstractmethod
def unbind(self, dim: int) -> tuple[T, ...]:
"""Returns a tuple of indexed tensordicts, unbound along the indicated dimension.
Expand All @@ -558,6 +557,21 @@ def unbind(self, dim: int) -> tuple[T, ...]:
tensor([4, 5, 6, 7])
"""
batch_dims = self.batch_dims
if dim < -batch_dims or dim >= batch_dims:
raise RuntimeError(
f"the dimension provided ({dim}) is beyond the tensordict dimensions ({self.ndim})."
)
if dim < 0:
dim = batch_dims + dim
results = self._unbind(dim)
if self._is_memmap or self._is_shared:
for result in results:
result.lock_()
return results

@abc.abstractmethod
def _unbind(self, dim: int) -> tuple[T, ...]:
...

def chunk(self, chunks: int, dim: int = 0) -> tuple[TensorDictBase, ...]:
Expand Down
2 changes: 1 addition & 1 deletion tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ def chunk(self, chunks: int, dim: int = 0) -> tuple[TensorDictBase, ...]:
...

@_fallback
def unbind(self, dim: int) -> tuple[TensorDictBase, ...]:
def _unbind(self, dim: int) -> tuple[TensorDictBase, ...]:
...

@_fallback
Expand Down
2 changes: 1 addition & 1 deletion tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,7 +1079,7 @@ def _unsqueeze(self, dim):
reshape = TensorDict.reshape
split = TensorDict.split
to_module = TensorDict.to_module
unbind = TensorDict.unbind
_unbind = TensorDict._unbind
_get_names_idx = TensorDict._get_names_idx


Expand Down

0 comments on commit 0b9b375

Please sign in to comment.