diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 488681dcd..d060beda2 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -19,7 +19,7 @@ import torch from functorch import dim as ftdim from tensordict._td import _SubTensorDict, _TensorDictKeysView, TensorDict -from tensordict._tensordict import _unravel_key_to_tuple +from tensordict._tensordict import _unravel_key_to_tuple, unravel_key from tensordict.base import ( _ACCEPTED_CLASSES, _is_tensor_collection, @@ -133,6 +133,8 @@ class LazyStackedTensorDict(TensorDictBase): """ + _is_vmapped: bool = False + @classmethod def __torch_function__( cls, @@ -362,7 +364,7 @@ def _set_str( if not validated: value = self._validate_value(value) validated = True - if self.hook_in is not None: + if self._is_vmapped: value = self.hook_in(value) values = value.unbind(self.stack_dim) for tensordict, item in zip(self.tensordicts, values): @@ -397,7 +399,7 @@ def _set_tuple( if not validated: value = self._validate_value(value) validated = True - if self.hook_in is not None: + if self._is_vmapped: value = self.hook_in(value) values = value.unbind(self.stack_dim) for tensordict, item in zip(self.tensordicts, values): @@ -554,7 +556,7 @@ def _set_at_str(self, key, value, index, *, validated): if not validated: value = self._validate_value(value, check_shape=False) validated = True - if self.hook_in is not None: + if self._is_vmapped: value = self.hook_in(value) split_index = self._split_index(index) converted_idx = split_index["index_dict"] @@ -649,7 +651,7 @@ def _set_at_tuple(self, key, value, idx, *, validated): if not validated: value = self._validate_value(value, check_shape=False) validated = True - if self.hook_in is not None: + if self._is_vmapped: value = self.hook_in(value) item = td._get_str(key, NO_DEFAULT) item[idx] = value @@ -778,10 +780,22 @@ def _get_str( # then it's a LazyStackedTD out.hook_out = self.hook_out out.hook_in = self.hook_in + out._is_vmapped = self._is_vmapped + incr = 0 if not self._is_vmapped else 1 + out._batch_size = ( + self._batch_size + + out.batch_size[(len(self._batch_size) + incr) :] + ) else: # then it's a tensorclass out._tensordict.hook_out = self.hook_out out._tensordict.hook_in = self.hook_in + out._tensordict._is_vmapped = self._is_vmapped + incr = 0 if not self._is_vmapped else 1 + out._tensordict._batch_size = ( + self._batch_size + + out._tensordict.batch_size[(len(self._batch_size) + incr) :] + ) elif self.hook_out is not None: out = self.hook_out(out) return out @@ -802,7 +816,7 @@ def _get_str( def _get_tuple(self, key, default): first = self._get_str(key[0], None) if first is None: - return self._default_get(first, default) + return self._default_get(key[0], default) if len(key) == 1: return first try: @@ -850,7 +864,7 @@ def _cached_add_batch_dims(cls, td, in_dim, vmap_level): # we return a stack with hook_out, and hack the batch_size and names # Per se it is still a LazyStack but the stacking dim is "hidden" from # the outside - out = td.clone(False) + out = td.copy() def hook_out(tensor, in_dim=in_dim, vmap_level=vmap_level): return _add_batch_dim(tensor, in_dim, vmap_level) @@ -869,6 +883,7 @@ def hook_in( out.hook_out = hook_out out.hook_in = hook_in + out._is_vmapped = True out._batch_size = torch.Size( [dim for i, dim in enumerate(out._batch_size) if i != out.stack_dim] ) @@ -1570,7 +1585,7 @@ def update(self, input_dict_or_td: T, clone: bool = False, **kwargs: Any) -> T: isinstance(input_dict_or_td, LazyStackedTensorDict) and input_dict_or_td.stack_dim == self.stack_dim ): - if not input_dict_or_td.shape[self.stack_dim] == len(self.tensordicts): + if len(input_dict_or_td.tensordicts) != len(self.tensordicts): raise ValueError( "cannot update stacked tensordicts with different shapes." ) @@ -1580,36 +1595,37 @@ def update(self, input_dict_or_td: T, clone: bool = False, **kwargs: Any) -> T: td_dest.update(td_source, clone=clone, **kwargs) return self - keys = self.keys(False) + inplace = kwargs.get("inplace", False) for key, value in input_dict_or_td.items(): if clone and hasattr(value, "clone"): value = value.clone() - else: + elif clone: value = tree_map(torch.clone, value) + key = unravel_key(key) if isinstance(key, tuple): - key, subkey = key[0], key[1:] - else: - subkey = () - # the key must be a string by now. Let's check if it is present - if key in keys: - target_class = self.entry_class(key) - if _is_tensor_collection(target_class): - if isinstance(value, dict): - value_unbind = TensorDict( - value, self.batch_size, _run_checks=False - ).unbind(self.stack_dim) - else: - value_unbind = value.unbind(self.stack_dim) - for t, _value in zip(self.tensordicts, value_unbind): - if len(subkey): - t.update({key: {subkey: _value}}, clone=clone, **kwargs) - else: - t.update({key: _value}, clone=clone, **kwargs) - continue - if len(subkey): - self.set((key, *subkey), value, **kwargs) + # we must check that the target is not a leaf + target = self._get_str(key[0], default=None) + if is_tensor_collection(target): + target.update({key[1:]: value}, inplace=inplace, clone=clone) + elif target is None: + self._set_tuple(key, value, inplace=inplace, validated=False) + else: + raise TypeError( + f"Type mismatch: self.get(key[0]) is {type(target)} but expected a tensor collection." + ) else: - self.set(key, value, **kwargs) + target = self._get_str(key, default=None) + if is_tensor_collection(target) and ( + is_tensor_collection(value) or isinstance(value, dict) + ): + target.update(value, inplace=inplace, clone=clone) + elif target is None or not is_tensor_collection(value): + self._set_str(key, value, inplace=inplace, validated=False) + else: + raise TypeError( + f"Type mismatch: self.get(key) is {type(target)} but value is of type {type(value)}." + ) + return self def update_( diff --git a/tensordict/_td.py b/tensordict/_td.py index 2e01b91cd..5c34d8fbd 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -271,11 +271,11 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None) # For batch-size it is a minor issue (unlikely that a td with batch-size # is passed with to_module) but for the device it could be a problem. if swap_dest is None: - swap = self.empty() - swap.clear_device_() + swap = TensorDict({}, batch_size=[]) else: swap = swap_dest memo[id(module)] = swap + _swap = {} for key, value in self.items(): if isinstance(value, (Tensor, ftdim.Tensor)): @@ -320,8 +320,13 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None) swap = swap.to(device=local_out.device) if return_swap: - assert local_out is not None, key - swap._set_str(key, local_out, inplace=False, validated=True) + _swap[key] = local_out + if return_swap: + if isinstance(swap, TensorDict): + # this is very ad-hoc but faster than calling _set_str every time + swap._tensordict.update(_swap) + else: + swap.update(_swap) return swap def __ne__(self, other: object) -> T | bool: @@ -1242,12 +1247,13 @@ def _set_str( inplace: bool, validated: bool, ) -> T: - best_attempt = inplace is BEST_ATTEMPT_INPLACE - inplace = self._convert_inplace(inplace, key) + if inplace is not False: + best_attempt = inplace is BEST_ATTEMPT_INPLACE + inplace = self._convert_inplace(inplace, key) if not validated: value = self._validate_value(value, check_shape=True) if not inplace: - if self.is_locked: + if self._is_locked: raise RuntimeError(_LOCK_ERROR) self._tensordict[key] = value else: @@ -1703,14 +1709,13 @@ def contiguous(self) -> T: def empty(self, recurse=False) -> T: if not recurse: return TensorDict( - device=self.device, - batch_size=self.batch_size, + device=self._device, + batch_size=self._batch_size, source={}, - # names=self.names if self._has_names() else None, names=self._td_dim_names, _run_checks=False, - _is_memmap=self._is_memmap, - _is_shared=self._is_shared, + _is_memmap=False, + _is_shared=False, ) return super().empty(recurse=recurse) diff --git a/tensordict/base.py b/tensordict/base.py index 74e4980f3..e9f0b7b5a 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3125,8 +3125,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): return self.lock_() if last_op == self.__class__.to_module.__name__: if is_tensor_collection(out): - with out.unlock_(): - return self.to_module(*args, **kwargs, swap_dest=out) + return self.to_module(*args, **kwargs, swap_dest=out) else: raise RuntimeError( "to_module cannot be used as a decorator when return_swap=False." @@ -3520,6 +3519,10 @@ def flatten_keys(self, separator: str = ".", inplace: bool = False) -> T: result._set_str( leaf_flat, self.get(leaf), validated=True, inplace=False ) + shared = result._is_shared = self._is_shared + mmap = result._is_memmap = self._is_memmap + if shared or mmap: + result._is_locked = True return result @cache # noqa: B019