diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 125e0e9c5..03c9632a1 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -9,6 +9,7 @@ import re import weakref from concurrent.futures import Future, ThreadPoolExecutor +from contextlib import nullcontext from copy import copy from functools import wraps from typing import Any, Callable, Dict, Iterator, List, OrderedDict, Sequence, Type @@ -171,7 +172,11 @@ def new_func(_self, *args, **kwargs): if _self.is_locked: # if the root (TensorDictParams) is locked, we still want to raise an exception raise RuntimeError(_LOCK_ERROR) - with _self._param_td.unlock_(): + with ( + _self._param_td.unlock_() + if _self._param_td.is_locked + else nullcontext() + ): meth = getattr(_self._param_td, name) out = meth(*args, **kwargs) _self._reset_params()