diff --git a/tensordict/_contextlib.py b/tensordict/_contextlib.py index d934a47d9..ac5f74b18 100644 --- a/tensordict/_contextlib.py +++ b/tensordict/_contextlib.py @@ -20,9 +20,9 @@ import numpy as np try: - from torch.compiler import is_dynamo_compiling + from torch.compiler import is_compiling except ImportError: # torch 2.0 - from torch._dynamo import is_compiling as is_dynamo_compiling + from torch._dynamo import is_compiling # Used for annotating the decorator usage of _DecoratorContextManager (e.g., @@ -330,7 +330,7 @@ def _reverse_squeeze(self, args, kwargs, out): def _reverse_to_module(self, args, kwargs, out): try: - with out.unlock_() if not is_dynamo_compiling() else contextlib.nullcontext(): + with out.unlock_() if not is_compiling() else contextlib.nullcontext(): return self.to_module(*args, **kwargs, swap_dest=out) except AttributeError: # This is a bit unsafe but we assume that out won't have an unlock_() if it's not a TD diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 07e355746..f305fcddd 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -58,6 +58,12 @@ from tensordict.utils import Buffer +try: + from torch.compiler import is_compiling +except ImportError: + from torch._dynamo import is_compiling + + def _apply_leaves(data, fn): if isinstance(data, TensorDict): with data.unlock_(): @@ -101,12 +107,18 @@ def _get_args_dict(func, args, kwargs): def _maybe_make_param(tensor): - if ( - isinstance(tensor, (Tensor, ftdim.Tensor)) - and not isinstance(tensor, nn.Parameter) - and tensor.dtype in (torch.float, torch.double, torch.half) + if isinstance(tensor, (Tensor, ftdim.Tensor)) and not isinstance( + tensor, (nn.Parameter, Buffer, BufferLegacy) ): - tensor = nn.Parameter(tensor) + if tensor.dtype in (torch.float, torch.double, torch.half): + tensor = nn.Parameter(tensor) + elif not is_batchedtensor(tensor): + # convert all non-parameters to buffers + # dataptr = tensor.data.data_ptr() + tensor = Buffer(tensor) + else: + # We want to keep the grad_fn of tensors, e.g. param.expand(10) should point to the original param + tensor = BufferLegacy(tensor) return tensor @@ -250,37 +262,43 @@ def new_func(self, *args, **kwargs): class TensorDictParams(TensorDictBase, nn.Module): - r"""Holds a TensorDictBase instance full of parameters. + r"""A Wrapper for TensorDictBase with Parameter Exposure. - This class exposes the contained parameters to a parent nn.Module - such that iterating over the parameters of the module also iterates over - the leaves of the tensordict. + This class is designed to hold a `TensorDictBase` instance that contains parameters, making them accessible to a + parent :class:`~torch.nn.Module`. This allows for seamless integration of tensordict parameters into PyTorch modules, + enabling operations like parameter iteration and optimization. - Indexing works exactly as the indexing of the wrapped tensordict. - The parameter names will be registered within this module using :meth:`~.TensorDict.flatten_keys("_")`. - Therefore, the result of :meth:`~.named_parameters()` and the content of the - tensordict will differ slightly in term of key names. + Key Features: - Any operation that sets a tensor in the tensordict will be augmented by - a :class:`torch.nn.Parameter` conversion. + - Parameter Exposure: Parameters within the tensordict are exposed to the parent module, allowing them to be included + in operations like `named_parameters()`. + - Indexing: Indexing works similarly to the wrapped tensordict. However, parameter names (in :meth:`~.named_parameters`) are registered using + `TensorDict.flatten_keys("_")`, which may result in different key names compared to the tensordict content. + - Automatic Conversion: Any tensor set in the tensordict is automatically converted to a :class:`torch.nn.Parameter`, + unless specified otherwise through the :attr:`no_convert` keyword argument. - Args: - parameters (TensorDictBase): a tensordict to represent as parameters. - Values will be converted to parameters unless ``no_convert=True``. + Args + parameters (TensorDictBase or dict): The tensordict to represent as parameters. Values are converted to + parameters unless `no_convert=True`. If a `dict` is provided, it is wrapped in a `TensorDict` instance. + Keyword arguments can also be used. Keyword Args: - no_convert (bool): if ``True``, no conversion to ``nn.Parameter`` will - occur at construction and after (unless the ``no_convert`` attribute is changed). - If ``no_convert`` is ``True`` and if non-parameters are present, they - will be registered as buffers. - Defaults to ``False``. - lock (bool): if ``True``, the tensordict hosted by TensorDictParams will - be locked. This can be useful to avoid unwanted modifications, but - also restricts the operations that can be done over the object (and - can have significant performance impact when `unlock_()` is required). - Defaults to ``False``. - - Examples: + no_convert (bool): If `True`, no conversion to `nn.Parameter` occurs and all non-parameter, non-buffer tensors + will be converted to a :class:`~torch.nn.Buffer` instance. + If ``False``, all tensors with non-integer dtypes will be converted to :class:`~torch.nn.Parameter` + whereas integer dtypes will be converted to :class:`~torch.nn.Buffer` instances. + Defaults to `False`. + lock (bool): If `True`, the tensordict hosted by `TensorDictParams` is locked, preventing modifications and + potentially impacting performance when `unlock_()` is required. + Defaults to `False`. + + .. warning:: Because the inner tensordict isn't copied or locked by default, registering the tensordict + in a ``TensorDictParams`` and modifying its content afterwards will __not__ update the values within + the ``TensorDictParams`` :meth:`.parameters` and :meth:`~.buffers` sequences. + + **kwargs: Key-value pairs to populate the `TensorDictParams`. Exclusive with the `parameters` input. + + Examples >>> from torch import nn >>> from tensordict import TensorDict >>> module = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 4)) @@ -312,39 +330,100 @@ class TensorDictParams(TensorDictBase, nn.Module): ... super().__init__() ... self.params = params >>> m = CustomModule(p) - >>> # the wrapper supports assignment and values are turned in Parameter + >>> # The wrapper supports assignment, and values are converted to Parameters >>> m.params['other'] = torch.randn(3) >>> assert isinstance(m.params['other'], nn.Parameter) """ def __init__( - self, parameters: TensorDictBase, *, no_convert=False, lock: bool = False + self, + parameters: TensorDictBase | dict | None = None, + *, + no_convert=False, + lock: bool = False, + **kwargs, ): super().__init__() - if isinstance(parameters, TensorDictParams): - parameters = parameters._param_td - self._param_td = parameters + if parameters is None: + parameters = kwargs + elif kwargs: + raise TypeError( + f"parameters cannot be passed along with extra keyword arguments, but got {kwargs.keys()} extra args." + ) + + params = None + buffers = None + if isinstance(parameters, dict): + parameters = TensorDict(parameters) + elif isinstance(parameters, TensorDictParams): + params = dict(parameters._parameters) + buffers = dict(parameters._buffers) + parameters = parameters._param_td.copy().lock_() + no_convert = "skip" + self.no_convert = no_convert if no_convert != "skip": if not no_convert: func = _maybe_make_param else: func = _maybe_make_param_or_buffer - self._param_td = _apply_leaves(self._param_td, lambda x: func(x)) + self._param_td = _apply_leaves(parameters, lambda x: func(x)) + else: + self._param_td = parameters + self._lock_content = lock if lock: self._param_td.lock_() - self._reset_params() + self._reset_params(params=params, buffers=buffers) self._is_locked = False self._locked_tensordicts = [] self._get_post_hook = [] @classmethod def _new_unsafe( - cls, parameters: TensorDictBase, *, no_convert=False, lock: bool = False + cls, + parameters: TensorDictBase, + *, + no_convert=False, + lock: bool = False, + params: dict | None = None, + buffers: dict | None = None, + **kwargs, ): - return TensorDictParams(parameters, no_convert="skip", lock=lock) + if is_compiling(): + return TensorDictParams(parameters, no_convert="skip", lock=lock) + + self = TensorDictParams.__new__(cls) + nn.Module.__init__(self) + + if parameters is None: + parameters = kwargs + elif kwargs: + raise TypeError( + f"parameters cannot be passed along with extra keyword arguments, but got {kwargs.keys()} extra args." + ) + + if isinstance(parameters, dict): + parameters = TensorDict._new_unsafe(parameters) + elif isinstance(parameters, TensorDictParams): + params = dict(parameters._parameters) + buffers = dict(parameters._buffers) + parameters = parameters._param_td + no_convert = "skip" + + self._param_td = parameters + self.no_convert = no_convert + if no_convert != "skip": + raise RuntimeError("_new_unsafe requires no_convert to be set to 'skip'") + self._lock_content = lock + if lock: + self._param_td.lock_() + self._reset_params(params=params, buffers=buffers) + self._is_locked = False + self._locked_tensordicts = [] + self._get_post_hook = [] + return self def __iter__(self): yield from self._param_td.__iter__() @@ -363,24 +442,35 @@ def _apply_get_post_hook(self, val): val = new_val return val - def _reset_params(self): + def _reset_params(self, params: dict | None = None, buffers: dict | None = None): parameters = self._param_td - param_keys = [] - params = [] - buffer_keys = [] - buffers = [] - for key, value in parameters.items(True, True): - # flatten key - if isinstance(key, tuple): - key = ".".join(key) - if isinstance(value, nn.Parameter): - param_keys.append(key) - params.append(value) - else: - buffer_keys.append(key) - buffers.append(value) - self.__dict__["_parameters"] = dict(zip(param_keys, params)) - self.__dict__["_buffers"] = dict(zip(buffer_keys, buffers)) + + self._parameters.clear() + self._buffers.clear() + + if (params is not None) ^ (buffers is not None): + raise RuntimeError("both params and buffers must either be None or not.") + elif params is None: + param_keys = [] + params = [] + buffer_keys = [] + buffers = [] + for key, value in parameters.items(True, True): + # flatten key + if isinstance(key, tuple): + key = ".".join(key) + if isinstance(value, nn.Parameter): + param_keys.append(key) + params.append(value) + else: + buffer_keys.append(key) + buffers.append(value) + + self._parameters.update(dict(zip(param_keys, params))) + self._buffers.update(dict(zip(buffer_keys, buffers))) + else: + self._parameters.update(params) + self._buffers.update(buffers) @classmethod def __torch_function__( @@ -618,7 +708,12 @@ def _clone(self, recurse: bool = True) -> TensorDictBase: """ if not recurse: - return TensorDictParams(self._param_td._clone(False), no_convert="skip") + return TensorDictParams._new_unsafe( + self._param_td._clone(False), + no_convert="skip", + params=dict(self._parameters), + buffers=dict(self._buffers), + ) memo = {} @@ -897,6 +992,7 @@ def _propagate_unlock(self): if not self._lock_content: return self._param_td._propagate_unlock() + return [] unlock_ = TensorDict.unlock_ lock_ = TensorDict.lock_ diff --git a/test/test_nn.py b/test/test_nn.py index 5c32ac6e3..ba9faf93d 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1896,6 +1896,51 @@ def test_td_params(self): assert (m.params == params).all() assert (params == m.params).all() + def test_constructors(self): + class MyModule(nn.Module): + def __init__(self): + super().__init__() + self.register_parameter( + "param", nn.Parameter(torch.randn(3, requires_grad=True)) + ) + self.register_buffer("buf", torch.randn(3)) + self.register_buffer("buf_int", torch.randint(3, ())) + + td = TensorDict.from_module(MyModule()) + assert not isinstance(td, TensorDictParams) + td = TensorDictParams(td) + assert isinstance(td, TensorDictParams) + assert isinstance(td["param"], nn.Parameter) + assert isinstance(td["buf"], nn.Parameter) + assert isinstance(td["buf_int"], Buffer) + td = TensorDict.from_module(MyModule()) + assert not isinstance(td, TensorDictParams) + td = TensorDictParams(td, no_convert=True) + assert isinstance(td, TensorDictParams) + assert isinstance(td["param"], nn.Parameter) + assert isinstance(td["buf"], Buffer) + assert isinstance(td["buf_int"], Buffer) + + td = TensorDict.from_module(MyModule(), as_module=True) + assert isinstance(td, TensorDictParams) + assert isinstance(td["param"], nn.Parameter) + assert isinstance(td["buf"], Buffer) + assert isinstance(td["buf_int"], Buffer) + + tdparams = TensorDictParams(a=0, b=1.0) + assert isinstance(tdparams["a"], Buffer) + assert isinstance(tdparams["b"], nn.Parameter) + + tdparams = TensorDictParams({"a": 0, "b": 1.0}) + assert isinstance(tdparams["a"], Buffer) + assert isinstance(tdparams["b"], nn.Parameter) + tdparams_copy = tdparams.copy() + + def assert_is_identical(a, b): + assert a is b + + tdparams.apply(assert_is_identical, tdparams_copy, filter_empty=True) + def test_td_params_cast(self): params = self._get_params() p = TensorDictParams(params)