diff --git a/tensordict/_pytree.py b/tensordict/_pytree.py index ad043a807..0ae93682c 100644 --- a/tensordict/_pytree.py +++ b/tensordict/_pytree.py @@ -11,7 +11,7 @@ TensorDict, ) -from torch.utils._pytree import _register_pytree_node, Context +from torch.utils._pytree import Context, register_pytree_node PYTREE_REGISTERED_TDS = ( LazyStackedTensorDict, @@ -99,7 +99,7 @@ def _tensordictdict_unflatten(values: List[Any], context: Context) -> Dict[Any, for cls in PYTREE_REGISTERED_TDS: - _register_pytree_node( + register_pytree_node( cls, _tensordict_flatten, _tensordictdict_unflatten,