From 4949efa80102a9af61130f316657a7b9e14f433f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 3 Dec 2024 11:45:41 +0000 Subject: [PATCH] [BugFix] Do not unlock td if it's not locked in TDParams (for compile compat) ghstack-source-id: 9b6923f9c219e12af5560c97c1c6c58ed7870a8a Pull Request resolved: https://github.com/pytorch/tensordict/pull/1125 --- tensordict/nn/params.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 125e0e9c5..03c9632a1 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -9,6 +9,7 @@ import re import weakref from concurrent.futures import Future, ThreadPoolExecutor +from contextlib import nullcontext from copy import copy from functools import wraps from typing import Any, Callable, Dict, Iterator, List, OrderedDict, Sequence, Type @@ -171,7 +172,11 @@ def new_func(_self, *args, **kwargs): if _self.is_locked: # if the root (TensorDictParams) is locked, we still want to raise an exception raise RuntimeError(_LOCK_ERROR) - with _self._param_td.unlock_(): + with ( + _self._param_td.unlock_() + if _self._param_td.is_locked + else nullcontext() + ): meth = getattr(_self._param_td, name) out = meth(*args, **kwargs) _self._reset_params()