Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Nov 14, 2023
1 parent 7413381 commit 029de18
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 13 deletions.
10 changes: 8 additions & 2 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2587,6 +2587,7 @@ def memmap_like(self, prefix: str | None = None) -> T:
prefix / f"{key}.meta.pt",
)
tensordict._is_memmap = True
self._is_shared = False
tensordict._device = torch.device("cpu")
tensordict.lock_()
return tensordict
Expand Down Expand Up @@ -4423,8 +4424,9 @@ def _set_str(
inplace: bool,
validated: bool,
) -> T:
best_attempt = inplace is BEST_ATTEMPT_INPLACE
inplace = self._convert_inplace(inplace, key)
if inplace is not False:
best_attempt = inplace is BEST_ATTEMPT_INPLACE
inplace = self._convert_inplace(inplace, key)
if not validated:
value = self._validate_value(value, check_shape=True)
if not inplace:
Expand Down Expand Up @@ -4702,6 +4704,7 @@ def memmap_(
prefix / f"{key}.meta.pt",
)
self._is_memmap = True
self._is_shared = False
self._device = torch.device("cpu")
self.lock_()
return self
Expand Down Expand Up @@ -7707,6 +7710,7 @@ def memmap_(self, prefix: str | None = None, copy_existing: bool = False) -> T:
copy_existing=copy_existing,
)
self._is_memmap = True
self._is_shared = False
self._device = torch.device("cpu")
self.lock_()
return self
Expand All @@ -7728,6 +7732,7 @@ def memmap_like(
tds.append(td_like)
td_out = torch.stack(tds, self.stack_dim)
self._is_memmap = True
self._is_shared = False
self._device = torch.device("cpu")
td_out.lock_()
return td_out
Expand Down Expand Up @@ -8447,6 +8452,7 @@ def memmap_(
torch.save(metadata, prefix / "meta.pt")

self._is_memmap = True
self._is_shared = False
self._device = torch.device("cpu")
self.lock_()
return self
Expand Down
2 changes: 2 additions & 0 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def get_available_devices():
if n_cuda > 0:
for i in range(n_cuda):
devices += [torch.device(f"cuda:{i}")]
if i == 1:
break
return devices


Expand Down
12 changes: 1 addition & 11 deletions test/test_memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,7 @@ def test_memmap_same_device_as_tensor(device):
"""
t = torch.tensor([1], device=device)
m = MemoryMappedTensor.from_tensor(t)
assert m.device == torch.device(device)
for other_device in get_available_devices():
if other_device != device:
with pytest.raises(
RuntimeError,
match="Expected all tensors to be on the same device, "
+ "but found at least two devices",
):
assert torch.all(m + torch.ones([3, 4], device=other_device) == 1)
m = m.to(other_device)
assert m.device == torch.device(other_device)
assert m.device == torch.device("cpu")


@pytest.mark.parametrize("device", get_available_devices())
Expand Down

0 comments on commit 029de18

Please sign in to comment.