From c55e28afe4e00c8e3e8f32d92300c312622935dd Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 23 Nov 2023 08:57:31 +0000 Subject: [PATCH] amend --- tensordict/_td.py | 2 +- tensordict/base.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index c6b8e8653..5c34d8fbd 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -272,7 +272,6 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None) # is passed with to_module) but for the device it could be a problem. if swap_dest is None: swap = TensorDict({}, batch_size=[]) - swap.clear_device_() else: swap = swap_dest memo[id(module)] = swap @@ -324,6 +323,7 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None) _swap[key] = local_out if return_swap: if isinstance(swap, TensorDict): + # this is very ad-hoc but faster than calling _set_str every time swap._tensordict.update(_swap) else: swap.update(_swap) diff --git a/tensordict/base.py b/tensordict/base.py index e54d77d22..e9f0b7b5a 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3125,7 +3125,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): return self.lock_() if last_op == self.__class__.to_module.__name__: if is_tensor_collection(out): - # with out.unlock_(): return self.to_module(*args, **kwargs, swap_dest=out) else: raise RuntimeError(