diff --git a/tensordict/_td.py b/tensordict/_td.py index 17cb8839e..3d62a19cd 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -1679,6 +1679,11 @@ def _populate( self._is_memmap = True self._is_shared = False # since they are mutually exclusive self._device = torch.device("cpu") + else: + dest._is_memmap = True + dest._is_shared = False # since they are mutually exclusive + dest._device = torch.device("cpu") + dest._is_locked = True return dest diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index 1783999c1..764f59dea 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -1735,6 +1735,7 @@ class MyClass: assert isinstance(cmemmap.x, MemoryMappedTensor) assert isinstance(cmemmap.y.x, MemoryMappedTensor) assert cmemmap.z == "foo" + assert cmemmap.is_memmap() class TestNesting: diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 774e59f52..06567a137 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -2857,6 +2857,7 @@ def test_memmap_like(self, td_name, device, use_dir, tmpdir, num_threads): for key in td.keys(True): assert td[key] is not tdmemmap[key] assert (tdmemmap == 0).all() + assert tdmemmap.is_memmap() def test_memmap_prefix(self, td_name, device, tmp_path): if td_name == "memmap_td": @@ -6167,6 +6168,7 @@ def test_memmap_like(self, tmpdir): ) tdm = td.memmap_like(prefix=tmpdir) assert tdm.names == ["a", "b", "c", "d"] + assert tdm.is_memmap() def test_memmap_td(self): td = self.memmap_td("cpu")