Skip to content

Commit

Permalink
[Refactor] Faster instantiation
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Nov 3, 2023
1 parent 9865dec commit ec304fd
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 47 deletions.
4 changes: 1 addition & 3 deletions tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,7 @@ class PersistentTensorDict(TensorDictBase):
"""

def __new__(cls, *args, **kwargs):
cls._td_dim_names = None
return super().__new__(cls, *args, **kwargs)
_td_dim_names = None

def __init__(
self,
Expand Down
72 changes: 28 additions & 44 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,17 +354,14 @@ class TensorDictBase(MutableMapping):
)
KEY_ERROR = 'key "{}" not found in {} with ' "keys {}"

def __new__(cls, *args: Any, **kwargs: Any) -> T:
self = super().__new__(cls)
self._safe = kwargs.get("_safe", False)
self._lazy = kwargs.get("_lazy", False)
self._inplace_set = kwargs.get("_inplace_set", False)
self.is_meta = kwargs.get("is_meta", False)
self._is_locked = kwargs.get("_is_locked", False)
self._cache = None
self._last_op = None
self.__last_op_queue = None
return self
_safe = False
_lazy = False
_inplace_set = False
is_meta = False
_is_locked = False
_cache = None
_last_op = None
__last_op_queue = None

def __getstate__(self) -> dict[str, Any]:
state = self.__dict__.copy()
Expand Down Expand Up @@ -4077,26 +4074,11 @@ class TensorDict(TensorDictBase):
"""

__slots__ = (
"_tensordict",
"_batch_size",
"_is_shared",
"_is_memmap",
"_device",
"_is_locked",
"_td_dim_names",
"_lock_id",
"_locked_tensordicts",
"_cache",
"_last_op",
"__last_op_queue",
)

def __new__(cls, *args: Any, **kwargs: Any) -> TensorDict:
cls._is_shared = False
cls._is_memmap = False
cls._td_dim_names = None
return super().__new__(cls, *args, _safe=True, _lazy=False, **kwargs)
_is_shared = False
_is_memmap = False
_td_dim_names = None
_safe = True
_lazy = False

def __init__(
self,
Expand Down Expand Up @@ -5001,11 +4983,12 @@ def _nested_keys(
)

def __getstate__(self):
return {
slot: getattr(self, slot)
for slot in self.__slots__
if slot not in ("_last_op", "_cache", "__last_op_queue")
result = {
key: val
for key, val in self.__dict__.items()
if key not in ("_last_op", "_cache", "__last_op_queue")
}
return result

def __setstate__(self, state):
for slot, value in state.items():
Expand Down Expand Up @@ -5790,10 +5773,11 @@ class SubTensorDict(TensorDictBase):
"""

def __new__(cls, *args: Any, **kwargs: Any) -> SubTensorDict:
cls._is_shared = False
cls._is_memmap = False
return super().__new__(cls, _safe=False, _lazy=True, _inplace_set=True)
_is_shared = False
_is_memmap = False
_safe = False
_lazy = True
_inplace_set = True

def __init__(
self,
Expand Down Expand Up @@ -6414,9 +6398,9 @@ def __torch_function__(
else:
return super().__torch_function__(func, types, args, kwargs)

def __new__(cls, *args: Any, **kwargs: Any) -> LazyStackedTensorDict:
cls._td_dim_name = None
return super().__new__(cls, *args, _safe=False, _lazy=True, **kwargs)
_td_dim_name = None
_safe = False
_lazy = True

def __init__(
self,
Expand Down Expand Up @@ -8162,8 +8146,8 @@ def _repr_exclusive_fields(self):
class _CustomOpTensorDict(TensorDictBase):
"""Encodes lazy operations on tensors contained in a TensorDict."""

def __new__(cls, *args: Any, **kwargs: Any) -> _CustomOpTensorDict:
return super().__new__(cls, *args, _safe=False, _lazy=True, **kwargs)
_safe = False
_lazy = True

def __init__(
self,
Expand Down

0 comments on commit ec304fd

Please sign in to comment.