From 795e39a5352a12e2aaf035c30cf06190948b80a1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 30 Nov 2023 10:12:14 +0000 Subject: [PATCH] [BugFix] Faster empty_like for MemoryMappedTensor (#585) --- tensordict/memmap.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/tensordict/memmap.py b/tensordict/memmap.py index e357c9da4..65a21ff4c 100644 --- a/tensordict/memmap.py +++ b/tensordict/memmap.py @@ -98,7 +98,15 @@ def __init__( __torch_function__ = torch._C._disabled_torch_function_impl @classmethod - def from_tensor(cls, input, *, filename=None, existsok=False, copy_existing=False): + def from_tensor( + cls, + input, + *, + filename=None, + existsok=False, + copy_existing=False, + copy_data=True, + ): """Creates a MemoryMappedTensor with the same content as another tensor. If the tensor is already a MemoryMappedTensor the original tensor is @@ -118,6 +126,8 @@ def from_tensor(cls, input, *, filename=None, existsok=False, copy_existing=Fals the content to the new location is permitted. Otherwise an exception is thown. This behaviour exists to prevent unadvertedly duplicating data on disk. + copy_data (bool, optional): if ``True``, the content of the tensor + will be copied on the storage. Defaults to ``True``. """ if isinstance(input, MemoryMappedTensor): @@ -205,6 +215,7 @@ def empty_like(cls, input, *, filename=None): return cls.from_tensor( torch.zeros((), dtype=input.dtype, device=input.device).expand_as(input), filename=filename, + copy_data=False, ) @classmethod @@ -221,11 +232,10 @@ def full_like(cls, input, fill_value, *, filename=None): is provided, a handler is used. """ return cls.from_tensor( - torch.zeros((), dtype=input.dtype, device=input.device) - .fill_(fill_value) - .expand_as(input), + torch.zeros((), dtype=input.dtype, device=input.device).expand_as(input), filename=filename, - ) + copy_data=False, + ).fill_(fill_value) @classmethod def zeros_like(cls, input, *, filename=None): @@ -242,7 +252,8 @@ def zeros_like(cls, input, *, filename=None): return cls.from_tensor( torch.zeros((), dtype=input.dtype, device=input.device).expand_as(input), filename=filename, - ) + copy_data=False, + ).fill_(0.0) @classmethod def ones_like(cls, input, *, filename=None): @@ -259,7 +270,8 @@ def ones_like(cls, input, *, filename=None): return cls.from_tensor( torch.ones((), dtype=input.dtype, device=input.device).expand_as(input), filename=filename, - ) + copy_data=False, + ).fill_(1.0) @classmethod @overload @@ -537,7 +549,7 @@ def __getitem__(self, item): "isn't supported at the moment." ) from err raise - if out.storage().data_ptr() == self.storage().data_ptr(): + if out.untyped_storage().data_ptr() == self.untyped_storage().data_ptr(): out = MemoryMappedTensor(out) out._handler = self._handler out._filename = self._filename