Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Oct 10, 2023
1 parent b78eccf commit a581e5a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
4 changes: 3 additions & 1 deletion tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1765,7 +1765,9 @@ def map(
def _add_batch_dim(self, *, in_dim, vmap_level):
if self.is_memmap():
if self.device.type != "cpu":
td = self.cpu()
raise RuntimeError(
"MemmapTensor with non-cpu device are not supported in vmap ops."
)
else:
td = self.as_tensor()
else:
Expand Down
7 changes: 7 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2889,6 +2889,13 @@ def test_add_batch_dim_cache(self, td_name, device, nested):
):
fun(td)
return
if td_name == "memmap_td" and device.type != "cpu":
with pytest.raises(
RuntimeError,
match="MemmapTensor with non-cpu device are not supported in vmap ops",
):
fun(td)
return
fun(td)

if td_name == "td_params":
Expand Down

0 comments on commit a581e5a

Please sign in to comment.