From b78eccf32c92ece08fd556e1ce8efed4e2875bc3 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 15:22:05 +0100 Subject: [PATCH] amend --- tensordict/tensordict.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index 4ed50c476..b6e397ccf 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -1764,7 +1764,10 @@ def map( @cache # noqa: B019 def _add_batch_dim(self, *, in_dim, vmap_level): if self.is_memmap(): - td = self.cpu() + if self.device.type != "cpu": + td = self.cpu() + else: + td = self.as_tensor() else: td = self out = TensorDict(