Skip to content

Commit

Permalink
[Refactor] Avoid TDParams parameters and buffers construction when ob…
Browse files Browse the repository at this point in the history
…vious

ghstack-source-id: bd701ecfaf68605801a215d3cd9d49268b888bb3
Pull Request resolved: #1100
  • Loading branch information
vmoens committed Nov 25, 2024
1 parent 270d7ba commit df870ef
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 60 deletions.
6 changes: 3 additions & 3 deletions tensordict/_contextlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import numpy as np

try:
from torch.compiler import is_dynamo_compiling
from torch.compiler import is_compiling
except ImportError: # torch 2.0
from torch._dynamo import is_compiling as is_dynamo_compiling
from torch._dynamo import is_compiling


# Used for annotating the decorator usage of _DecoratorContextManager (e.g.,
Expand Down Expand Up @@ -330,7 +330,7 @@ def _reverse_squeeze(self, args, kwargs, out):

def _reverse_to_module(self, args, kwargs, out):
try:
with out.unlock_() if not is_dynamo_compiling() else contextlib.nullcontext():
with out.unlock_() if not is_compiling() else contextlib.nullcontext():
return self.to_module(*args, **kwargs, swap_dest=out)
except AttributeError:
# This is a bit unsafe but we assume that out won't have an unlock_() if it's not a TD
Expand Down
210 changes: 153 additions & 57 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@
from tensordict.utils import Buffer


try:
from torch.compiler import is_compiling
except ImportError:
from torch._dynamo import is_compiling


def _apply_leaves(data, fn):
if isinstance(data, TensorDict):
with data.unlock_():
Expand Down Expand Up @@ -101,12 +107,18 @@ def _get_args_dict(func, args, kwargs):


