Skip to content

Commit

Permalink
[Refactor] Improve functional call efficiency (#567)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Nov 23, 2023
1 parent 2ea264b commit 57fc236
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 46 deletions.
80 changes: 48 additions & 32 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
from functorch import dim as ftdim
from tensordict._td import _SubTensorDict, _TensorDictKeysView, TensorDict
from tensordict._tensordict import _unravel_key_to_tuple
from tensordict._tensordict import _unravel_key_to_tuple, unravel_key
from tensordict.base import (
_ACCEPTED_CLASSES,
_is_tensor_collection,
Expand Down Expand Up @@ -133,6 +133,8 @@ class LazyStackedTensorDict(TensorDictBase):
"""

_is_vmapped: bool = False

@classmethod
def __torch_function__(
cls,
Expand Down Expand Up @@ -362,7 +364,7 @@ def _set_str(
if not validated:
value = self._validate_value(value)
validated = True
if self.hook_in is not None:
if self._is_vmapped:
value = self.hook_in(value)
values = value.unbind(self.stack_dim)
for tensordict, item in zip(self.tensordicts, values):
Expand Down Expand Up @@ -397,7 +399,7 @@ def _set_tuple(
if not validated:
value = self._validate_value(value)
validated = True
if self.hook_in is not None:
if self._is_vmapped:
value = self.hook_in(value)
values = value.unbind(self.stack_dim)
for tensordict, item in zip(self.tensordicts, values):
Expand Down Expand Up @@ -554,7 +556,7 @@ def _set_at_str(self, key, value, index, *, validated):
if not validated:
value = self._validate_value(value, check_shape=False)
validated = True
if self.hook_in is not None:
if self._is_vmapped:
value = self.hook_in(value)
split_index = self._split_index(index)
converted_idx = split_index["index_dict"]
Expand Down Expand Up @@ -649,7 +651,7 @@ def _set_at_tuple(self, key, value, idx, *, validated):
if not validated:
value = self._validate_value(value, check_shape=False)
validated = True
if self.hook_in is not None:
if self._is_vmapped:
value = self.hook_in(value)
item = td._get_str(key, NO_DEFAULT)
item[idx] = value
Expand Down Expand Up @@ -778,10 +780,22 @@ def _get_str(
# then it's a LazyStackedTD
out.hook_out = self.hook_out
out.hook_in = self.hook_in
out._is_vmapped = self._is_vmapped
incr = 0 if not self._is_vmapped else 1
out._batch_size = (
self._batch_size
+ out.batch_size[(len(self._batch_size) + incr) :]
)
else:
# then it's a tensorclass
out._tensordict.hook_out = self.hook_out
out._tensordict.hook_in = self.hook_in
out._tensordict._is_vmapped = self._is_vmapped
incr = 0 if not self._is_vmapped else 1
out._tensordict._batch_size = (
self._batch_size
+ out._tensordict.batch_size[(len(self._batch_size) + incr) :]
)
elif self.hook_out is not None:
out = self.hook_out(out)
return out
Expand All @@ -802,7 +816,7 @@ def _get_str(
def _get_tuple(self, key, default):
first = self._get_str(key[0], None)
if first is None:
return self._default_get(first, default)
return self._default_get(key[0], default)
if len(key) == 1:
return first
try:
Expand Down Expand Up @@ -850,7 +864,7 @@ def _cached_add_batch_dims(cls, td, in_dim, vmap_level):
# we return a stack with hook_out, and hack the batch_size and names
# Per se it is still a LazyStack but the stacking dim is "hidden" from
# the outside
out = td.clone(False)
out = td.copy()

def hook_out(tensor, in_dim=in_dim, vmap_level=vmap_level):
return _add_batch_dim(tensor, in_dim, vmap_level)
Expand All @@ -869,6 +883,7 @@ def hook_in(

out.hook_out = hook_out
out.hook_in = hook_in
out._is_vmapped = True
out._batch_size = torch.Size(
[dim for i, dim in enumerate(out._batch_size) if i != out.stack_dim]
)
Expand Down Expand Up @@ -1570,7 +1585,7 @@ def update(self, input_dict_or_td: T, clone: bool = False, **kwargs: Any) -> T:
isinstance(input_dict_or_td, LazyStackedTensorDict)
and input_dict_or_td.stack_dim == self.stack_dim
):
if not input_dict_or_td.shape[self.stack_dim] == len(self.tensordicts):
if len(input_dict_or_td.tensordicts) != len(self.tensordicts):
raise ValueError(
"cannot update stacked tensordicts with different shapes."
)
Expand All @@ -1580,36 +1595,37 @@ def update(self, input_dict_or_td: T, clone: bool = False, **kwargs: Any) -> T:
td_dest.update(td_source, clone=clone, **kwargs)
return self

keys = self.keys(False)
inplace = kwargs.get("inplace", False)
for key, value in input_dict_or_td.items():
if clone and hasattr(value, "clone"):
value = value.clone()
else:
elif clone:
value = tree_map(torch.clone, value)
key = unravel_key(key)
if isinstance(key, tuple):
key, subkey = key[0], key[1:]
else:
subkey = ()
# the key must be a string by now. Let's check if it is present
if key in keys:
target_class = self.entry_class(key)
if _is_tensor_collection(target_class):
if isinstance(value, dict):
value_unbind = TensorDict(
value, self.batch_size, _run_checks=False
).unbind(self.stack_dim)
else:
value_unbind = value.unbind(self.stack_dim)
for t, _value in zip(self.tensordicts, value_unbind):
if len(subkey):
t.update({key: {subkey: _value}}, clone=clone, **kwargs)
else:
t.update({key: _value}, clone=clone, **kwargs)
continue
if len(subkey):
self.set((key, *subkey), value, **kwargs)
# we must check that the target is not a leaf
target = self._get_str(key[0], default=None)
if is_tensor_collection(target):
target.update({key[1:]: value}, inplace=inplace, clone=clone)
elif target is None:
self._set_tuple(key, value, inplace=inplace, validated=False)
else:
raise TypeError(
f"Type mismatch: self.get(key[0]) is {type(target)} but expected a tensor collection."
)
else:
self.set(key, value, **kwargs)
target = self._get_str(key, default=None)
if is_tensor_collection(target) and (
is_tensor_collection(value) or isinstance(value, dict)
):
target.update(value, inplace=inplace, clone=clone)
elif target is None or not is_tensor_collection(value):
self._set_str(key, value, inplace=inplace, validated=False)
else:
raise TypeError(
f"Type mismatch: self.get(key) is {type(target)} but value is of type {type(value)}."
)

return self

def update_(
Expand Down
29 changes: 17 additions & 12 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,11 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None)
# For batch-size it is a minor issue (unlikely that a td with batch-size
# is passed with to_module) but for the device it could be a problem.
if swap_dest is None:
swap = self.empty()
swap.clear_device_()
swap = TensorDict({}, batch_size=[])
else:
swap = swap_dest
memo[id(module)] = swap
_swap = {}

for key, value in self.items():
if isinstance(value, (Tensor, ftdim.Tensor)):
Expand Down Expand Up @@ -320,8 +320,13 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None)
swap = swap.to(device=local_out.device)

if return_swap:
assert local_out is not None, key
swap._set_str(key, local_out, inplace=False, validated=True)
_swap[key] = local_out
if return_swap:
if isinstance(swap, TensorDict):
# this is very ad-hoc but faster than calling _set_str every time
swap._tensordict.update(_swap)
else:
swap.update(_swap)
return swap

def __ne__(self, other: object) -> T | bool:
Expand Down Expand Up @@ -1242,12 +1247,13 @@ def _set_str(
inplace: bool,
validated: bool,
) -> T:
best_attempt = inplace is BEST_ATTEMPT_INPLACE
inplace = self._convert_inplace(inplace, key)
if inplace is not False:
best_attempt = inplace is BEST_ATTEMPT_INPLACE
inplace = self._convert_inplace(inplace, key)
if not validated:
value = self._validate_value(value, check_shape=True)
if not inplace:
if self.is_locked:
if self._is_locked:
raise RuntimeError(_LOCK_ERROR)
self._tensordict[key] = value
else:
Expand Down Expand Up @@ -1703,14 +1709,13 @@ def contiguous(self) -> T:
def empty(self, recurse=False) -> T:
if not recurse:
return TensorDict(
device=self.device,
batch_size=self.batch_size,
device=self._device,
batch_size=self._batch_size,
source={},
# names=self.names if self._has_names() else None,
names=self._td_dim_names,
_run_checks=False,
_is_memmap=self._is_memmap,
_is_shared=self._is_shared,
_is_memmap=False,
_is_shared=False,
)
return super().empty(recurse=recurse)

Expand Down
7 changes: 5 additions & 2 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3125,8 +3125,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
return self.lock_()
if last_op == self.__class__.to_module.__name__:
if is_tensor_collection(out):
with out.unlock_():
return self.to_module(*args, **kwargs, swap_dest=out)
return self.to_module(*args, **kwargs, swap_dest=out)
else:
raise RuntimeError(
"to_module cannot be used as a decorator when return_swap=False."
Expand Down Expand Up @@ -3520,6 +3519,10 @@ def flatten_keys(self, separator: str = ".", inplace: bool = False) -> T:
result._set_str(
leaf_flat, self.get(leaf), validated=True, inplace=False
)
shared = result._is_shared = self._is_shared
mmap = result._is_memmap = self._is_memmap
if shared or mmap:
result._is_locked = True
return result

@cache # noqa: B019
Expand Down

0 comments on commit 57fc236

Please sign in to comment.