diff --git a/tensordict/_contextlib.py b/tensordict/_contextlib.py index d934a47d9..ac5f74b18 100644 --- a/tensordict/_contextlib.py +++ b/tensordict/_contextlib.py @@ -20,9 +20,9 @@ import numpy as np try: - from torch.compiler import is_dynamo_compiling + from torch.compiler import is_compiling except ImportError: # torch 2.0 - from torch._dynamo import is_compiling as is_dynamo_compiling + from torch._dynamo import is_compiling # Used for annotating the decorator usage of _DecoratorContextManager (e.g., @@ -330,7 +330,7 @@ def _reverse_squeeze(self, args, kwargs, out): def _reverse_to_module(self, args, kwargs, out): try: - with out.unlock_() if not is_dynamo_compiling() else contextlib.nullcontext(): + with out.unlock_() if not is_compiling() else contextlib.nullcontext(): return self.to_module(*args, **kwargs, swap_dest=out) except AttributeError: # This is a bit unsafe but we assume that out won't have an unlock_() if it's not a TD diff --git a/tensordict/base.py b/tensordict/base.py index 39729eba4..31cd9ae1e 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3070,6 +3070,8 @@ def _legacy_permute( # Cache functionality def _erase_cache(self): + if is_dynamo_compiling(): + return self._cache = None # Dim names functionality diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 00d984330..151080538 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -58,6 +58,12 @@ from tensordict.utils import Buffer +try: + from torch.compiler import is_compiling +except ImportError: + from torch._dynamo import is_compiling + + def _apply_leaves(data, fn): if isinstance(data, TensorDict): with data.unlock_(): @@ -114,7 +120,7 @@ def _maybe_make_param_or_buffer(tensor): if ( isinstance(tensor, (Tensor, ftdim.Tensor)) and not isinstance(tensor, (nn.Parameter, Buffer)) - and tensor.dtype in (torch.float, torch.double, torch.half) + # and tensor.dtype in (torch.float, torch.double, torch.half) ): if not tensor.requires_grad and not is_batchedtensor(tensor): # convert all non-parameters to buffers @@ -267,8 +273,10 @@ class TensorDictParams(TensorDictBase, nn.Module): a :class:`torch.nn.Parameter` conversion. Args: - parameters (TensorDictBase): a tensordict to represent as parameters. + parameters (TensorDictBase or dict): a tensordict to represent as parameters. Values will be converted to parameters unless ``no_convert=True``. + If a dict is provided, it will be first wrapped to a :class:`~tensordict.TensorDict` + instance. Keyword arguments can be used instead. Keyword Args: no_convert (bool): if ``True``, no conversion to ``nn.Parameter`` will @@ -281,6 +289,8 @@ class TensorDictParams(TensorDictBase, nn.Module): also restricts the operations that can be done over the object (and can have significant performance impact when `unlock_()` is required). Defaults to ``False``. + **kwargs: Key-value pairs to populate the ``TensorDictParams``. Exclusive with + the :attr:`parameters` input. Examples: >>> from torch import nn @@ -321,32 +331,93 @@ class TensorDictParams(TensorDictBase, nn.Module): """ def __init__( - self, parameters: TensorDictBase, *, no_convert=False, lock: bool = False + self, + parameters: TensorDictBase | dict | None = None, + *, + no_convert=False, + lock: bool = False, + **kwargs, ): super().__init__() - if isinstance(parameters, TensorDictParams): - parameters = parameters._param_td - self._param_td = parameters + if parameters is None: + parameters = kwargs + elif kwargs: + raise TypeError( + f"parameters cannot be passed along with extra keyword arguments, but got {kwargs.keys()} extra args." + ) + + params = None + buffers = None + if isinstance(parameters, dict): + parameters = TensorDict(parameters) + elif isinstance(parameters, TensorDictParams): + params = dict(parameters._parameters) + buffers = dict(parameters._buffers) + parameters = parameters._param_td.copy().lock_() + no_convert = "skip" + self.no_convert = no_convert if no_convert != "skip": if not no_convert: func = _maybe_make_param else: func = _maybe_make_param_or_buffer - self._param_td = _apply_leaves(self._param_td, lambda x: func(x)) + self._param_td = _apply_leaves(parameters, lambda x: func(x)) + else: + self._param_td = parameters + self._lock_content = lock if lock: self._param_td.lock_() - self._reset_params() + self._reset_params(params=params, buffers=buffers) self._is_locked = False self._locked_tensordicts = [] self._get_post_hook = [] @classmethod def _new_unsafe( - cls, parameters: TensorDictBase, *, no_convert=False, lock: bool = False + cls, + parameters: TensorDictBase, + *, + no_convert=False, + lock: bool = False, + params: dict | None = None, + buffers: dict | None = None, + **kwargs, ): - return TensorDictParams(parameters, no_convert="skip", lock=lock) + if is_compiling(): + return TensorDictParams(parameters, no_convert="skip", lock=lock) + + self = TensorDictParams.__new__(cls) + nn.Module.__init__(self) + + if parameters is None: + parameters = kwargs + elif kwargs: + raise TypeError( + f"parameters cannot be passed along with extra keyword arguments, but got {kwargs.keys()} extra args." + ) + + if isinstance(parameters, dict): + parameters = TensorDict._new_unsafe(parameters) + elif isinstance(parameters, TensorDictParams): + params = dict(parameters._parameters) + buffers = dict(parameters._buffers) + parameters = parameters._param_td + no_convert = "skip" + + self._param_td = parameters + self.no_convert = no_convert + if no_convert != "skip": + raise RuntimeError("_new_unsafe requires no_convert to be set to 'skip'") + self._lock_content = lock + if lock: + self._param_td.lock_() + self._reset_params(params=params, buffers=buffers) + self._is_locked = False + self._locked_tensordicts = [] + self._get_post_hook = [] + return self def __iter__(self): yield from self._param_td.__iter__() @@ -365,24 +436,35 @@ def _apply_get_post_hook(self, val): val = new_val return val - def _reset_params(self): + def _reset_params(self, params: dict | None = None, buffers: dict | None = None): parameters = self._param_td - param_keys = [] - params = [] - buffer_keys = [] - buffers = [] - for key, value in parameters.items(True, True): - # flatten key - if isinstance(key, tuple): - key = ".".join(key) - if isinstance(value, nn.Parameter): - param_keys.append(key) - params.append(value) - else: - buffer_keys.append(key) - buffers.append(value) - self.__dict__["_parameters"] = dict(zip(param_keys, params)) - self.__dict__["_buffers"] = dict(zip(buffer_keys, buffers)) + + self._parameters.clear() + self._buffers.clear() + + if (params is not None) ^ (buffers is not None): + raise RuntimeError("both params and buffers must either be None or not.") + elif params is None: + param_keys = [] + params = [] + buffer_keys = [] + buffers = [] + for key, value in parameters.items(True, True): + # flatten key + if isinstance(key, tuple): + key = ".".join(key) + if isinstance(value, nn.Parameter): + param_keys.append(key) + params.append(value) + else: + buffer_keys.append(key) + buffers.append(value) + + self._parameters.update(dict(zip(param_keys, params))) + self._buffers.update(dict(zip(buffer_keys, buffers))) + else: + self._parameters.update(params) + self._buffers.update(buffers) @classmethod def __torch_function__( @@ -620,7 +702,12 @@ def _clone(self, recurse: bool = True) -> TensorDictBase: """ if not recurse: - return TensorDictParams(self._param_td._clone(False), no_convert="skip") + return TensorDictParams._new_unsafe( + self._param_td._clone(False), + no_convert="skip", + params=dict(self._parameters), + buffers=dict(self._buffers), + ) memo = {} @@ -899,6 +986,7 @@ def _propagate_unlock(self): if not self._lock_content: return self._param_td._propagate_unlock() + return [] unlock_ = TensorDict.unlock_ lock_ = TensorDict.lock_ diff --git a/tensordict/utils.py b/tensordict/utils.py index 1e00d57a3..416f0d538 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -67,11 +67,11 @@ except ImportError: _has_funcdim = False try: - from torch.compiler import assume_constant_result, is_dynamo_compiling + from torch.compiler import assume_constant_result, is_compiling except ImportError: # torch 2.0 from torch._dynamo import ( assume_constant_result, - is_compiling as is_dynamo_compiling, + is_compiling, ) if TYPE_CHECKING: @@ -863,7 +863,7 @@ def _is_tensorclass(cls: type) -> bool: out = _TENSORCLASS_MEMO.get(cls, None) if out is None: out = getattr(cls, "_is_tensorclass", False) - if not is_dynamo_compiling(): + if not is_compiling(): _TENSORCLASS_MEMO[cls] = out return out @@ -1119,7 +1119,7 @@ def cache(fun): @wraps(fun) def newfun(_self: "TensorDictBase", *args, **kwargs): - if not _self.is_locked or is_dynamo_compiling(): + if not _self.is_locked or is_compiling(): return fun(_self, *args, **kwargs) cache = _self._cache if cache is None: @@ -1359,7 +1359,7 @@ def _parse_to(*args, **kwargs): num_threads = kwargs.pop("num_threads", None) other = kwargs.pop("other", None) inplace = kwargs.pop("inplace", False) - if not is_dynamo_compiling(): + if not is_compiling(): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( *args, **kwargs ) @@ -1733,7 +1733,7 @@ def _check_keys( is_leaf=_is_leaf_nontensor, ) # TODO: compile doesn't like set() over an arbitrary object - if is_dynamo_compiling(): + if is_compiling(): keys = {k for k in keys} # noqa: C416 else: keys: set[str] = set(keys) @@ -1746,7 +1746,7 @@ def _check_keys( if not strict: keys = keys.intersection(k) else: - if is_dynamo_compiling(): + if is_compiling(): k = {v for v in k} # noqa: C416 else: k = set(k) @@ -2015,7 +2015,7 @@ def _getitem_batch_size(batch_size, index): continue elif isinstance(idx, slice): batch = batch_size[count] - if is_dynamo_compiling(): + if is_compiling(): out.append(len(range(*_slice_indices(idx, batch)))) else: out.append(len(range(*idx.indices(batch)))) @@ -2447,7 +2447,7 @@ def is_non_tensor(data): def _is_non_tensor(cls: type): out = None - is_dynamo = is_dynamo_compiling() + is_dynamo = is_compiling() if not is_dynamo: out = _NON_TENSOR_MEMO.get(cls) if out is None: @@ -2503,7 +2503,7 @@ def new_func(self): def _unravel_key_to_tuple(key): - if not is_dynamo_compiling(): + if not is_compiling(): return _unravel_key_to_tuple_cpp(key) if isinstance(key, str): return (key,) @@ -2524,7 +2524,7 @@ def unravel_key(key): ("a", "b") """ - if not is_dynamo_compiling(): + if not is_compiling(): return unravel_key_cpp(key) if isinstance(key, str): return key @@ -2537,14 +2537,14 @@ def unravel_key(key): def unravel_keys(*keys): """Unravels a sequence of keys.""" - if not is_dynamo_compiling(): + if not is_compiling(): return unravel_keys_cpp(*keys) return tuple(unravel_key(key) for key in keys) def unravel_key_list(keys): """Unravels a list of keys.""" - if not is_dynamo_compiling(): + if not is_compiling(): return unravel_key_list_cpp(keys) return [unravel_key(key) for key in keys] @@ -2823,11 +2823,11 @@ def __init__(self, default=None): self._lock = threading.Lock() def get_mode(self) -> Any | None: - cm = self._lock if not is_dynamo_compiling() else nullcontext() + cm = self._lock if not is_compiling() else nullcontext() with cm: return self._mode def set_mode(self, type: Any | None) -> None: - cm = self._lock if not is_dynamo_compiling() else nullcontext() + cm = self._lock if not is_compiling() else nullcontext() with cm: self._mode = type