From a3f9500e1e4204c77ef143b14845c5f427a3681d Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 22 Jan 2024 18:12:46 +0000 Subject: [PATCH] amend --- tensordict/_torch_func.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index 67d06a972..b267fbc86 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -13,7 +13,7 @@ from tensordict._lazy import LazyStackedTensorDict from tensordict._td import TensorDict -from tensordict.base import NO_DEFAULT, TensorDictBase +from tensordict.base import _is_leaf_nontensor, NO_DEFAULT, TensorDictBase from tensordict.persistent import PersistentTensorDict from tensordict.utils import ( _check_keys, @@ -95,12 +95,18 @@ def _gather_tensor(tensor, dest=None): names = input.names if input._has_names() else None return TensorDict( - {key: _gather_tensor(value) for key, value in input.items()}, + { + key: _gather_tensor(value) + for key, value in input.items(is_leaf=_is_leaf_nontensor) + }, batch_size=index.shape, names=names, ) TensorDict( - {key: _gather_tensor(value, out[key]) for key, value in input.items()}, + { + key: _gather_tensor(value, out.get(key)) + for key, value in input.items(is_leaf=_is_leaf_nontensor) + }, batch_size=index.shape, ) return out