Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Oct 10, 2023
1 parent 2fb64be commit 7665531
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion tensordict/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,12 @@ def _cat(


def _where(condition, input, other):
return torch.where(condition=condition, input=input.as_tensor(), other=other)
device = input.device
if device != torch.device("cpu"):
input = input.to("cpu").as_tensor().to(device)
else:
input = input.as_tensor()
return torch.where(condition=condition, input=input, other=other)


implements_for_memmap(torch.where)(_where)
Expand Down

0 comments on commit 7665531

Please sign in to comment.