def _maybe_make_param(tensor):
if (
isinstance(tensor, (Tensor, ftdim.Tensor))
and not isinstance(tensor, nn.Parameter)
and tensor.dtype in (torch.float, torch.double, torch.half)
if isinstance(tensor, (Tensor, ftdim.Tensor)) and not isinstance(
tensor, (nn.Parameter, Buffer, BufferLegacy)
):
tensor = nn.Parameter(tensor)
if tensor.dtype in (torch.float, torch.double, torch.half):
tensor = nn.Parameter(tensor)
elif not is_batchedtensor(tensor):
# convert all non-parameters to buffers
# dataptr = tensor.data.data_ptr()
tensor = Buffer(tensor)
else:
# We want to keep the grad_fn of tensors, e.g. param.expand(10) should point to the original param
tensor = BufferLegacy(tensor)
return tensor


Expand Down Expand Up @@ -250,37 +262,43 @@ def new_func(self, *args, **kwargs):


class TensorDictParams(TensorDictBase, nn.Module):
r"""Holds a TensorDictBase instance full of parameters.
r"""A Wrapper for TensorDictBase with Parameter Exposure.
This class exposes the contained parameters to a parent nn.Module
such that iterating over the parameters of the module also iterates over
the leaves of the tensordict.
This class is designed to hold a `TensorDictBase` instance that contains parameters, making them accessible to a
parent :class:`~torch.nn.Module`. This allows for seamless integration of tensordict parameters into PyTorch modules,
enabling operations like parameter iteration and optimization.
Indexing works exactly as the indexing of the wrapped tensordict.
The parameter names will be registered within this module using :meth:`~.TensorDict.flatten_keys("_")`.
Therefore, the result of :meth:`~.named_parameters()` and the content of the
tensordict will differ slightly in term of key names.
Key Features:
Any operation that sets a tensor in the tensordict will be augmented by
a :class:`torch.nn.Parameter` conversion.
- Parameter Exposure: Parameters within the tensordict are exposed to the parent module, allowing them to be included
in operations like `named_parameters()`.
- Indexing: Indexing works similarly to the wrapped tensordict. However, parameter names (in :meth:`~.named_parameters`) are registered using
`TensorDict.flatten_keys("_")`, which may result in different key names compared to the tensordict content.
- Automatic Conversion: Any tensor set in the tensordict is automatically converted to a :class:`torch.nn.Parameter`,
unless specified otherwise through the :attr:`no_convert` keyword argument.
Args:
parameters (TensorDictBase): a tensordict to represent as parameters.
Values will be converted to parameters unless ``no_convert=True``.
Args
parameters (TensorDictBase or dict): The tensordict to represent as parameters. Values are converted to
parameters unless `no_convert=True`. If a `dict` is provided, it is wrapped in a `TensorDict` instance.
Keyword arguments can also be used.
Keyword Args:
no_convert (bool): if ``True``, no conversion to ``nn.Parameter`` will
occur at construction and after (unless the ``no_convert`` attribute is changed).
If ``no_convert`` is ``True`` and if non-parameters are present, they
will be registered as buffers.
Defaults to ``False``.
lock (bool): if ``True``, the tensordict hosted by TensorDictParams will
be locked. This can be useful to avoid unwanted modifications, but
also restricts the operations that can be done over the object (and
can have significant performance impact when `unlock_()` is required).
Defaults to ``False``.
Examples:
no_convert (bool): If `True`, no conversion to `nn.Parameter` occurs and all non-parameter, non-buffer tensors
will be converted to a :class:`~torch.nn.Buffer` instance.
If ``False``, all tensors with non-integer dtypes will be converted to :class:`~torch.nn.Parameter`
whereas integer dtypes will be converted to :class:`~torch.nn.Buffer` instances.
Defaults to `False`.
lock (bool): If `True`, the tensordict hosted by `TensorDictParams` is locked, preventing modifications and
potentially impacting performance when `unlock_()` is required.
Defaults to `False`.
.. warning:: Because the inner tensordict isn't copied or locked by default, registering the tensordict
in a ``TensorDictParams`` and modifying its content afterwards will __not__ update the values within
the ``TensorDictParams`` :meth:`.parameters` and :meth:`~.buffers` sequences.
**kwargs: Key-value pairs to populate the `TensorDictParams`. Exclusive with the `parameters` input.
Examples
>>> from torch import nn
>>> from tensordict import TensorDict
>>> module = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 4))
Expand Down Expand Up @@ -312,39 +330,100 @@ class TensorDictParams(TensorDictBase, nn.Module):
... super().__init__()
... self.params = params
>>> m = CustomModule(p)
>>> # the wrapper supports assignment and values are turned in Parameter
>>> # The wrapper supports assignment, and values are converted to Parameters
>>> m.params['other'] = torch.randn(3)
>>> assert isinstance(m.params['other'], nn.Parameter)
"""

def __init__(
self, parameters: TensorDictBase, *, no_convert=False, lock: bool = False
self,
parameters: TensorDictBase | dict | None = None,
*,
no_convert=False,
lock: bool = False,
**kwargs,
):
super().__init__()
if isinstance(parameters, TensorDictParams):
parameters = parameters._param_td
self._param_td = parameters
if parameters is None:
parameters = kwargs
elif kwargs:
raise TypeError(
f"parameters cannot be passed along with extra keyword arguments, but got {kwargs.keys()} extra args."
)

params = None
buffers = None
if isinstance(parameters, dict):
parameters = TensorDict(parameters)
elif isinstance(parameters, TensorDictParams):
params = dict(parameters._parameters)
buffers = dict(parameters._buffers)
parameters = parameters._param_td.copy().lock_()
no_convert = "skip"

self.no_convert = no_convert
if no_convert != "skip":
if not no_convert:
func = _maybe_make_param
else:
func = _maybe_make_param_or_buffer
self._param_td = _apply_leaves(self._param_td, lambda x: func(x))
self._param_td = _apply_leaves(parameters, lambda x: func(x))
else:
self._param_td = parameters

self._lock_content = lock
if lock:
self._param_td.lock_()
self._reset_params()
self._reset_params(params=params, buffers=buffers)
self._is_locked = False
self._locked_tensordicts = []
self._get_post_hook = []

@classmethod
def _new_unsafe(
cls, parameters: TensorDictBase, *, no_convert=False, lock: bool = False
cls,
parameters: TensorDictBase,
*,
no_convert=False,
lock: bool = False,
params: dict | None = None,
buffers: dict | None = None,
**kwargs,
):
return TensorDictParams(parameters, no_convert="skip", lock=lock)
if is_compiling():
return TensorDictParams(parameters, no_convert="skip", lock=lock)

self = TensorDictParams.__new__(cls)
nn.Module.__init__(self)

if parameters is None:
parameters = kwargs
elif kwargs:
raise TypeError(
f"parameters cannot be passed along with extra keyword arguments, but got {kwargs.keys()} extra args."
)

if isinstance(parameters, dict):
parameters = TensorDict._new_unsafe(parameters)
elif isinstance(parameters, TensorDictParams):
params = dict(parameters._parameters)
buffers = dict(parameters._buffers)
parameters = parameters._param_td
no_convert = "skip"

self._param_td = parameters
self.no_convert = no_convert
if no_convert != "skip":
raise RuntimeError("_new_unsafe requires no_convert to be set to 'skip'")
self._lock_content = lock
if lock:
self._param_td.lock_()
self._reset_params(params=params, buffers=buffers)
self._is_locked = False
self._locked_tensordicts = []
self._get_post_hook = []
return self

def __iter__(self):
yield from self._param_td.__iter__()
Expand All @@ -363,24 +442,35 @@ def _apply_get_post_hook(self, val):
val = new_val
return val

def _reset_params(self):
def _reset_params(self, params: dict | None = None, buffers: dict | None = None):
parameters = self._param_td
param_keys = []
params = []
buffer_keys = []
buffers = []
for key, value in parameters.items(True, True):
# flatten key
if isinstance(key, tuple):
key = ".".join(key)
if isinstance(value, nn.Parameter):
param_keys.append(key)
params.append(value)
else:
buffer_keys.append(key)
buffers.append(value)
self.__dict__["_parameters"] = dict(zip(param_keys, params))
self.__dict__["_buffers"] = dict(zip(buffer_keys, buffers))

self._parameters.clear()
self._buffers.clear()

if (params is not None) ^ (buffers is not None):
raise RuntimeError("both params and buffers must either be None or not.")
elif params is None:
param_keys = []
params = []
buffer_keys = []
buffers = []
for key, value in parameters.items(True, True):
# flatten key
if isinstance(key, tuple):
key = ".".join(key)
if isinstance(value, nn.Parameter):
param_keys.append(key)
params.append(value)
else:
buffer_keys.append(key)
buffers.append(value)

self._parameters.update(dict(zip(param_keys, params)))
self._buffers.update(dict(zip(buffer_keys, buffers)))
else:
self._parameters.update(params)
self._buffers.update(buffers)

@classmethod
def __torch_function__(
Expand Down Expand Up @@ -618,7 +708,12 @@ def _clone(self, recurse: bool = True) -> TensorDictBase:
"""
if not recurse:
return TensorDictParams(self._param_td._clone(False), no_convert="skip")
return TensorDictParams._new_unsafe(
self._param_td._clone(False),
no_convert="skip",
params=dict(self._parameters),
buffers=dict(self._buffers),
)

