Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 22, 2024
1 parent 92beac1 commit a3f9500
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions tensordict/_torch_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from tensordict._lazy import LazyStackedTensorDict
from tensordict._td import TensorDict

from tensordict.base import NO_DEFAULT, TensorDictBase
from tensordict.base import _is_leaf_nontensor, NO_DEFAULT, TensorDictBase
from tensordict.persistent import PersistentTensorDict
from tensordict.utils import (
_check_keys,
Expand Down Expand Up @@ -95,12 +95,18 @@ def _gather_tensor(tensor, dest=None):
names = input.names if input._has_names() else None

return TensorDict(
{key: _gather_tensor(value) for key, value in input.items()},
{
key: _gather_tensor(value)
for key, value in input.items(is_leaf=_is_leaf_nontensor)
},
batch_size=index.shape,
names=names,
)
TensorDict(
{key: _gather_tensor(value, out[key]) for key, value in input.items()},
{
key: _gather_tensor(value, out.get(key))
for key, value in input.items(is_leaf=_is_leaf_nontensor)
},
batch_size=index.shape,
)
return out
Expand Down

0 comments on commit a3f9500

Please sign in to comment.