Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Oct 10, 2023
1 parent 9ae41b8 commit 85f596d
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions tensordict/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,8 @@ def to(
out.device = device
return out

self.device = device
out = self.clone()
out.device = device
return self

def unbind(self, dim: int) -> tuple[torch.Tensor, ...]:
Expand Down Expand Up @@ -845,7 +846,7 @@ def _cat(
def _where(condition, input, other):
device = input.device
if device != torch.device("cpu"):
input = input.to("cpu").as_tensor().to(device, non_blocking=True)
input = input.to("cpu").to(device, non_blocking=True)
else:
input = input.as_tensor()
if condition.device != device or (isinstance(other, (MemmapTensor, torch.Tensor)) and other.device != device):
Expand Down

0 comments on commit 85f596d

Please sign in to comment.