From 04f6375532056eea42825581a3081f7e68b5fe3d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 24 Nov 2023 17:40:29 +0000 Subject: [PATCH] [Feature, Test] Add tests for partial update (#578) --- tensordict/_lazy.py | 53 +++++++++++++++++++++++------ tensordict/_td.py | 45 +++++++++++++++---------- tensordict/base.py | 72 ++++++++++++++++++++++++++++------------ tensordict/nn/params.py | 10 +++++- tensordict/persistent.py | 3 +- tensordict/utils.py | 8 +++++ test/test_tensordict.py | 35 +++++++++++++++++++ 7 files changed, 176 insertions(+), 50 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index d060beda2..979646776 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -19,7 +19,7 @@ import torch from functorch import dim as ftdim from tensordict._td import _SubTensorDict, _TensorDictKeysView, TensorDict -from tensordict._tensordict import _unravel_key_to_tuple, unravel_key +from tensordict._tensordict import _unravel_key_to_tuple, unravel_key_list from tensordict.base import ( _ACCEPTED_CLASSES, _is_tensor_collection, @@ -36,6 +36,7 @@ _getitem_batch_size, _is_number, _parse_to, + _prune_selected_keys, _renamed_inplace_method, _shape, _td_fields, @@ -1576,10 +1577,21 @@ def expand(self, *args: int, inplace: bool = False) -> T: return self return torch.stack(tensordicts, stack_dim) - def update(self, input_dict_or_td: T, clone: bool = False, **kwargs: Any) -> T: + def update( + self, + input_dict_or_td: T, + clone: bool = False, + *, + keys_to_update: Sequence[NestedKey] | None = None, + **kwargs: Any, + ) -> T: if input_dict_or_td is self: # no op return self + if keys_to_update is not None: + keys_to_update = unravel_key_list(keys_to_update) + if len(keys_to_update) == 0: + return self if ( isinstance(input_dict_or_td, LazyStackedTensorDict) @@ -1592,7 +1604,9 @@ def update(self, input_dict_or_td: T, clone: bool = False, **kwargs: Any) -> T: for td_dest, td_source in zip( self.tensordicts, input_dict_or_td.tensordicts ): - td_dest.update(td_source, clone=clone, **kwargs) + td_dest.update( + td_source, clone=clone, keys_to_update=keys_to_update, **kwargs + ) return self inplace = kwargs.get("inplace", False) @@ -1601,12 +1615,25 @@ def update(self, input_dict_or_td: T, clone: bool = False, **kwargs: Any) -> T: value = value.clone() elif clone: value = tree_map(torch.clone, value) - key = unravel_key(key) - if isinstance(key, tuple): + key = _unravel_key_to_tuple(key) + firstkey, subkey = key[0], key[1:] + if keys_to_update and not any( + firstkey == ktu if isinstance(ktu, str) else firstkey == ktu[0] + for ktu in keys_to_update + ): + continue + + if subkey: # we must check that the target is not a leaf - target = self._get_str(key[0], default=None) + target = self._get_str(firstkey, default=None) if is_tensor_collection(target): - target.update({key[1:]: value}, inplace=inplace, clone=clone) + sub_keys_to_update = _prune_selected_keys(keys_to_update, firstkey) + target.update( + {subkey: value}, + inplace=inplace, + clone=clone, + keys_to_update=sub_keys_to_update, + ) elif target is None: self._set_tuple(key, value, inplace=inplace, validated=False) else: @@ -1614,13 +1641,19 @@ def update(self, input_dict_or_td: T, clone: bool = False, **kwargs: Any) -> T: f"Type mismatch: self.get(key[0]) is {type(target)} but expected a tensor collection." ) else: - target = self._get_str(key, default=None) + target = self._get_str(firstkey, default=None) if is_tensor_collection(target) and ( is_tensor_collection(value) or isinstance(value, dict) ): - target.update(value, inplace=inplace, clone=clone) + sub_keys_to_update = _prune_selected_keys(keys_to_update, firstkey) + target.update( + value, + inplace=inplace, + clone=clone, + keys_to_update=sub_keys_to_update, + ) elif target is None or not is_tensor_collection(value): - self._set_str(key, value, inplace=inplace, validated=False) + self._set_str(firstkey, value, inplace=inplace, validated=False) else: raise TypeError( f"Type mismatch: self.get(key) is {type(target)} but value is of type {type(value)}." diff --git a/tensordict/_td.py b/tensordict/_td.py index 4b5ff73c5..37d66b72e 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -55,6 +55,7 @@ _NON_STR_KEY_ERR, _NON_STR_KEY_TUPLE_ERR, _parse_to, + _prune_selected_keys, _set_item, _set_max_batch_size, _shape, @@ -2105,18 +2106,18 @@ def update( # no op return self if keys_to_update is not None: + if len(keys_to_update) == 0: + return self keys_to_update = unravel_key_list(keys_to_update) - else: - keys_to_update = () keys = set(self.keys(False)) for key, value in input_dict_or_td.items(): key = _unravel_key_to_tuple(key) firstkey, subkey = key[0], key[1:] - if keys_to_update: - if (subkey and key in keys_to_update) or ( - not subkey and firstkey in keys_to_update - ): - continue + if keys_to_update and not any( + firstkey == ktu if isinstance(ktu, str) else firstkey == ktu[0] + for ktu in keys_to_update + ): + continue if clone and hasattr(value, "clone"): value = value.clone() elif clone: @@ -2127,12 +2128,22 @@ def update( if _is_tensor_collection(target_class): target = self._source.get(firstkey)._get_sub_tensordict(self.idx) if len(subkey): - target._set_tuple(subkey, value, inplace=False, validated=False) + sub_keys_to_update = _prune_selected_keys( + keys_to_update, firstkey + ) + target.update( + {subkey: value}, + inplace=False, + keys_to_update=sub_keys_to_update, + ) continue elif isinstance(value, dict) or _is_tensor_collection( value.__class__ ): - target.update(value) + sub_keys_to_update = _prune_selected_keys( + keys_to_update, firstkey + ) + target.update(value, keys_to_update=sub_keys_to_update) continue raise ValueError( f"Tried to replace a tensordict with an incompatible object of type {type(value)}" @@ -2173,17 +2184,17 @@ def update_at_( keys_to_update: Sequence[NestedKey] | None = None, ) -> _SubTensorDict: if keys_to_update is not None: + if len(keys_to_update) == 0: + return self keys_to_update = unravel_key_list(keys_to_update) - else: - keys_to_update = () for key, value in input_dict.items(): key = _unravel_key_to_tuple(key) - firstkey, *keys = key - if keys_to_update: - if (keys and key in keys_to_update) or ( - not keys and firstkey in keys_to_update - ): - continue + firstkey, _ = key[0], key[1:] + if keys_to_update and not any( + firstkey == ktu if isinstance(ktu, str) else firstkey == ktu[0] + for ktu in keys_to_update + ): + continue if not isinstance(value, tuple(_ACCEPTED_CLASSES)): raise TypeError( f"Expected value to be one of types {_ACCEPTED_CLASSES} " diff --git a/tensordict/base.py b/tensordict/base.py index 00646e236..9c279b709 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -34,6 +34,7 @@ _is_tensorclass, _KEY_ERROR, _proc_init, + _prune_selected_keys, _shape, _split_tensordict, _td_fields, @@ -50,7 +51,6 @@ lock_blocked, NestedKey, prod, - unravel_key, unravel_key_list, ) from torch import distributed as dist, multiprocessing as mp, nn, Tensor @@ -1871,17 +1871,17 @@ def update( # no op return self if keys_to_update is not None: + if len(keys_to_update) == 0: + return self keys_to_update = unravel_key_list(keys_to_update) - else: - keys_to_update = () for key, value in input_dict_or_td.items(): key = _unravel_key_to_tuple(key) firstkey, subkey = key[0], key[1:] - if keys_to_update: - if (subkey and key in keys_to_update) or ( - not subkey and firstkey in keys_to_update - ): - continue + if keys_to_update and not any( + firstkey == ktu if isinstance(ktu, str) else firstkey == ktu[0] + for ktu in keys_to_update + ): + continue target = self._get_str(firstkey, None) if clone and hasattr(value, "clone"): value = value.clone() @@ -1891,7 +1891,15 @@ def update( if target is not None: if _is_tensor_collection(type(target)): if subkey: - target.update({subkey: value}, inplace=inplace, clone=clone) + sub_keys_to_update = _prune_selected_keys( + keys_to_update, firstkey + ) + target.update( + {subkey: value}, + inplace=inplace, + clone=clone, + keys_to_update=sub_keys_to_update, + ) continue elif isinstance(value, (dict,)) or _is_tensor_collection( value.__class__ @@ -1899,17 +1907,33 @@ def update( if isinstance(value, LazyStackedTensorDict) and not isinstance( target, LazyStackedTensorDict ): + sub_keys_to_update = _prune_selected_keys( + keys_to_update, firstkey + ) self._set_tuple( key, LazyStackedTensorDict( *target.unbind(value.stack_dim), stack_dim=value.stack_dim, - ).update(value, inplace=inplace, clone=clone), + ).update( + value, + inplace=inplace, + clone=clone, + keys_to_update=sub_keys_to_update, + ), validated=True, inplace=False, ) else: - target.update(value, inplace=inplace, clone=clone) + sub_keys_to_update = _prune_selected_keys( + keys_to_update, firstkey + ) + target.update( + value, + inplace=inplace, + clone=clone, + keys_to_update=sub_keys_to_update, + ) continue self._set_tuple( key, @@ -1960,12 +1984,15 @@ def update_( # no op return self if keys_to_update is not None: + if len(keys_to_update) == 0: + return self keys_to_update = unravel_key_list(keys_to_update) - else: - keys_to_update = () for key, value in input_dict_or_td.items(): - key = unravel_key(key) - if key in keys_to_update: + firstkey, *nextkeys = _unravel_key_to_tuple(key) + if keys_to_update and not any( + firstkey == ktu if isinstance(ktu, str) else firstkey == ktu[0] + for ktu in keys_to_update + ): continue # if not isinstance(value, _accepted_classes): # raise TypeError( @@ -1974,7 +2001,7 @@ def update_( # ) if clone: value = value.clone() - self.set_(key, value) + self.set_((firstkey, *nextkeys), value) return self def update_at_( @@ -2025,12 +2052,15 @@ def update_at_( """ if keys_to_update is not None: + if len(keys_to_update) == 0: + return self keys_to_update = unravel_key_list(keys_to_update) - else: - keys_to_update = () for key, value in input_dict_or_td.items(): - key = unravel_key(key) - if key in keys_to_update: + firstkey, *nextkeys = _unravel_key_to_tuple(key) + if keys_to_update and not any( + firstkey == ktu if isinstance(ktu, str) else firstkey == ktu[0] + for ktu in keys_to_update + ): continue if not isinstance(value, tuple(_ACCEPTED_CLASSES)): raise TypeError( @@ -2039,7 +2069,7 @@ def update_at_( ) if clone: value = value.clone() - self.set_at_(key, value, idx) + self.set_at_((firstkey, *nextkeys), value, idx) return self @lock_blocked diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 781f6351b..788aceb5e 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -387,6 +387,8 @@ def update( input_dict_or_td: dict[str, CompatibleType] | TensorDictBase, clone: bool = False, inplace: bool = False, + *, + keys_to_update: Sequence[NestedKey] | None = None, ) -> TensorDictBase: if not self.no_convert: func = _maybe_make_param @@ -397,7 +399,13 @@ def update( else: input_dict_or_td = tree_map(func, input_dict_or_td) with self._param_td.unlock_(): - TensorDictBase.update(self, input_dict_or_td, clone=clone, inplace=inplace) + TensorDictBase.update( + self, + input_dict_or_td, + clone=clone, + inplace=inplace, + keys_to_update=keys_to_update, + ) self._reset_params() return self diff --git a/tensordict/persistent.py b/tensordict/persistent.py index 7f0ce86a6..e8aaf25fe 100644 --- a/tensordict/persistent.py +++ b/tensordict/persistent.py @@ -77,7 +77,8 @@ def __iter__(self): yield from self.tensordict._valid_keys() def __contains__(self, key): - if isinstance(key, tuple) and len(key) == 1: + key = _unravel_key_to_tuple(key) + if len(key) == 1: key = key[0] for a_key in self: if isinstance(a_key, tuple) and len(a_key) == 1: diff --git a/tensordict/utils.py b/tensordict/utils.py index b9e5b4b95..4c8308be8 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1740,3 +1740,11 @@ def _proc_init(base_seed, queue): torch.manual_seed(seed) np_seed = _generate_state(base_seed, worker_id) np.random.seed(np_seed) + + +def _prune_selected_keys(keys_to_update, prefix): + if keys_to_update is None: + return None + return tuple( + key[1:] for key in keys_to_update if isinstance(key, tuple) and key[0] == prefix + ) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 4ee3298f1..ab658c79c 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -3054,6 +3054,41 @@ def test_id(): assert lazy_legacy() test_id() + def test_update_select(self, td_name, device): + if td_name in ("memmap_td",): + pytest.skip(reason="update not possible with memory-mapped td") + td = getattr(self, td_name)(device) + t = lambda: torch.zeros(()).expand((4, 3, 2, 1)) + other_td = TensorDict( + { + "My": {"father": {"was": t(), "a": t()}, "relentlessly": t()}, + "self-improving": t(), + }, + batch_size=(4, 3, 2, 1), + ) + td.update( + other_td, + keys_to_update=(("My", ("father",), "was"), ("My", "relentlessly")), + ) + assert ("My", "father", "was") in td.keys(True) + assert ("My", ("father",), "was") in td.keys(True) + assert ("My", "relentlessly") in td.keys(True) + assert ("My", "father", "a") in td.keys(True) + assert ("self-improving",) not in td.keys(True) + t = lambda: torch.ones(()).expand((4, 3, 2, 1)) + other_td = TensorDict( + { + "My": {"father": {"was": t(), "a": t()}, "relentlessly": t()}, + "self-improving": t(), + }, + batch_size=(4, 3, 2, 1), + ) + td.update(other_td, keys_to_update=(("My", "relentlessly"),)) + assert (td["My", "relentlessly"] == 1).all() + assert (td["My", "father", "was"] == 0).all() + td.update(other_td, keys_to_update=(("My", ("father",), "was"),)) + assert (td["My", "father", "was"] == 1).all() + @pytest.mark.parametrize("device", [None, *get_available_devices()]) @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])