From 3d718a0655a042b47e40706e3072321cb8b2996e Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 24 Nov 2023 12:11:56 +0000 Subject: [PATCH 1/5] init --- tensordict/_td.py | 30 ++++++++++++++++++------------ tensordict/utils.py | 29 +++++++++++++++++++---------- 2 files changed, 37 insertions(+), 22 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index 9ddeae2a5..dc1994824 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -259,20 +259,31 @@ def from_module( td_struct.lock_() return td_struct + def is_empty(self): + for _ in self._tensordict: + return False + return True + @as_decorator() def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None): # we use __dict__ directly to avoid the getattr/setattr overhead whenever we can __dict__ = module.__dict__ + swap = None has_set_device = False if memo is None: - memo = {} + hooks = getattr( + torch.nn.modules.module, "_global_parameter_registration_hooks", {} + ) + memo = {"hooks": tuple(hooks.values())} + else: + hooks = memo["hooks"] if return_swap: # this could break if the device and batch-size are not congruent. # For batch-size it is a minor issue (unlikely that a td with batch-size # is passed with to_module) but for the device it could be a problem. if swap_dest is None: - swap = TensorDict({}, batch_size=[]) + swap = TensorDict({}, batch_size=torch.Size(()), _run_checks=False) else: swap = swap_dest memo[id(module)] = swap @@ -282,18 +293,16 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None) if isinstance(value, (Tensor, ftdim.Tensor)): if module.__class__.__setattr__ is __base__setattr__: # if setattr is the native nn.Module.setattr, we can rely on _set_tensor_dict - local_out = _set_tensor_dict(__dict__, module, key, value) + local_out = _set_tensor_dict(__dict__, hooks, module, key, value) else: if return_swap: local_out = getattr(module, key) # use specialized __setattr__ if needed setattr(module, key, value) else: - for _ in value.keys(): + if value.is_empty(): # if there is at least one key, we must populate the module. # Otherwise we just go to the next key - break - else: continue if swap_dest is not None: local_dest = swap_dest._get_str(key, default=NO_DEFAULT) @@ -2571,7 +2580,7 @@ def __repr__(self): def _set_tensor_dict( # noqa: F811 - module_dict, module, name: str, tensor: torch.Tensor + module_dict, hooks, module, name: str, tensor: torch.Tensor ) -> None: """Simplified version of torch.nn.utils._named_member_accessor.""" was_buffer = False @@ -2580,13 +2589,10 @@ def _set_tensor_dict( # noqa: F811 out = module_dict["_buffers"].pop(name, None) was_buffer = out is not None if out is None: - out = module_dict.pop(name, None) + out = module_dict.pop(name) if isinstance(tensor, torch.nn.Parameter): - # module.register_parameter(name, tensor) - for hook in getattr( - torch.nn.modules.module, "_global_parameter_registration_hooks", {} - ).values(): + for hook in hooks: output = hook(module, name, tensor) if output is not None: tensor = output diff --git a/tensordict/utils.py b/tensordict/utils.py index 6015c8f79..b9e5b4b95 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1116,19 +1116,28 @@ def __init__(self, attr=None): self.attr = attr def __call__(self, func): - @wraps(func) - def new_func(_self, *args, **kwargs): - if self.attr is not None: + if self.attr is not None: + + @wraps(func) + def new_func(_self, *args, **kwargs): _attr_pre = getattr(_self, self.attr) - out = func(_self, *args, **kwargs) - if self.attr is not None: + out = func(_self, *args, **kwargs) _attr_post = getattr(_self, self.attr) - if out is not None: - if self.attr is None or (_attr_post is not _attr_pre): + if out is not None: + if _attr_post is not _attr_pre: + out._last_op = (new_func.__name__, (args, kwargs, _self)) + else: + out._last_op = None + return out + + else: + + @wraps(func) + def new_func(_self, *args, **kwargs): + out = func(_self, *args, **kwargs) + if out is not None: out._last_op = (new_func.__name__, (args, kwargs, _self)) - else: - out._last_op = None - return out + return out return new_func From c5f29d362aea8ce4a77e9683c7c75decd30b8e0f Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 24 Nov 2023 13:44:59 +0000 Subject: [PATCH 2/5] init --- tensordict/_td.py | 53 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 17 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index dc1994824..a0cb49f2a 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -196,33 +196,29 @@ def __init__( _is_shared: bool | None = False, _is_memmap: bool | None = False, ) -> None: - self._lock_id = set() - self._locked_tensordicts = [] - self._is_shared = _is_shared self._is_memmap = _is_memmap if device is not None and isinstance(device, (int, str)): device = torch.device(device) self._device = device + self._tensordict = _tensordict = _StringOnlyDict() if not _run_checks: - _tensordict: dict = _StringOnlyDict() self._batch_size = batch_size - for key, value in source.items(): - if isinstance(value, dict): - value = TensorDict( - value, - batch_size=self._batch_size, - device=self._device, - _run_checks=_run_checks, - _is_shared=_is_shared, - _is_memmap=_is_memmap, - ) - _tensordict[key] = value - self._tensordict = _tensordict + if source: # faster than calling items + for key, value in source.items(): + if isinstance(value, dict): + value = TensorDict( + value, + batch_size=self._batch_size, + device=self._device, + _run_checks=_run_checks, + _is_shared=_is_shared, + _is_memmap=_is_memmap, + ) + _tensordict[key] = value self._td_dim_names = names else: - self._tensordict = _StringOnlyDict() if not isinstance(source, (TensorDictBase, dict)): raise ValueError( "A TensorDict source is expected to be a TensorDictBase " @@ -235,6 +231,28 @@ def __init__( for key, value in source.items(): self.set(key, value) + @property + def _lock_id(self): + _lock_id = self.__dict__.get("__lock_id", None) + if _lock_id is None: + _lock_id = self.__lock_id = set() + return _lock_id + + @_lock_id.setter + def _lock_id(self, value): + self.__lock_id = value + + @property + def _locked_tensordicts(self): + _locked_tensordicts = self.__dict__.get("__locked_tensordicts", None) + if _locked_tensordicts is None: + _locked_tensordicts = self.__locked_tensordicts = set() + return _locked_tensordicts + + @_locked_tensordicts.setter + def _locked_tensordicts(self, value): + self.__locked_tensordicts = value + @staticmethod def from_module( module: torch.nn.Module, as_module: bool = False, lock: bool = False @@ -278,6 +296,7 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None) memo = {"hooks": tuple(hooks.values())} else: hooks = memo["hooks"] + if return_swap: # this could break if the device and batch-size are not congruent. # For batch-size it is a minor issue (unlikely that a td with batch-size From b7078ccbc003a513843a5632a7dfc52799e07421 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 24 Nov 2023 13:51:54 +0000 Subject: [PATCH 3/5] amend --- tensordict/_td.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index a334a0bf5..b999c38fe 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -246,7 +246,7 @@ def _lock_id(self, value): def _locked_tensordicts(self): _locked_tensordicts = self.__dict__.get("__locked_tensordicts", None) if _locked_tensordicts is None: - _locked_tensordicts = self.__locked_tensordicts = set() + _locked_tensordicts = self.__locked_tensordicts = [] return _locked_tensordicts @_locked_tensordicts.setter From 4a897e9564fe8a1c59e2b9a8303def1748be1e45 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 24 Nov 2023 14:17:42 +0000 Subject: [PATCH 4/5] amend --- tensordict/_td.py | 22 ---------------------- tensordict/base.py | 22 ++++++++++++++++++++++ 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index b999c38fe..4b5ff73c5 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -231,28 +231,6 @@ def __init__( for key, value in source.items(): self.set(key, value) - @property - def _lock_id(self): - _lock_id = self.__dict__.get("__lock_id", None) - if _lock_id is None: - _lock_id = self.__lock_id = set() - return _lock_id - - @_lock_id.setter - def _lock_id(self, value): - self.__lock_id = value - - @property - def _locked_tensordicts(self): - _locked_tensordicts = self.__dict__.get("__locked_tensordicts", None) - if _locked_tensordicts is None: - _locked_tensordicts = self.__locked_tensordicts = [] - return _locked_tensordicts - - @_locked_tensordicts.setter - def _locked_tensordicts(self, value): - self.__locked_tensordicts = value - @staticmethod def from_module( module: torch.nn.Module, as_module: bool = False, lock: bool = False diff --git a/tensordict/base.py b/tensordict/base.py index 5a467ec60..00646e236 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3758,6 +3758,28 @@ def _propagate_lock(self, lock_ids=None): else: self._locked_tensordicts += _locked_tensordicts + @property + def _lock_id(self): + _lock_id = self.__dict__.get("__lock_id", None) + if _lock_id is None: + _lock_id = self.__dict__["__lock_id"] = set() + return _lock_id + + @_lock_id.setter + def _lock_id(self, value): + self.__dict__["__lock_id"] = value + + @property + def _locked_tensordicts(self): + _locked_tensordicts = self.__dict__.get("__locked_tensordicts", None) + if _locked_tensordicts is None: + _locked_tensordicts = self.__dict__["__locked_tensordicts"] = [] + return _locked_tensordicts + + @_locked_tensordicts.setter + def _locked_tensordicts(self, value): + self.__dict__["__locked_tensordicts"] = value + @as_decorator("is_locked") def lock_(self) -> T: if self.is_locked: From 9c99c0db309cdfee4eb6bf40b43e3958cc0b2d6c Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 24 Nov 2023 14:42:24 +0000 Subject: [PATCH 5/5] amend --- test/test_tensordict.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index c0893ca9d..4ee3298f1 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -6123,10 +6123,10 @@ def test_empty_tensordict_list_stack(self): a0 = td0["a"] b0 = td0["a", "b"] c0 = td0["a", "b", "c"] - assert not hasattr(td, "_locked_tensordicts") - assert not hasattr(a, "_locked_tensordicts") - assert not hasattr(b, "_locked_tensordicts") - assert not hasattr(c, "_locked_tensordicts") + assert not td._locked_tensordicts + assert not a._locked_tensordicts + assert not b._locked_tensordicts + assert not c._locked_tensordicts assert len(a0._locked_tensordicts) assert len(b0._locked_tensordicts) td.unlock_()