From c8132be4c8e0d432ee2f6bf19f5aaf65701a8c20 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 18 Oct 2023 11:59:50 -0700 Subject: [PATCH] amend --- tensordict/memmap_refact.py | 7 +++++++ test/test_tensordict.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tensordict/memmap_refact.py b/tensordict/memmap_refact.py index f42833483..6499c073a 100644 --- a/tensordict/memmap_refact.py +++ b/tensordict/memmap_refact.py @@ -112,6 +112,13 @@ def from_tensor( out.copy_(tensor) return out + @property + def filename(self): + filename = self._filename + if filename is None: + raise RuntimeError("The MemoryMappedTensor has no file associated.") + return filename + @classmethod def empty_like(cls, tensor, *, filename=None): return cls.from_tensor(torch.zeros((), dtype=tensor.dtype, device=tensor.device).expand_as(tensor), filename=filename) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 071c225ab..779707e8a 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -3038,7 +3038,7 @@ def test_repr_memmap(self, device, dtype): is_shared_tensor = False expected = f"""TensorDict( fields={{ - a: MemmapTensor(shape=torch.Size([4, 3, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, + a: MemoryMappedTensor(shape=torch.Size([4, 3, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, batch_size=torch.Size([4, 3, 2, 1]), device={str(device)}, is_shared={is_shared})"""