diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index df0e2150b..8615704a7 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -2587,6 +2587,7 @@ def memmap_like(self, prefix: str | None = None) -> T: prefix / f"{key}.meta.pt", ) tensordict._is_memmap = True + self._is_shared = False tensordict._device = torch.device("cpu") tensordict.lock_() return tensordict @@ -4423,8 +4424,9 @@ def _set_str( inplace: bool, validated: bool, ) -> T: - best_attempt = inplace is BEST_ATTEMPT_INPLACE - inplace = self._convert_inplace(inplace, key) + if inplace is not False: + best_attempt = inplace is BEST_ATTEMPT_INPLACE + inplace = self._convert_inplace(inplace, key) if not validated: value = self._validate_value(value, check_shape=True) if not inplace: @@ -4702,6 +4704,7 @@ def memmap_( prefix / f"{key}.meta.pt", ) self._is_memmap = True + self._is_shared = False self._device = torch.device("cpu") self.lock_() return self @@ -7707,6 +7710,7 @@ def memmap_(self, prefix: str | None = None, copy_existing: bool = False) -> T: copy_existing=copy_existing, ) self._is_memmap = True + self._is_shared = False self._device = torch.device("cpu") self.lock_() return self @@ -7728,6 +7732,7 @@ def memmap_like( tds.append(td_like) td_out = torch.stack(tds, self.stack_dim) self._is_memmap = True + self._is_shared = False self._device = torch.device("cpu") td_out.lock_() return td_out @@ -8447,6 +8452,7 @@ def memmap_( torch.save(metadata, prefix / "meta.pt") self._is_memmap = True + self._is_shared = False self._device = torch.device("cpu") self.lock_() return self diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 669174a4e..aea83fa2b 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -31,6 +31,8 @@ def get_available_devices(): if n_cuda > 0: for i in range(n_cuda): devices += [torch.device(f"cuda:{i}")] + if i == 1: + break return devices diff --git a/test/test_memmap.py b/test/test_memmap.py index 081309aba..4cef90d71 100644 --- a/test/test_memmap.py +++ b/test/test_memmap.py @@ -73,17 +73,7 @@ def test_memmap_same_device_as_tensor(device): """ t = torch.tensor([1], device=device) m = MemoryMappedTensor.from_tensor(t) - assert m.device == torch.device(device) - for other_device in get_available_devices(): - if other_device != device: - with pytest.raises( - RuntimeError, - match="Expected all tensors to be on the same device, " - + "but found at least two devices", - ): - assert torch.all(m + torch.ones([3, 4], device=other_device) == 1) - m = m.to(other_device) - assert m.device == torch.device(other_device) + assert m.device == torch.device("cpu") @pytest.mark.parametrize("device", get_available_devices())