diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 6cedbf41d..781f6351b 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -529,7 +529,7 @@ def to_h5( ... def __hash__(self): - return hash((id(self), id(self._param_td))) + return hash((id(self), id(self.__dict__.get("_param_td", None)))) @_fallback def __eq__(self, other: object) -> TensorDictBase: @@ -541,7 +541,7 @@ def __ne__(self, other: object) -> TensorDictBase: def __getattr__(self, item: str) -> Any: try: - return getattr(self._param_td, item) + return getattr(self.__dict__["_param_td"], item) except AttributeError: return super().__getattr__(item)