From e66935aac6efe77a20504442dc83d4217067eb97 Mon Sep 17 00:00:00 2001 From: vmoens <vincentmoens@gmail.com> Date: Wed, 11 Oct 2023 10:13:10 +0100 Subject: [PATCH 1/2] init --- tensordict/tensordict.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index 8d9c11552..4ea5beb66 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -2546,22 +2546,34 @@ def memmap_like(self, prefix: str | None = None) -> T: if prefix is not None: # ensure subdirectory exists os.makedirs(prefix / key, exist_ok=True) - tensordict[key] = value.memmap_like( - prefix=prefix / key, + tensordict._set_str( + key, + value.memmap_like( + prefix=prefix / key, + ), + inplace=False, + validated=True, ) torch.save( {"batch_size": value.batch_size, "device": value.device}, prefix / key / "meta.pt", ) else: - tensordict[key] = value.memmap_like() + tensordict._set_str( + key, value.memmap_like(), inplace=False, validated=True + ) continue else: - tensordict[key] = MemmapTensor.empty_like( - value, - filename=str(prefix / f"{key}.memmap") - if prefix is not None - else None, + tensordict._set_str( + key, + MemmapTensor.empty_like( + value, + filename=str(prefix / f"{key}.memmap") + if prefix is not None + else None, + ), + inplace=False, + validated=True, ) if prefix is not None: torch.save( From 3ffdd866f75a370b71b2c66f1c184c8fe654c644 Mon Sep 17 00:00:00 2001 From: vmoens <vincentmoens@gmail.com> Date: Wed, 11 Oct 2023 10:29:37 +0100 Subject: [PATCH 2/2] amend --- tensordict/tensordict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index 4ea5beb66..d04256008 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -4818,7 +4818,7 @@ def to(tensor): apply_kwargs = {} if device is not None or dtype is not None: - apply_kwargs["device"] = device + apply_kwargs["device"] = device if device is not None else self.device apply_kwargs["batch_size"] = batch_size result = result._fast_apply(to, **apply_kwargs) elif batch_size is not None: