Skip to content

Commit

Permalink
[BugFix] Faster empty_like for MemoryMappedTensor (#585)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Nov 30, 2023
1 parent c80078c commit 795e39a
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions tensordict/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,15 @@ def __init__(
__torch_function__ = torch._C._disabled_torch_function_impl

@classmethod
def from_tensor(cls, input, *, filename=None, existsok=False, copy_existing=False):
def from_tensor(
cls,
input,
*,
filename=None,
existsok=False,
copy_existing=False,
copy_data=True,
):
"""Creates a MemoryMappedTensor with the same content as another tensor.
If the tensor is already a MemoryMappedTensor the original tensor is
Expand All @@ -118,6 +126,8 @@ def from_tensor(cls, input, *, filename=None, existsok=False, copy_existing=Fals
the content to the new location is permitted. Otherwise an
exception is thown. This behaviour exists to prevent
unadvertedly duplicating data on disk.
copy_data (bool, optional): if ``True``, the content of the tensor
will be copied on the storage. Defaults to ``True``.
"""
if isinstance(input, MemoryMappedTensor):
Expand Down Expand Up @@ -205,6 +215,7 @@ def empty_like(cls, input, *, filename=None):
return cls.from_tensor(
torch.zeros((), dtype=input.dtype, device=input.device).expand_as(input),
filename=filename,
copy_data=False,
)

@classmethod
Expand All @@ -221,11 +232,10 @@ def full_like(cls, input, fill_value, *, filename=None):
is provided, a handler is used.
"""
return cls.from_tensor(
torch.zeros((), dtype=input.dtype, device=input.device)
.fill_(fill_value)
.expand_as(input),
torch.zeros((), dtype=input.dtype, device=input.device).expand_as(input),
filename=filename,
)
copy_data=False,
).fill_(fill_value)

@classmethod
def zeros_like(cls, input, *, filename=None):
Expand All @@ -242,7 +252,8 @@ def zeros_like(cls, input, *, filename=None):
return cls.from_tensor(
torch.zeros((), dtype=input.dtype, device=input.device).expand_as(input),
filename=filename,
)
copy_data=False,
).fill_(0.0)

@classmethod
def ones_like(cls, input, *, filename=None):
Expand All @@ -259,7 +270,8 @@ def ones_like(cls, input, *, filename=None):
return cls.from_tensor(
torch.ones((), dtype=input.dtype, device=input.device).expand_as(input),
filename=filename,
)
copy_data=False,
).fill_(1.0)

@classmethod
@overload
Expand Down Expand Up @@ -537,7 +549,7 @@ def __getitem__(self, item):
"isn't supported at the moment."
) from err
raise
if out.storage().data_ptr() == self.storage().data_ptr():
if out.untyped_storage().data_ptr() == self.untyped_storage().data_ptr():
out = MemoryMappedTensor(out)
out._handler = self._handler
out._filename = self._filename
Expand Down

0 comments on commit 795e39a

Please sign in to comment.