Skip to content

Commit

Permalink
[Refactor] Avoid TDParams parameters and buffers construction when ob…
Browse files Browse the repository at this point in the history
…vious

ghstack-source-id: 6c833eb5b6144174e733bc7eedae435a6e9fce18
Pull Request resolved: #1100
  • Loading branch information
vmoens committed Nov 21, 2024
1 parent 55f6b91 commit 0836b6b
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 46 deletions.
6 changes: 3 additions & 3 deletions tensordict/_contextlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import numpy as np

try:
from torch.compiler import is_dynamo_compiling
from torch.compiler import is_compiling
except ImportError: # torch 2.0
from torch._dynamo import is_compiling as is_dynamo_compiling
from torch._dynamo import is_compiling


# Used for annotating the decorator usage of _DecoratorContextManager (e.g.,
Expand Down Expand Up @@ -330,7 +330,7 @@ def _reverse_squeeze(self, args, kwargs, out):

def _reverse_to_module(self, args, kwargs, out):
try:
with out.unlock_() if not is_dynamo_compiling() else contextlib.nullcontext():
with out.unlock_() if not is_compiling() else contextlib.nullcontext():
return self.to_module(*args, **kwargs, swap_dest=out)
except AttributeError:
# This is a bit unsafe but we assume that out won't have an unlock_() if it's not a TD
Expand Down
2 changes: 2 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3070,6 +3070,8 @@ def _legacy_permute(

# Cache functionality
def _erase_cache(self):
if is_dynamo_compiling():
return
self._cache = None

# Dim names functionality
Expand Down
144 changes: 116 additions & 28 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@
from tensordict.utils import Buffer


try:
from torch.compiler import is_compiling
except ImportError:
from torch._dynamo import is_compiling


def _apply_leaves(data, fn):
if isinstance(data, TensorDict):
with data.unlock_():
Expand Down Expand Up @@ -114,7 +120,7 @@ def _maybe_make_param_or_buffer(tensor):
if (
isinstance(tensor, (Tensor, ftdim.Tensor))
and not isinstance(tensor, (nn.Parameter, Buffer))
and tensor.dtype in (torch.float, torch.double, torch.half)
# and tensor.dtype in (torch.float, torch.double, torch.half)
):
if not tensor.requires_grad and not is_batchedtensor(tensor):
# convert all non-parameters to buffers
Expand Down Expand Up @@ -267,8 +273,10 @@ class TensorDictParams(TensorDictBase, nn.Module):
a :class:`torch.nn.Parameter` conversion.
Args:
parameters (TensorDictBase): a tensordict to represent as parameters.
parameters (TensorDictBase or dict): a tensordict to represent as parameters.
Values will be converted to parameters unless ``no_convert=True``.
If a dict is provided, it will be first wrapped to a :class:`~tensordict.TensorDict`
instance. Keyword arguments can be used instead.
Keyword Args:
no_convert (bool): if ``True``, no conversion to ``nn.Parameter`` will
Expand All @@ -281,6 +289,8 @@ class TensorDictParams(TensorDictBase, nn.Module):
also restricts the operations that can be done over the object (and
can have significant performance impact when `unlock_()` is required).
Defaults to ``False``.
**kwargs: Key-value pairs to populate the ``TensorDictParams``. Exclusive with
the :attr:`parameters` input.
Examples:
>>> from torch import nn
Expand Down Expand Up @@ -321,32 +331,93 @@ class TensorDictParams(TensorDictBase, nn.Module):
"""

def __init__(
self, parameters: TensorDictBase, *, no_convert=False, lock: bool = False
self,
parameters: TensorDictBase | dict | None = None,
*,
no_convert=False,
lock: bool = False,
**kwargs,
):
super().__init__()
if isinstance(parameters, TensorDictParams):
parameters = parameters._param_td
self._param_td = parameters
if parameters is None:
parameters = kwargs
elif kwargs:
raise TypeError(
f"parameters cannot be passed along with extra keyword arguments, but got {kwargs.keys()} extra args."
)

params = None
buffers = None
if isinstance(parameters, dict):
parameters = TensorDict(parameters)
elif isinstance(parameters, TensorDictParams):
params = dict(parameters._parameters)
buffers = dict(parameters._buffers)
parameters = parameters._param_td.copy().lock_()
no_convert = "skip"

self.no_convert = no_convert
if no_convert != "skip":
if not no_convert:
func = _maybe_make_param
else:
func = _maybe_make_param_or_buffer
self._param_td = _apply_leaves(self._param_td, lambda x: func(x))
self._param_td = _apply_leaves(parameters, lambda x: func(x))
else:
self._param_td = parameters

self._lock_content = lock
if lock:
self._param_td.lock_()
self._reset_params()
self._reset_params(params=params, buffers=buffers)
self._is_locked = False
self._locked_tensordicts = []
self._get_post_hook = []

@classmethod
def _new_unsafe(
cls, parameters: TensorDictBase, *, no_convert=False, lock: bool = False
cls,
parameters: TensorDictBase,
*,
no_convert=False,
lock: bool = False,
params: dict | None = None,
buffers: dict | None = None,
**kwargs,
):
return TensorDictParams(parameters, no_convert="skip", lock=lock)
if is_compiling():
return TensorDictParams(parameters, no_convert="skip", lock=lock)

self = TensorDictParams.__new__(cls)
nn.Module.__init__(self)

if parameters is None:
parameters = kwargs
elif kwargs:
raise TypeError(
f"parameters cannot be passed along with extra keyword arguments, but got {kwargs.keys()} extra args."
)

if isinstance(parameters, dict):
parameters = TensorDict._new_unsafe(parameters)
elif isinstance(parameters, TensorDictParams):
params = dict(parameters._parameters)
buffers = dict(parameters._buffers)
parameters = parameters._param_td
no_convert = "skip"

self._param_td = parameters
self.no_convert = no_convert
if no_convert != "skip":
raise RuntimeError("_new_unsafe requires no_convert to be set to 'skip'")
self._lock_content = lock
if lock:
self._param_td.lock_()
self._reset_params(params=params, buffers=buffers)
self._is_locked = False
self._locked_tensordicts = []
self._get_post_hook = []
return self

def __iter__(self):
yield from self._param_td.__iter__()
Expand All @@ -365,24 +436,35 @@ def _apply_get_post_hook(self, val):
val = new_val
return val

def _reset_params(self):
def _reset_params(self, params: dict | None = None, buffers: dict | None = None):
parameters = self._param_td
param_keys = []
params = []
buffer_keys = []
buffers = []
for key, value in parameters.items(True, True):
# flatten key
if isinstance(key, tuple):
key = ".".join(key)
if isinstance(value, nn.Parameter):
param_keys.append(key)
params.append(value)
else:
buffer_keys.append(key)
buffers.append(value)
self.__dict__["_parameters"] = dict(zip(param_keys, params))
self.__dict__["_buffers"] = dict(zip(buffer_keys, buffers))

self._parameters.clear()
self._buffers.clear()

if (params is not None) ^ (buffers is not None):
raise RuntimeError("both params and buffers must either be None or not.")
elif params is None:
param_keys = []
params = []
buffer_keys = []
buffers = []
for key, value in parameters.items(True, True):
# flatten key
if isinstance(key, tuple):
key = ".".join(key)
if isinstance(value, nn.Parameter):
param_keys.append(key)
params.append(value)
else:
buffer_keys.append(key)
buffers.append(value)

self._parameters.update(dict(zip(param_keys, params)))
self._buffers.update(dict(zip(buffer_keys, buffers)))
else:
self._parameters.update(params)
self._buffers.update(buffers)

@classmethod
def __torch_function__(
Expand Down Expand Up @@ -620,7 +702,12 @@ def _clone(self, recurse: bool = True) -> TensorDictBase:
"""
if not recurse:
return TensorDictParams(self._param_td._clone(False), no_convert="skip")
return TensorDictParams._new_unsafe(
self._param_td._clone(False),
no_convert="skip",
params=dict(self._parameters),
buffers=dict(self._buffers),
)

memo = {}

Expand Down Expand Up @@ -899,6 +986,7 @@ def _propagate_unlock(self):

if not self._lock_content:
return self._param_td._propagate_unlock()
return []

unlock_ = TensorDict.unlock_
lock_ = TensorDict.lock_
Expand Down
30 changes: 15 additions & 15 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@
except ImportError:
_has_funcdim = False
try:
from torch.compiler import assume_constant_result, is_dynamo_compiling
from torch.compiler import assume_constant_result, is_compiling
except ImportError: # torch 2.0
from torch._dynamo import (
assume_constant_result,
is_compiling as is_dynamo_compiling,
is_compiling,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -863,7 +863,7 @@ def _is_tensorclass(cls: type) -> bool:
out = _TENSORCLASS_MEMO.get(cls, None)
if out is None:
out = getattr(cls, "_is_tensorclass", False)
if not is_dynamo_compiling():
if not is_compiling():
_TENSORCLASS_MEMO[cls] = out
return out

Expand Down Expand Up @@ -1119,7 +1119,7 @@ def cache(fun):

@wraps(fun)
def newfun(_self: "TensorDictBase", *args, **kwargs):
if not _self.is_locked or is_dynamo_compiling():
if not _self.is_locked or is_compiling():
return fun(_self, *args, **kwargs)
cache = _self._cache
if cache is None:
Expand Down Expand Up @@ -1359,7 +1359,7 @@ def _parse_to(*args, **kwargs):
num_threads = kwargs.pop("num_threads", None)
other = kwargs.pop("other", None)
inplace = kwargs.pop("inplace", False)
if not is_dynamo_compiling():
if not is_compiling():
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
*args, **kwargs
)
Expand Down Expand Up @@ -1733,7 +1733,7 @@ def _check_keys(
is_leaf=_is_leaf_nontensor,
)
# TODO: compile doesn't like set() over an arbitrary object
if is_dynamo_compiling():
if is_compiling():
keys = {k for k in keys} # noqa: C416
else:
keys: set[str] = set(keys)
Expand All @@ -1746,7 +1746,7 @@ def _check_keys(
if not strict:
keys = keys.intersection(k)
else:
if is_dynamo_compiling():
if is_compiling():
k = {v for v in k} # noqa: C416
else:
k = set(k)
Expand Down Expand Up @@ -2015,7 +2015,7 @@ def _getitem_batch_size(batch_size, index):
continue
elif isinstance(idx, slice):
batch = batch_size[count]
if is_dynamo_compiling():
if is_compiling():
out.append(len(range(*_slice_indices(idx, batch))))
else:
out.append(len(range(*idx.indices(batch))))
Expand Down Expand Up @@ -2447,7 +2447,7 @@ def is_non_tensor(data):

def _is_non_tensor(cls: type):
out = None
is_dynamo = is_dynamo_compiling()
is_dynamo = is_compiling()
if not is_dynamo:
out = _NON_TENSOR_MEMO.get(cls)
if out is None:
Expand Down Expand Up @@ -2503,7 +2503,7 @@ def new_func(self):


def _unravel_key_to_tuple(key):
if not is_dynamo_compiling():
if not is_compiling():
return _unravel_key_to_tuple_cpp(key)
if isinstance(key, str):
return (key,)
Expand All @@ -2524,7 +2524,7 @@ def unravel_key(key):
("a", "b")
"""
if not is_dynamo_compiling():
if not is_compiling():
return unravel_key_cpp(key)
if isinstance(key, str):
return key
Expand All @@ -2537,14 +2537,14 @@ def unravel_key(key):

def unravel_keys(*keys):
"""Unravels a sequence of keys."""
if not is_dynamo_compiling():
if not is_compiling():
return unravel_keys_cpp(*keys)
return tuple(unravel_key(key) for key in keys)


def unravel_key_list(keys):
"""Unravels a list of keys."""
if not is_dynamo_compiling():
if not is_compiling():
return unravel_key_list_cpp(keys)
return [unravel_key(key) for key in keys]

Expand Down Expand Up @@ -2823,11 +2823,11 @@ def __init__(self, default=None):
self._lock = threading.Lock()

def get_mode(self) -> Any | None:
cm = self._lock if not is_dynamo_compiling() else nullcontext()
cm = self._lock if not is_compiling() else nullcontext()
with cm:
return self._mode

def set_mode(self, type: Any | None) -> None:
cm = self._lock if not is_dynamo_compiling() else nullcontext()
cm = self._lock if not is_compiling() else nullcontext()
with cm:
self._mode = type

0 comments on commit 0836b6b

Please sign in to comment.