diff --git a/tensordict/_td.py b/tensordict/_td.py index dc1994824..4b5ff73c5 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -196,33 +196,29 @@ def __init__( _is_shared: bool | None = False, _is_memmap: bool | None = False, ) -> None: - self._lock_id = set() - self._locked_tensordicts = [] - self._is_shared = _is_shared self._is_memmap = _is_memmap if device is not None and isinstance(device, (int, str)): device = torch.device(device) self._device = device + self._tensordict = _tensordict = _StringOnlyDict() if not _run_checks: - _tensordict: dict = _StringOnlyDict() self._batch_size = batch_size - for key, value in source.items(): - if isinstance(value, dict): - value = TensorDict( - value, - batch_size=self._batch_size, - device=self._device, - _run_checks=_run_checks, - _is_shared=_is_shared, - _is_memmap=_is_memmap, - ) - _tensordict[key] = value - self._tensordict = _tensordict + if source: # faster than calling items + for key, value in source.items(): + if isinstance(value, dict): + value = TensorDict( + value, + batch_size=self._batch_size, + device=self._device, + _run_checks=_run_checks, + _is_shared=_is_shared, + _is_memmap=_is_memmap, + ) + _tensordict[key] = value self._td_dim_names = names else: - self._tensordict = _StringOnlyDict() if not isinstance(source, (TensorDictBase, dict)): raise ValueError( "A TensorDict source is expected to be a TensorDictBase " diff --git a/tensordict/base.py b/tensordict/base.py index 5a467ec60..00646e236 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3758,6 +3758,28 @@ def _propagate_lock(self, lock_ids=None): else: self._locked_tensordicts += _locked_tensordicts + @property + def _lock_id(self): + _lock_id = self.__dict__.get("__lock_id", None) + if _lock_id is None: + _lock_id = self.__dict__["__lock_id"] = set() + return _lock_id + + @_lock_id.setter + def _lock_id(self, value): + self.__dict__["__lock_id"] = value + + @property + def _locked_tensordicts(self): + _locked_tensordicts = self.__dict__.get("__locked_tensordicts", None) + if _locked_tensordicts is None: + _locked_tensordicts = self.__dict__["__locked_tensordicts"] = [] + return _locked_tensordicts + + @_locked_tensordicts.setter + def _locked_tensordicts(self, value): + self.__dict__["__locked_tensordicts"] = value + @as_decorator("is_locked") def lock_(self) -> T: if self.is_locked: diff --git a/test/test_tensordict.py b/test/test_tensordict.py index c0893ca9d..4ee3298f1 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -6123,10 +6123,10 @@ def test_empty_tensordict_list_stack(self): a0 = td0["a"] b0 = td0["a", "b"] c0 = td0["a", "b", "c"] - assert not hasattr(td, "_locked_tensordicts") - assert not hasattr(a, "_locked_tensordicts") - assert not hasattr(b, "_locked_tensordicts") - assert not hasattr(c, "_locked_tensordicts") + assert not td._locked_tensordicts + assert not a._locked_tensordicts + assert not b._locked_tensordicts + assert not c._locked_tensordicts assert len(a0._locked_tensordicts) assert len(b0._locked_tensordicts) td.unlock_()