Skip to content

Commit

Permalink
[Refactor] Upgrade pytree import (#573)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Nov 24, 2023
1 parent 42bf143 commit df1ee89
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tensordict/_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
TensorDict,
)

from torch.utils._pytree import _register_pytree_node, Context
from torch.utils._pytree import Context, register_pytree_node

PYTREE_REGISTERED_TDS = (
LazyStackedTensorDict,
Expand Down Expand Up @@ -99,7 +99,7 @@ def _tensordictdict_unflatten(values: List[Any], context: Context) -> Dict[Any,


for cls in PYTREE_REGISTERED_TDS:
_register_pytree_node(
register_pytree_node(
cls,
_tensordict_flatten,
_tensordictdict_unflatten,
Expand Down

0 comments on commit df1ee89

Please sign in to comment.