diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index df7bad0e6..00d984330 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -235,7 +235,7 @@ def new_func(self, *args, **kwargs): if out is self._param_td: return self if not isinstance(out, TensorDictParams): - out = TensorDictParams(out, no_convert=True) + out = TensorDictParams(out, no_convert="skip") out.no_convert = self.no_convert return out @@ -328,11 +328,12 @@ def __init__( parameters = parameters._param_td self._param_td = parameters self.no_convert = no_convert - if not no_convert: - func = _maybe_make_param - else: - func = _maybe_make_param_or_buffer - self._param_td = _apply_leaves(self._param_td, lambda x: func(x)) + if no_convert != "skip": + if not no_convert: + func = _maybe_make_param + else: + func = _maybe_make_param_or_buffer + self._param_td = _apply_leaves(self._param_td, lambda x: func(x)) self._lock_content = lock if lock: self._param_td.lock_() @@ -341,6 +342,12 @@ def __init__( self._locked_tensordicts = [] self._get_post_hook = [] + @classmethod + def _new_unsafe( + cls, parameters: TensorDictBase, *, no_convert=False, lock: bool = False + ): + return TensorDictParams(parameters, no_convert="skip", lock=lock) + def __iter__(self): yield from self._param_td.__iter__() @@ -613,7 +620,7 @@ def _clone(self, recurse: bool = True) -> TensorDictBase: """ if not recurse: - return TensorDictParams(self._param_td._clone(False), no_convert=True) + return TensorDictParams(self._param_td._clone(False), no_convert="skip") memo = {} @@ -631,7 +638,7 @@ def _clone(tensor, memo=memo): memo[tensor] = result return result - return TensorDictParams(self._param_td.apply(_clone), no_convert=True) + return TensorDictParams(self._param_td.apply(_clone), no_convert="skip") @_fallback def chunk(self, chunks: int, dim: int = 0) -> tuple[TensorDictBase, ...]: ...