memo = {}

Expand Down Expand Up @@ -897,6 +992,7 @@ def _propagate_unlock(self):

if not self._lock_content:
return self._param_td._propagate_unlock()
return []

unlock_ = TensorDict.unlock_
lock_ = TensorDict.lock_
Expand Down
45 changes: 45 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1896,6 +1896,51 @@ def test_td_params(self):
assert (m.params == params).all()
assert (params == m.params).all()

def test_constructors(self):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.register_parameter(
"param", nn.Parameter(torch.randn(3, requires_grad=True))
)
self.register_buffer("buf", torch.randn(3))
self.register_buffer("buf_int", torch.randint(3, ()))

td = TensorDict.from_module(MyModule())
assert not isinstance(td, TensorDictParams)
td = TensorDictParams(td)
assert isinstance(td, TensorDictParams)
assert isinstance(td["param"], nn.Parameter)
assert isinstance(td["buf"], nn.Parameter)
assert isinstance(td["buf_int"], Buffer)
td = TensorDict.from_module(MyModule())
assert not isinstance(td, TensorDictParams)
td = TensorDictParams(td, no_convert=True)
assert isinstance(td, TensorDictParams)
assert isinstance(td["param"], nn.Parameter)
assert isinstance(td["buf"], Buffer)
assert isinstance(td["buf_int"], Buffer)

td = TensorDict.from_module(MyModule(), as_module=True)
assert isinstance(td, TensorDictParams)
assert isinstance(td["param"], nn.Parameter)
assert isinstance(td["buf"], Buffer)
assert isinstance(td["buf_int"], Buffer)

tdparams = TensorDictParams(a=0, b=1.0)
assert isinstance(tdparams["a"], Buffer)
assert isinstance(tdparams["b"], nn.Parameter)

tdparams = TensorDictParams({"a": 0, "b": 1.0})
assert isinstance(tdparams["a"], Buffer)
assert isinstance(tdparams["b"], nn.Parameter)
tdparams_copy = tdparams.copy()

def assert_is_identical(a, b):
assert a is b

tdparams.apply(assert_is_identical, tdparams_copy, filter_empty=True)

def test_td_params_cast(self):
params = self._get_params()
p = TensorDictParams(params)
Expand Down

0 comments on commit df870ef

Please sign in to comment.