Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Dec 8, 2023
1 parent d347869 commit adddbfb
Showing 1 changed file with 22 additions and 51 deletions.
73 changes: 22 additions & 51 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
)
from torch import distributed as dist, multiprocessing as mp, nn, Tensor
from torch.utils._pytree import tree_map
import weakref


# NO_DEFAULT is used as a placeholder whenever the default is not provided.
Expand Down Expand Up @@ -142,10 +143,6 @@ def __eq__(self, other: object) -> T:
"""
...

def __del__(self):
for td in getattr(self, "_locked_tensordicts", ()):
td._remove_lock(id(self))

def __repr__(self) -> str:
fields = _td_fields(self)
field_str = indent(f"fields={{{fields}}}", 4 * " ")
Expand Down Expand Up @@ -3950,84 +3947,58 @@ def _propagate_lock(self, lock_ids=None):
self._is_locked = True
is_root = lock_ids is None
if is_root:
lock_ids = set()
self._lock_id = self._lock_id.union(lock_ids)
lock_ids = lock_ids.union({id(self)})
_locked_tensordicts = []
lock_ids = []
self._lock_id += lock_ids
lock_ids.append(weakref.ref(self))
for value in self.values():
if _is_tensor_collection(type(value)):
value._propagate_lock(lock_ids)
_locked_tensordicts.append(value)
if is_root:
self._locked_tensordicts = _locked_tensordicts
else:
self._locked_tensordicts += _locked_tensordicts

@property
def _lock_id(self):
_lock_id = self.__dict__.get("__lock_id", None)
if _lock_id is None:
_lock_id = self.__dict__["__lock_id"] = set()
_lock_id = self.__dict__["__lock_id"] = []
return _lock_id

@_lock_id.setter
def _lock_id(self, value):
def _lock_id(self, value: list):
self.__dict__["__lock_id"] = value

@property
def _locked_tensordicts(self):
_locked_tensordicts = self.__dict__.get("__locked_tensordicts", None)
if _locked_tensordicts is None:
_locked_tensordicts = self.__dict__["__locked_tensordicts"] = []
return _locked_tensordicts

@_locked_tensordicts.setter
def _locked_tensordicts(self, value):
self.__dict__["__locked_tensordicts"] = value

@as_decorator("is_locked")
def lock_(self) -> T:
if self.is_locked:
return self
self._propagate_lock()
return self

def _remove_lock(self, lock_id):
self._lock_id.discard(lock_id)
if self._locked_tensordicts:
for td in self._locked_tensordicts:
td._remove_lock(lock_id)

@erase_cache
def _propagate_unlock(self, lock_ids=None):
if lock_ids is not None:
self._lock_id.difference_update(lock_ids)
else:
lock_ids = set()
def _propagate_unlock(self):
for ref in self._lock_id:
obj = ref()
if obj is not None and obj.is_locked:
raise RuntimeError(
"Cannot unlock a tensordict that is part of a locked graph. "
"Unlock the root tensordict first. If the tensordict is part of multiple graphs, "
"group the graphs under a common tensordict an unlock this root. "
)
self._lock_id = []
self._is_locked = False

unlocked_tds = [self]
lock_ids.add(id(self))
for value in self.values():
if _is_tensor_collection(type(value)):
unlocked_tds.extend(value._propagate_unlock(lock_ids))
self._locked_tensordicts = []
value._propagate_unlock()

self._is_shared = False
self._is_memmap = False
return unlocked_tds

@as_decorator("is_locked")
def unlock_(self) -> T:
unlock_tds = self._propagate_unlock()
for td in unlock_tds:
if len(td._lock_id):
self.lock_()
raise RuntimeError(
"Cannot unlock a tensordict that is part of a locked graph. "
"Unlock the root tensordict first. If the tensordict is part of multiple graphs, "
"group the graphs under a common tensordict an unlock this root. "
)
try:
self._propagate_unlock()
except RuntimeError as err:
self.lock_()
raise err
return self

# Conversion (device or dtype)
Expand Down

0 comments on commit adddbfb

Please sign in to comment.