Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 16, 2024
1 parent 3b9ed90 commit 7620392
Showing 1 changed file with 9 additions and 17 deletions.
26 changes: 9 additions & 17 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ class TensorDict(TensorDictBase):
"""

_td_dim_names = None
_is_shared = False
_is_memmap = False

def __init__(
self,
Expand All @@ -190,11 +192,7 @@ def __init__(
device: DeviceType | None = None,
names: Sequence[str] | None = None,
_run_checks: bool = True,
_is_shared: bool | None = False,
_is_memmap: bool | None = False,
) -> None:
self._is_shared = _is_shared
self._is_memmap = _is_memmap
if device is not None and isinstance(device, (int, str)):
device = torch.device(device)
self._device = device
Expand All @@ -210,8 +208,6 @@ def __init__(
batch_size=self._batch_size,
device=self._device,
_run_checks=_run_checks,
_is_shared=_is_shared,
_is_memmap=_is_memmap,
)
_tensordict[key] = value
self._td_dim_names = names
Expand Down Expand Up @@ -1771,13 +1767,13 @@ def is_contiguous(self) -> bool:
return all([value.is_contiguous() for _, value in self.items()])

def clone(self, recurse: bool = True) -> T:
if recurse:
def func(x):
return x.clone()
else:
def func(x):
return x
return self._fast_apply(func, call_on_nested=False)
return TensorDict(
source={key: _clone_value(value, recurse) for key, value in self.items()},
batch_size=self.batch_size,
device=self.device,
names=copy(self._td_dim_names),
_run_checks=False,
)

def contiguous(self) -> T:
if not self.is_contiguous():
Expand All @@ -1792,8 +1788,6 @@ def empty(self, recurse=False) -> T:
source={},
names=self._td_dim_names,
_run_checks=False,
_is_memmap=False,
_is_shared=False,
)
return super().empty(recurse=recurse)

Expand Down Expand Up @@ -1834,8 +1828,6 @@ def select(self, *keys: NestedKey, inplace: bool = False, strict: bool = True) -
# names=self.names if self._has_names() else None,
names=self._td_dim_names,
_run_checks=False,
_is_memmap=self._is_memmap,
_is_shared=self._is_shared,
)
if inplace:
self._tensordict = out._tensordict
Expand Down

0 comments on commit 7620392

Please sign in to comment.