From 6a9f8e354848455054a74f61e74187bb0368677a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 11 Oct 2023 05:40:09 -0400 Subject: [PATCH] [BugFix] Better tensor allocation in memmap_like (#543) --- tensordict/tensordict.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index 8d9c11552..d04256008 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -2546,22 +2546,34 @@ def memmap_like(self, prefix: str | None = None) -> T: if prefix is not None: # ensure subdirectory exists os.makedirs(prefix / key, exist_ok=True) - tensordict[key] = value.memmap_like( - prefix=prefix / key, + tensordict._set_str( + key, + value.memmap_like( + prefix=prefix / key, + ), + inplace=False, + validated=True, ) torch.save( {"batch_size": value.batch_size, "device": value.device}, prefix / key / "meta.pt", ) else: - tensordict[key] = value.memmap_like() + tensordict._set_str( + key, value.memmap_like(), inplace=False, validated=True + ) continue else: - tensordict[key] = MemmapTensor.empty_like( - value, - filename=str(prefix / f"{key}.memmap") - if prefix is not None - else None, + tensordict._set_str( + key, + MemmapTensor.empty_like( + value, + filename=str(prefix / f"{key}.memmap") + if prefix is not None + else None, + ), + inplace=False, + validated=True, ) if prefix is not None: torch.save( @@ -4806,7 +4818,7 @@ def to(tensor): apply_kwargs = {} if device is not None or dtype is not None: - apply_kwargs["device"] = device + apply_kwargs["device"] = device if device is not None else self.device apply_kwargs["batch_size"] = batch_size result = result._fast_apply(to, **apply_kwargs) elif batch_size is not None: