Skip to content

Commit

Permalink
[Feature, Test] Add tests for partial update (#578)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Nov 24, 2023
1 parent fb1b589 commit 04f6375
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 50 deletions.
53 changes: 43 additions & 10 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
from functorch import dim as ftdim
from tensordict._td import _SubTensorDict, _TensorDictKeysView, TensorDict
from tensordict._tensordict import _unravel_key_to_tuple, unravel_key
from tensordict._tensordict import _unravel_key_to_tuple, unravel_key_list
from tensordict.base import (
_ACCEPTED_CLASSES,
_is_tensor_collection,
Expand All @@ -36,6 +36,7 @@
_getitem_batch_size,
_is_number,
_parse_to,
_prune_selected_keys,
_renamed_inplace_method,
_shape,
_td_fields,
Expand Down Expand Up @@ -1576,10 +1577,21 @@ def expand(self, *args: int, inplace: bool = False) -> T:
return self
return torch.stack(tensordicts, stack_dim)

def update(self, input_dict_or_td: T, clone: bool = False, **kwargs: Any) -> T:
def update(
self,
input_dict_or_td: T,
clone: bool = False,
*,
keys_to_update: Sequence[NestedKey] | None = None,
**kwargs: Any,
) -> T:
if input_dict_or_td is self:
# no op
return self
if keys_to_update is not None:
keys_to_update = unravel_key_list(keys_to_update)
if len(keys_to_update) == 0:
return self

if (
isinstance(input_dict_or_td, LazyStackedTensorDict)
Expand All @@ -1592,7 +1604,9 @@ def update(self, input_dict_or_td: T, clone: bool = False, **kwargs: Any) -> T:
for td_dest, td_source in zip(
self.tensordicts, input_dict_or_td.tensordicts
):
td_dest.update(td_source, clone=clone, **kwargs)
td_dest.update(
td_source, clone=clone, keys_to_update=keys_to_update, **kwargs
)
return self

inplace = kwargs.get("inplace", False)
Expand All @@ -1601,26 +1615,45 @@ def update(self, input_dict_or_td: T, clone: bool = False, **kwargs: Any) -> T:
value = value.clone()
elif clone:
value = tree_map(torch.clone, value)
key = unravel_key(key)
if isinstance(key, tuple):
key = _unravel_key_to_tuple(key)
firstkey, subkey = key[0], key[1:]
if keys_to_update and not any(
firstkey == ktu if isinstance(ktu, str) else firstkey == ktu[0]
for ktu in keys_to_update
):
continue

if subkey:
# we must check that the target is not a leaf
target = self._get_str(key[0], default=None)
target = self._get_str(firstkey, default=None)
if is_tensor_collection(target):
target.update({key[1:]: value}, inplace=inplace, clone=clone)
sub_keys_to_update = _prune_selected_keys(keys_to_update, firstkey)
target.update(
{subkey: value},
inplace=inplace,
clone=clone,
keys_to_update=sub_keys_to_update,
)
elif target is None:
self._set_tuple(key, value, inplace=inplace, validated=False)
else:
raise TypeError(
f"Type mismatch: self.get(key[0]) is {type(target)} but expected a tensor collection."
)
else:
target = self._get_str(key, default=None)
target = self._get_str(firstkey, default=None)
if is_tensor_collection(target) and (
is_tensor_collection(value) or isinstance(value, dict)
):
target.update(value, inplace=inplace, clone=clone)
sub_keys_to_update = _prune_selected_keys(keys_to_update, firstkey)
target.update(
value,
inplace=inplace,
clone=clone,
keys_to_update=sub_keys_to_update,
)
elif target is None or not is_tensor_collection(value):
self._set_str(key, value, inplace=inplace, validated=False)
self._set_str(firstkey, value, inplace=inplace, validated=False)
else:
raise TypeError(
f"Type mismatch: self.get(key) is {type(target)} but value is of type {type(value)}."
Expand Down
45 changes: 28 additions & 17 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
_NON_STR_KEY_ERR,
_NON_STR_KEY_TUPLE_ERR,
_parse_to,
_prune_selected_keys,
_set_item,
_set_max_batch_size,
_shape,
Expand Down Expand Up @@ -2105,18 +2106,18 @@ def update(
# no op
return self
if keys_to_update is not None:
if len(keys_to_update) == 0:
return self
keys_to_update = unravel_key_list(keys_to_update)
else:
keys_to_update = ()
keys = set(self.keys(False))
for key, value in input_dict_or_td.items():
key = _unravel_key_to_tuple(key)
firstkey, subkey = key[0], key[1:]
if keys_to_update:
if (subkey and key in keys_to_update) or (
not subkey and firstkey in keys_to_update
):
continue
if keys_to_update and not any(
firstkey == ktu if isinstance(ktu, str) else firstkey == ktu[0]
for ktu in keys_to_update
):
continue
if clone and hasattr(value, "clone"):
value = value.clone()
elif clone:
Expand All @@ -2127,12 +2128,22 @@ def update(
if _is_tensor_collection(target_class):
target = self._source.get(firstkey)._get_sub_tensordict(self.idx)
if len(subkey):
target._set_tuple(subkey, value, inplace=False, validated=False)
sub_keys_to_update = _prune_selected_keys(
keys_to_update, firstkey
)
target.update(
{subkey: value},
inplace=False,
keys_to_update=sub_keys_to_update,
)
continue
elif isinstance(value, dict) or _is_tensor_collection(
value.__class__
):
target.update(value)
sub_keys_to_update = _prune_selected_keys(
keys_to_update, firstkey
)
target.update(value, keys_to_update=sub_keys_to_update)
continue
raise ValueError(
f"Tried to replace a tensordict with an incompatible object of type {type(value)}"
Expand Down Expand Up @@ -2173,17 +2184,17 @@ def update_at_(
keys_to_update: Sequence[NestedKey] | None = None,
) -> _SubTensorDict:
if keys_to_update is not None:
if len(keys_to_update) == 0:
return self
keys_to_update = unravel_key_list(keys_to_update)
else:
keys_to_update = ()
for key, value in input_dict.items():
key = _unravel_key_to_tuple(key)
firstkey, *keys = key
if keys_to_update:
if (keys and key in keys_to_update) or (
not keys and firstkey in keys_to_update
):
continue
firstkey, _ = key[0], key[1:]
if keys_to_update and not any(
firstkey == ktu if isinstance(ktu, str) else firstkey == ktu[0]
for ktu in keys_to_update
):
continue
if not isinstance(value, tuple(_ACCEPTED_CLASSES)):
raise TypeError(
f"Expected value to be one of types {_ACCEPTED_CLASSES} "
Expand Down
72 changes: 51 additions & 21 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
_is_tensorclass,
_KEY_ERROR,
_proc_init,
_prune_selected_keys,
_shape,
_split_tensordict,
_td_fields,
Expand All @@ -50,7 +51,6 @@
lock_blocked,
NestedKey,
prod,
unravel_key,
unravel_key_list,
)
from torch import distributed as dist, multiprocessing as mp, nn, Tensor
Expand Down Expand Up @@ -1871,17 +1871,17 @@ def update(
# no op
return self
if keys_to_update is not None:
if len(keys_to_update) == 0:
return self
keys_to_update = unravel_key_list(keys_to_update)
else:
keys_to_update = ()
for key, value in input_dict_or_td.items():
key = _unravel_key_to_tuple(key)
firstkey, subkey = key[0], key[1:]
if keys_to_update:
if (subkey and key in keys_to_update) or (
not subkey and firstkey in keys_to_update
):
continue
if keys_to_update and not any(
firstkey == ktu if isinstance(ktu, str) else firstkey == ktu[0]
for ktu in keys_to_update
):
continue
target = self._get_str(firstkey, None)
if clone and hasattr(value, "clone"):
value = value.clone()
Expand All @@ -1891,25 +1891,49 @@ def update(
if target is not None:
if _is_tensor_collection(type(target)):
if subkey:
target.update({subkey: value}, inplace=inplace, clone=clone)
sub_keys_to_update = _prune_selected_keys(
keys_to_update, firstkey
)
target.update(
{subkey: value},
inplace=inplace,
clone=clone,
keys_to_update=sub_keys_to_update,
)
continue
elif isinstance(value, (dict,)) or _is_tensor_collection(
value.__class__
):
if isinstance(value, LazyStackedTensorDict) and not isinstance(
target, LazyStackedTensorDict
):
sub_keys_to_update = _prune_selected_keys(
keys_to_update, firstkey
)
self._set_tuple(
key,
LazyStackedTensorDict(
*target.unbind(value.stack_dim),
stack_dim=value.stack_dim,
).update(value, inplace=inplace, clone=clone),
).update(
value,
inplace=inplace,
clone=clone,
keys_to_update=sub_keys_to_update,
),
validated=True,
inplace=False,
)
else:
target.update(value, inplace=inplace, clone=clone)
sub_keys_to_update = _prune_selected_keys(
keys_to_update, firstkey
)
target.update(
value,
inplace=inplace,
clone=clone,
keys_to_update=sub_keys_to_update,
)
continue
self._set_tuple(
key,
Expand Down Expand Up @@ -1960,12 +1984,15 @@ def update_(
# no op
return self
if keys_to_update is not None:
if len(keys_to_update) == 0:
return self
keys_to_update = unravel_key_list(keys_to_update)
else:
keys_to_update = ()
for key, value in input_dict_or_td.items():
key = unravel_key(key)
if key in keys_to_update:
firstkey, *nextkeys = _unravel_key_to_tuple(key)
if keys_to_update and not any(
firstkey == ktu if isinstance(ktu, str) else firstkey == ktu[0]
for ktu in keys_to_update
):
continue
# if not isinstance(value, _accepted_classes):
# raise TypeError(
Expand All @@ -1974,7 +2001,7 @@ def update_(
# )
if clone:
value = value.clone()
self.set_(key, value)
self.set_((firstkey, *nextkeys), value)
return self

def update_at_(
Expand Down Expand Up @@ -2025,12 +2052,15 @@ def update_at_(
"""
if keys_to_update is not None:
if len(keys_to_update) == 0:
return self
keys_to_update = unravel_key_list(keys_to_update)
else:
keys_to_update = ()
for key, value in input_dict_or_td.items():
key = unravel_key(key)
if key in keys_to_update:
firstkey, *nextkeys = _unravel_key_to_tuple(key)
if keys_to_update and not any(
firstkey == ktu if isinstance(ktu, str) else firstkey == ktu[0]
for ktu in keys_to_update
):
continue
if not isinstance(value, tuple(_ACCEPTED_CLASSES)):
raise TypeError(
Expand All @@ -2039,7 +2069,7 @@ def update_at_(
)
if clone:
value = value.clone()
self.set_at_(key, value, idx)
self.set_at_((firstkey, *nextkeys), value, idx)
return self

@lock_blocked
Expand Down
10 changes: 9 additions & 1 deletion tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,8 @@ def update(
input_dict_or_td: dict[str, CompatibleType] | TensorDictBase,
clone: bool = False,
inplace: bool = False,
*,
keys_to_update: Sequence[NestedKey] | None = None,
) -> TensorDictBase:
if not self.no_convert:
func = _maybe_make_param
Expand All @@ -397,7 +399,13 @@ def update(
else:
input_dict_or_td = tree_map(func, input_dict_or_td)
with self._param_td.unlock_():
TensorDictBase.update(self, input_dict_or_td, clone=clone, inplace=inplace)
TensorDictBase.update(
self,
input_dict_or_td,
clone=clone,
inplace=inplace,
keys_to_update=keys_to_update,
)
self._reset_params()
return self

Expand Down
3 changes: 2 additions & 1 deletion tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def __iter__(self):
yield from self.tensordict._valid_keys()

def __contains__(self, key):
if isinstance(key, tuple) and len(key) == 1:
key = _unravel_key_to_tuple(key)
if len(key) == 1:
key = key[0]
for a_key in self:
if isinstance(a_key, tuple) and len(a_key) == 1:
Expand Down
8 changes: 8 additions & 0 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1740,3 +1740,11 @@ def _proc_init(base_seed, queue):
torch.manual_seed(seed)
np_seed = _generate_state(base_seed, worker_id)
np.random.seed(np_seed)


def _prune_selected_keys(keys_to_update, prefix):
if keys_to_update is None:
return None
return tuple(
key[1:] for key in keys_to_update if isinstance(key, tuple) and key[0] == prefix
)
Loading

0 comments on commit 04f6375

Please sign in to comment.