diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index 4ed50c476..b6e397ccf 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -1764,7 +1764,10 @@ def map( @cache # noqa: B019 def _add_batch_dim(self, *, in_dim, vmap_level): if self.is_memmap(): - td = self.cpu() + if self.device.type != "cpu": + td = self.cpu() + else: + td = self.as_tensor() else: td = self out = TensorDict(