Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Robust to lazy_legacy set to false and context managers for reshape ops #634

Merged
merged 13 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 59 additions & 7 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from tensordict.utils import (
_broadcast_tensors,
_check_keys,
_get_shape_from_args,
_getitem_batch_size,
_is_number,
_parse_to,
Expand Down Expand Up @@ -2314,19 +2315,70 @@ def _permute(
*args,
**kwargs,
):
raise RuntimeError(
"Cannot call `permute` on a lazy stacked tensordict. Make it dense before calling this method by calling `to_tensordict`."
dims_list = _get_shape_from_args(*args, kwarg_name="dims", **kwargs)
dims_list = [dim if dim >= 0 else self.ndim + dim for dim in dims_list]
dims_list_sort = np.argsort(dims_list)
# find the new stack dim
stack_dim = dims_list_sort[self.stack_dim]
# remove that dim from the dims_list
dims_list = [
d if d < self.stack_dim else d - 1 for d in dims_list if d != self.stack_dim
]
result = LazyStackedTensorDict.lazy_stack(
[td.permute(dims_list) for td in self.tensordicts], stack_dim
)
result._td_dim_name = self._td_dim_name
return result

def _squeeze(self, dim=None):
raise RuntimeError(
"Cannot call `squeeze` on a lazy stacked tensordict. Make it dense before calling this method by calling `to_tensordict`."
)
if dim is not None:
new_dim = dim
if new_dim < 0:
new_dim = self.batch_dims + new_dim
if new_dim > self.batch_dims - 1 or new_dim < 0:
raise RuntimeError(
f"The dim provided to squeeze is incompatible with the tensordict shape: dim={dim} and batch_size={self.batch_size}."
)
dim = new_dim
if self.batch_size[dim] != 1:
return self
if dim == self.stack_dim:
return self.tensordicts[0]
if dim > self.stack_dim:
dim = dim - 1
stack_dim = self.stack_dim
else:
stack_dim = self.stack_dim - 1
result = LazyStackedTensorDict.lazy_stack(
[td.squeeze(dim) for td in self.tensordicts], stack_dim
)
result._td_dim_name = result._td_dim_name
else:
result = self
for dim in range(self.batch_dims - 1, -1, -1):
if self.batch_size[dim] == 1:
result = result.squeeze(dim)
return result

def _unsqueeze(self, dim):
raise RuntimeError(
"Cannot call `unsqueeze` on a lazy stacked tensordict. Make it dense before calling this method by calling `to_tensordict`."
new_dim = dim
if new_dim < 0:
new_dim = self.batch_dims + new_dim + 1
if new_dim > self.batch_dims or new_dim < 0:
raise RuntimeError(
f"The dim provided to unsqueeze is incompatible with the tensordict shape: dim={dim} and batch_size={self.batch_size}."
)
dim = new_dim
if dim > self.stack_dim:
dim = dim - 1
stack_dim = self.stack_dim
else:
stack_dim = self.stack_dim + 1
result = LazyStackedTensorDict.lazy_stack(
[td.unsqueeze(dim) for td in self.tensordicts], stack_dim
)
result._td_dim_name = result._td_dim_name
return result

lock_ = TensorDictBase.lock_
lock = _renamed_inplace_method(lock_)
Expand Down
21 changes: 9 additions & 12 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,24 +1027,21 @@ def _permute(self, *args, **kwargs):
if np.array_equal(dims_list, range(len(dims_list))):
return self

# min_dim, max_dim = -self.batch_dims, self.batch_dims - 1
# seen = [False for dim in range(max_dim + 1)]
# for idx in dims_list:
# if idx < min_dim or idx > max_dim:
# raise IndexError(
# f"dimension out of range (expected to be in range of [{min_dim}, {max_dim}], but got {idx})"
# )
# if seen[idx]:
# raise RuntimeError("repeated dim in permute")
# seen[idx] = True
def _permute(tensor):
return tensor.permute(*dims_list, *range(len(dims_list), tensor.ndim))

batch_size = self.batch_size
batch_size = [batch_size[p] for p in dims_list] + list(
batch_size[len(dims_list) :]
)
result = self._fast_apply(_permute, batch_size=batch_size, call_on_nested=True)
if self._has_names():
names = self.names
names = [names[i] for i in dims_list]
else:
names = None
result = self._fast_apply(
_permute, batch_size=batch_size, call_on_nested=True, names=names
)
self._maybe_set_shared_attributes(result)
return result

Expand Down Expand Up @@ -1119,7 +1116,7 @@ def _unsqueeze(self, dim):
batch_size = torch.Size(batch_size)

names = copy(self.names)
names.insert(dim, None)
names.insert(newdim, None)

def _unsqueeze(tensor):
return tensor.unsqueeze(newdim)
Expand Down
27 changes: 23 additions & 4 deletions tensordict/_torch_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
from __future__ import annotations

import functools

import warnings
from typing import Any, Callable, Sequence, TypeVar

import torch

from tensordict._lazy import LazyStackedTensorDict
from tensordict._td import TensorDict

from tensordict.base import _is_leaf_nontensor, NO_DEFAULT, TensorDictBase
from tensordict.persistent import PersistentTensorDict
from tensordict.utils import (
Expand All @@ -24,7 +25,6 @@
)
from torch import Tensor


TD_HANDLED_FUNCTIONS: dict[Callable, Callable] = {}
LAZY_TD_HANDLED_FUNCTIONS: dict[Callable, Callable] = {}
T = TypeVar("T", bound="TensorDictBase")
Expand Down Expand Up @@ -367,6 +367,25 @@ def _stack(
)

# check that all tensordict match
# Read lazy_legacy
_lazy_legacy = lazy_legacy(allow_none=True)
if _lazy_legacy is None:
warnings.warn(
"""You did not define if torch.stack was to return a dense or lazy
stack of tensordicts. Up until v0.3 included, a lazy stack was returned.
From v0.4 onward, a dense stack will be returned and to build a
lazy stack, an explicit call to LazyStackedTensorDict.lazy_stack will be required.
To silence this warning, choose one of the following options:
- set the LAZY_LEGACY_OP to 'True' (recommended) or 'False' depending on
the behaviour you want to use. Another way to achieve this is to call
`tensordict.set_lazy_legacy(True).set()` at the beginning of your script.
- set the decorator/context manager `tensordict.set_lazy_legacy(True)` (recommended) around
the function or code block where stack is used.
- Use `LazyStackedTensorDict.lazy_stack()` if it is a lazy stack that you wish to use.""",
category=DeprecationWarning,
)
# get default
_lazy_legacy = lazy_legacy()

if out is None:
# We need to handle tensordicts with exclusive keys and tensordicts with
Expand All @@ -375,11 +394,11 @@ def _stack(
# don't match exactly.
# The second requires a check over the tensor shapes.
device = list_of_tensordicts[0].device
if contiguous or not lazy_legacy():
if contiguous or not _lazy_legacy:
try:
keys = _check_keys(list_of_tensordicts, strict=True)
except KeyError:
if not lazy_legacy() and not contiguous:
if not _lazy_legacy and not contiguous:
with set_lazy_legacy(True):
return _stack(list_of_tensordicts, dim=dim)
raise
Expand Down
Loading
Loading