diff --git a/tensordict/_td.py b/tensordict/_td.py index 4faafa031..aed415f36 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -182,6 +182,8 @@ class TensorDict(TensorDictBase): """ _td_dim_names = None + _is_shared = False + _is_memmap = False def __init__( self, @@ -190,11 +192,7 @@ def __init__( device: DeviceType | None = None, names: Sequence[str] | None = None, _run_checks: bool = True, - _is_shared: bool | None = False, - _is_memmap: bool | None = False, ) -> None: - self._is_shared = _is_shared - self._is_memmap = _is_memmap if device is not None and isinstance(device, (int, str)): device = torch.device(device) self._device = device @@ -210,8 +208,6 @@ def __init__( batch_size=self._batch_size, device=self._device, _run_checks=_run_checks, - _is_shared=_is_shared, - _is_memmap=_is_memmap, ) _tensordict[key] = value self._td_dim_names = names @@ -1771,13 +1767,13 @@ def is_contiguous(self) -> bool: return all([value.is_contiguous() for _, value in self.items()]) def clone(self, recurse: bool = True) -> T: - if recurse: - def func(x): - return x.clone() - else: - def func(x): - return x - return self._fast_apply(func, call_on_nested=False) + return TensorDict( + source={key: _clone_value(value, recurse) for key, value in self.items()}, + batch_size=self.batch_size, + device=self.device, + names=copy(self._td_dim_names), + _run_checks=False, + ) def contiguous(self) -> T: if not self.is_contiguous(): @@ -1792,8 +1788,6 @@ def empty(self, recurse=False) -> T: source={}, names=self._td_dim_names, _run_checks=False, - _is_memmap=False, - _is_shared=False, ) return super().empty(recurse=recurse) @@ -1834,8 +1828,6 @@ def select(self, *keys: NestedKey, inplace: bool = False, strict: bool = True) - # names=self.names if self._has_names() else None, names=self._td_dim_names, _run_checks=False, - _is_memmap=self._is_memmap, - _is_shared=self._is_shared, ) if inplace: self._tensordict = out._tensordict