From 4939c7409f3bc09fa70e9b2bd9fb8a8c788068ae Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 14 Sep 2023 09:04:33 -0400 Subject: [PATCH 1/7] init --- tensordict/nn/params.py | 45 +++++++++++++++++++++++++++++++++++++++- tensordict/tensordict.py | 21 +++++++++++++------ test/test_nn.py | 41 ++++++++++++++++++++++++++++++++++++ 3 files changed, 100 insertions(+), 7 deletions(-) diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index a6a602bb0..6bf3b18a0 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -10,7 +10,7 @@ import re from copy import copy from functools import wraps -from typing import Any, Callable, Iterator, Sequence +from typing import Any, Callable, Iterator, OrderedDict, Sequence import torch @@ -744,6 +744,49 @@ def values( continue yield self._apply_get_post_hook(v) + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + sd = self._param_td.flatten_keys(".").state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) + return sd + + def load_state_dict(self, state_dict: OrderedDict[str, Any], *Args, **kwargs): + state_dict_tensors = {} + state_dict = dict(state_dict) + for k, v in list(state_dict.items()): + if isinstance(v, torch.Tensor): + del state_dict[k] + state_dict_tensors[k] = v + state_dict_tensors = dict( + TensorDict(state_dict_tensors, []).unflatten_keys(".") + ) + self.data.load_state_dict({**state_dict_tensors, **state_dict}, *Args, **kwargs) + return self + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + data = ( + TensorDict( + { + key: val + for key, val in state_dict.items() + if key.startswith(prefix) and val is not None + }, + [], + ) + .unflatten_keys(".") + .get(prefix[:-1]) + ) + self.data.load_state_dict(data) + def items( self, include_nested: bool = False, leaves_only: bool = False ) -> Iterator[CompatibleType]: diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index 9ecf1c0cb..1839a88c7 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -691,11 +691,17 @@ def is_shared(self) -> bool: return self.device.type == "cuda" or self._is_shared return self._is_shared - def state_dict(self) -> OrderedDict[str, Any]: + def state_dict( + self, destination=None, prefix="", keep_vars=False + ) -> OrderedDict[str, Any]: out = collections.OrderedDict() for key, item in self.apply(memmap_tensor_as_tensor).items(): - out[key] = ( - item if not _is_tensor_collection(item.__class__) else item.state_dict() + out[prefix + key] = ( + item + if keep_vars + else item.detach().clone() + if not _is_tensor_collection(item.__class__) + else item.state_dict(keep_vars=keep_vars) ) if "__batch_size" in out: raise KeyError( @@ -705,15 +711,18 @@ def state_dict(self) -> OrderedDict[str, Any]: raise KeyError( "Cannot retrieve the state_dict of a TensorDict with `'__batch_size'` key" ) - out["__batch_size"] = self.batch_size - out["__device"] = self.device + out[prefix + "__batch_size"] = self.batch_size + out[prefix + "__device"] = self.device + if destination is not None: + destination.update(out) + return destination return out def load_state_dict(self, state_dict: OrderedDict[str, Any]) -> T: # copy since we'll be using pop state_dict = copy(state_dict) self.batch_size = state_dict.pop("__batch_size") - device = state_dict.pop("__device") + device = state_dict.pop("__device", None) if device is not None: self.to(device) for key, item in state_dict.items(): diff --git a/test/test_nn.py b/test/test_nn.py index e9f76114e..e4771872e 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -3159,6 +3159,47 @@ def test_add_scale_sequence(self, num_outputs=4): assert (scale > 0).all() +class TestStateDict: + @pytest.mark.parametrize("detach", [True, False]) + def test_sd_params(self, detach): + td = TensorDict({"1": 1, "2": 2, "3": {"3": 3}}, []) + td = TensorDictParams(td) + if detach: + sd = td.detach().clone().zero_().state_dict() + else: + sd = td.state_dict() + sd = { + k: v if not isinstance(v, torch.Tensor) else v * 0 + for k, v in sd.items() + } + print(sd) + # do some op to create a graph + td.apply(lambda x: x + 1) + # load the data + td.load_state_dict(sd) + # check that data has been loaded + assert (td == 0).all() + + def test_sd_module(self): + td = TensorDict({"1": 1, "2": 2, "3": {"3": 3}}, []) + td = TensorDictParams(td) + module = nn.Linear(3, 4) + module.td = td + + sd = module.state_dict() + assert "td.1" in sd + assert "td.3.3" in sd + sd = {k: v * 0 if isinstance(v, torch.Tensor) else v for k, v in sd.items()} + + # load the data + module.load_state_dict(sd) + + # check that data has been loaded + assert (module.td == 0).all() + for val in td.values(True, True): + assert isinstance(val, nn.Parameter) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) From 196d01bd47228b2c99626ee72bc71c79a9bb84dd Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 14 Sep 2023 10:01:03 -0400 Subject: [PATCH 2/7] amend --- tensordict/nn/params.py | 8 ++- tensordict/tensorclass.py | 18 ++++-- tensordict/tensordict.py | 113 ++++++++++++++++++++++++++++++++++++-- test/test_tensordict.py | 28 ++++++++++ 4 files changed, 155 insertions(+), 12 deletions(-) diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 6bf3b18a0..5bc4fb723 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -750,7 +750,9 @@ def state_dict(self, *args, destination=None, prefix="", keep_vars=False): ) return sd - def load_state_dict(self, state_dict: OrderedDict[str, Any], *Args, **kwargs): + def load_state_dict( + self, state_dict: OrderedDict[str, Any], strict=True, assign=False + ): state_dict_tensors = {} state_dict = dict(state_dict) for k, v in list(state_dict.items()): @@ -760,7 +762,9 @@ def load_state_dict(self, state_dict: OrderedDict[str, Any], *Args, **kwargs): state_dict_tensors = dict( TensorDict(state_dict_tensors, []).unflatten_keys(".") ) - self.data.load_state_dict({**state_dict_tensors, **state_dict}, *Args, **kwargs) + self.data.load_state_dict( + {**state_dict_tensors, **state_dict}, strict=True, assign=False + ) return self def _load_from_state_dict( diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 2113a81e2..657cb814e 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -740,14 +740,22 @@ def _batch_size_setter(self, new_size: torch.Size) -> None: # noqa: D417 self._tensordict._batch_size_setter(new_size) -def _state_dict(self) -> dict[str, Any]: +def _state_dict( + self, destination=None, prefix="", keep_vars=False, flatten=False +) -> dict[str, Any]: """Returns a state_dict dictionary that can be used to save and load data from a tensorclass.""" - state_dict = {"_tensordict": self._tensordict.state_dict()} + state_dict = { + "_tensordict": self._tensordict.state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars, flatten=flatten + ) + } state_dict["_non_tensordict"] = copy(self._non_tensordict) return state_dict -def _load_state_dict(self, state_dict: dict[str, Any]): +def _load_state_dict( + self, state_dict: dict[str, Any], strict=True, assign=False, from_flatten=False +): """Loads a state_dict attemptedly in-place on the destination tensorclass.""" for key, item in state_dict.items(): # keys will never be nested which facilitates everything, but let's @@ -778,7 +786,9 @@ def _load_state_dict(self, state_dict: dict[str, Any]): f"Key '{sub_key}' wasn't expected in the state-dict." ) - self._tensordict.load_state_dict(item) + self._tensordict.load_state_dict( + item, strict=strict, assign=assign, from_flatten=from_flatten + ) else: raise KeyError(f"Key '{key}' wasn't expected in the state-dict.") diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index 1839a88c7..d9a3ff5ea 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -692,10 +692,46 @@ def is_shared(self) -> bool: return self._is_shared def state_dict( - self, destination=None, prefix="", keep_vars=False + self, + destination=None, + prefix="", + keep_vars=False, + flatten=False, ) -> OrderedDict[str, Any]: + """Produces a state_dict from the tensordict. The structure of the state-dict will still be nested, unless ``flatten`` is set to ``True``. + + A tensordict state-dict contains all the tensors and meta-data needed + to rebuild the tensordict (names are currently not supported). + + Args: + destination (dict, optional): If provided, the state of module will + be updated into the dict and the same object is returned. + Otherwise, an ``OrderedDict`` will be created and returned. + Default: ``None``. + prefix (str, optional): a prefix added to parameter and buffer + names to compose the keys in state_dict. Default: ``''``. + keep_vars (bool, optional): by default the :class:`~torch.Tensor` s + returned in the state dict are detached from autograd. If it's + set to ``True``, detaching will not be performed. + Default: ``False``. + flatten (bool, optional): whether the structure should be flattened + with the ``"."`` character or not. + Defaults to ``False``. + + Examples: + >>> data = TensorDict({"1": 1, "2": 2, "3": {"3": 3}}, []) + >>> sd = data.state_dict() + >>> print(sd) + OrderedDict([('1', tensor(1)), ('2', tensor(2)), ('3', OrderedDict([('3', tensor(3)), ('__batch_size', torch.Size([])), ('__device', None)])), ('__batch_size', torch.Size([])), ('__device', None)]) + >>> sd = data.state_dict(flatten=True) + OrderedDict([('1', tensor(1)), ('2', tensor(2)), ('3.3', tensor(3)), ('__batch_size', torch.Size([])), ('__device', None)]) + + """ out = collections.OrderedDict() - for key, item in self.apply(memmap_tensor_as_tensor).items(): + source = self.apply(memmap_tensor_as_tensor) + if flatten: + source = source.flatten_keys(".") + for key, item in source.items(): out[prefix + key] = ( item if keep_vars @@ -718,7 +754,63 @@ def state_dict( return destination return out - def load_state_dict(self, state_dict: OrderedDict[str, Any]) -> T: + def load_state_dict( + self, + state_dict: OrderedDict[str, Any], + strict=True, + assign=False, + from_flatten=False, + ) -> T: + """Loads a state-dict, formatted as in :meth:`~.state_dict`, into the tensordict. + + Args: + state_dict (OrderedDict): the state_dict of to be copied. + strict (bool, optional): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` + assign (bool, optional): whether to assign items in the state + dictionary to their corresponding keys in the tensordict instead + of copying them inplace into the module's current parameters and buffers. + When ``False``, the properties of the tensors in the current + module are preserved while when ``True``, the properties of the + Tensors in the state dict are preserved. + Default: ``False`` + from_flatten (bool, optional): if ``True``, the input state_dict is + assumed to be flattened. + Defaults to ``False``. + + Examples: + >>> data = TensorDict({"1": 1, "2": 2, "3": {"3": 3}}, []) + >>> data_zeroed = TensorDict({"1": 0, "2": 0, "3": {"3": 0}}, []) + >>> sd = data.state_dict() + >>> data_zeroed.load_state_dict(sd) + >>> print(data_zeroed["3", "3"]) + tensor(3) + >>> # with flattening + >>> data_zeroed = TensorDict({"1": 0, "2": 0, "3": {"3": 0}}, []) + >>> data_zeroed.load_state_dict(data.state_dict(flatten=True), from_flatten=True) + >>> print(data_zeroed["3", "3"]) + tensor(3) + + + """ + if from_flatten: + self_flatten = self.flatten_keys(".") + self_flatten.load_state_dict(state_dict, strict=strict, assign=assign) + if not assign: + # modifications are done in-place so we should be fine returning self + return self + else: + # run a check over keys, if we any key with a '.' in name we're doomed + DOT_ERROR = "Cannot use load_state_dict(..., from_flatten=True, assign=True) when some keys contain a dot character." + for key in self.keys(True, True): + if isinstance(key, tuple): + for subkey in key: + if "." in subkey: + raise RuntimeError(DOT_ERROR) + elif "." in key: + raise RuntimeError(DOT_ERROR) + return self.update(self_flatten.unflatten_keys(".")) # copy since we'll be using pop state_dict = copy(state_dict) self.batch_size = state_dict.pop("__batch_size") @@ -729,11 +821,20 @@ def load_state_dict(self, state_dict: OrderedDict[str, Any]) -> T: if isinstance(item, dict): self.set( key, - self.get(key, default=TensorDict({}, [])).load_state_dict(item), - inplace=True, + self.get(key, default=TensorDict({}, [])).load_state_dict( + item, assign=assign, strict=strict + ), + inplace=not assign, ) else: - self.set(key, item, inplace=True) + self.set(key, item, inplace=not assign) + if strict and set(state_dict.keys()) != set(self.keys()): + set_sd = set(state_dict.keys()) + set_td = set(self.keys()) + raise RuntimeError( + "Cannot load state-dict because the key sets don't match: got " + f"state_dict extra keys \n{set_sd-set_td}\n and tensordict extra keys\n{set_td-set_sd}\n" + ) return self def is_memmap(self) -> bool: diff --git a/test/test_tensordict.py b/test/test_tensordict.py index fa7303738..c473421f3 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -1380,6 +1380,34 @@ def test_cpu_cuda(self, td_name, device): assert td_device.device == torch.device("cuda") assert td_back.device == torch.device("cpu") + def test_state_dict(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + sd = td.state_dict() + td_zero = td.clone().detach().zero_() + td_zero.load_state_dict(sd) + assert_allclose_td(td, td_zero) + + def test_state_dict_strict(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + sd = td.state_dict() + td_zero = td.clone().detach().zero_() + del sd["a"] + td_zero.load_state_dict(sd, strict=False) + with pytest.raises(RuntimeError): + td_zero.load_state_dict(sd, strict=True) + + def test_state_dict_assign(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + sd = td.state_dict() + td_zero = td.clone().detach().zero_() + shallow_copy = td_zero.clone(False) + td_zero.load_state_dict(sd, assign=True) + assert (shallow_copy == 0).all() + assert_allclose_td(td, td_zero) + @pytest.mark.parametrize("dim", range(4)) def test_unbind(self, td_name, device, dim): if td_name not in ["sub_td", "idx_td", "td_reset_bs"]: From 30e1030ae944ec837026f35c6357deb92a06ccfc Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 14 Sep 2023 10:07:56 -0400 Subject: [PATCH 3/7] amend --- tensordict/tensordict.py | 44 ++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index d9a3ff5ea..1c27542c8 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -376,7 +376,7 @@ def from_module(module, as_module: bool = False): """Copies the params and buffers of a module in a tensordict. Args: - as_module (bool, optional): if ``True``, a :class:`tensordict.nn.TensorDictParams` + as_module (bool, optional): if ``True``, a :class:`~tensordict.nn.TensorDictParams` instance will be returned which can be used to store parameters within a :class:`torch.nn.Module`. Defaults to ``False``. @@ -704,13 +704,13 @@ def state_dict( to rebuild the tensordict (names are currently not supported). Args: - destination (dict, optional): If provided, the state of module will + destination (dict, optional): If provided, the state of tensordict will be updated into the dict and the same object is returned. Otherwise, an ``OrderedDict`` will be created and returned. Default: ``None``. - prefix (str, optional): a prefix added to parameter and buffer + prefix (str, optional): a prefix added to tensor names to compose the keys in state_dict. Default: ``''``. - keep_vars (bool, optional): by default the :class:`~torch.Tensor` s + keep_vars (bool, optional): by default the :class:`torch.Tensor` s returned in the state dict are detached from autograd. If it's set to ``True``, detaching will not be performed. Default: ``False``. @@ -766,11 +766,11 @@ def load_state_dict( Args: state_dict (OrderedDict): the state_dict of to be copied. strict (bool, optional): whether to strictly enforce that the keys - in :attr:`state_dict` match the keys returned by this module's - :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` + in :attr:`state_dict` match the keys returned by this tensordict's + :meth:`torch.nn.Module.state_dict` function. Default: ``True`` assign (bool, optional): whether to assign items in the state dictionary to their corresponding keys in the tensordict instead - of copying them inplace into the module's current parameters and buffers. + of copying them inplace into the tensordict's current tensors. When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. @@ -2522,7 +2522,7 @@ def to_h5( **kwargs: kwargs to be passed to :meth:`h5py.File.create_dataset`. Returns: - A :class:`PersitentTensorDict` instance linked to the newly created file. + A :class:`~.tensordict.PersitentTensorDict` instance linked to the newly created file. Examples: >>> import tempfile @@ -2660,7 +2660,7 @@ def clone(self, recurse: bool = True) -> T: TensorDict will be copied too. Default is `True`. .. note:: - For some TensorDictBase subtypes, such as :class:`SubTensorDict`, cloning + For some TensorDictBase subtypes, such as :class:`~.tensordict.SubTensorDict`, cloning recursively makes little sense (in this specific case it would involve copying the parent tensordict too). In those cases, :meth:`~.clone` will fall back onto :meth:`~.to_tensordict`. @@ -2732,7 +2732,7 @@ def to(self, *args, **kwargs) -> T: other (TensorDictBase, optional): TensorDict instance whose dtype and device are the desired dtype and device for all tensors in this TensorDict. - .. note:: Since :class:`TensorDictBase` instances do not have + .. note:: Since :class:`~tensordict.TensorDictBase` instances do not have a dtype, the dtype is gathered from the example leaves. If there are more than one dtype, then no dtype casting is undertook. @@ -4040,7 +4040,7 @@ def __init__( @classmethod def from_dict(cls, input_dict, batch_size=None, device=None, batch_dims=None): - """Returns a TensorDict created from a dictionary or another :class:`TensorDict`. + """Returns a TensorDict created from a dictionary or another :class:`~.tensordict.TensorDict`. If ``batch_size`` is not specified, returns the maximum batch size possible. @@ -5974,8 +5974,8 @@ def clone(self, recurse: bool = True) -> SubTensorDict: Args: recurse (bool, optional): if ``True`` (default), a regular - :class:`TensorDict` instance will be created from the :class:`SubTensorDict`. - Otherwise, another :class:`SubTensorDict` with identical content + :class:`~.tensordict.TensorDict` instance will be created from the :class:`~.tensordict.SubTensorDict`. + Otherwise, another :class:`~.tensordict.SubTensorDict` with identical content will be returned. Examples: @@ -8233,8 +8233,8 @@ def clone(self, recurse: bool = True) -> T: Args: recurse (bool, optional): if ``True`` (default), a regular - :class:`TensorDict` instance will be returned. - Otherwise, another :class:`SubTensorDict` with identical content + :class:`~.tensordict.TensorDict` instance will be returned. + Otherwise, another :class:`~.tensordict.SubTensorDict` with identical content will be returned. """ if not recurse: @@ -8917,17 +8917,17 @@ def dense_stack_tds( td_list: Sequence[TensorDictBase] | LazyStackedTensorDict, dim: int = None, ) -> T: - """Densely stack a list of :class:`tensordict.TensorDictBase` objects (or a :class:`tensordict.LazyStackedTensorDict`) given that they have the same structure. + """Densely stack a list of :class:`~tensordict.TensorDictBase` objects (or a :class:`~tensordict.LazyStackedTensorDict`) given that they have the same structure. - This function is called with a list of :class:`tensordict.TensorDictBase` (either passed directly or obtrained from - a :class:`tensordict.LazyStackedTensorDict`). - Instead of calling ``torch.stack(td_list)``, which would return a :class:`tensordict.LazyStackedTensorDict`, + This function is called with a list of :class:`~tensordict.TensorDictBase` (either passed directly or obtrained from + a :class:`~tensordict.LazyStackedTensorDict`). + Instead of calling ``torch.stack(td_list)``, which would return a :class:`~tensordict.LazyStackedTensorDict`, this function expands the first element of the input list and stacks the input list onto that element. This works only when all the elements of the input list have the same structure. - The :class:`tensordict.TensorDictBase` returned will have the same type of the elements of the input list. + The :class:`~tensordict.TensorDictBase` returned will have the same type of the elements of the input list. - This function is useful when some of the :class:`tensordict.TensorDictBase` objects that need to be stacked - are :class:`tensordict.LazyStackedTensorDict` or have :class:`tensordict.LazyStackedTensorDict` + This function is useful when some of the :class:`~tensordict.TensorDictBase` objects that need to be stacked + are :class:`~tensordict.LazyStackedTensorDict` or have :class:`~tensordict.LazyStackedTensorDict` among entries (or nested entries). In those cases, calling ``torch.stack(td_list).to_tensordict()`` is infeasible. Thus, this function provides an alternative for densely stacking the list provided. From 9a517dd5653c20d03d009cc6a3356b55ea541145 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 14 Sep 2023 11:09:04 -0400 Subject: [PATCH 4/7] fixes --- tensordict/tensordict.py | 22 ++++++++++++---------- test/test_tensordict.py | 4 ++-- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index 1c27542c8..a3c89f4bf 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -732,13 +732,13 @@ def state_dict( if flatten: source = source.flatten_keys(".") for key, item in source.items(): - out[prefix + key] = ( - item - if keep_vars - else item.detach().clone() - if not _is_tensor_collection(item.__class__) - else item.state_dict(keep_vars=keep_vars) - ) + if not _is_tensor_collection(item.__class__): + if not keep_vars: + out[prefix + key] = item.detach().clone() + else: + out[prefix + key] = item + else: + out[prefix + key] = item.state_dict(keep_vars=keep_vars) if "__batch_size" in out: raise KeyError( "Cannot retrieve the state_dict of a TensorDict with `'__batch_size'` key" @@ -747,8 +747,8 @@ def state_dict( raise KeyError( "Cannot retrieve the state_dict of a TensorDict with `'__batch_size'` key" ) - out[prefix + "__batch_size"] = self.batch_size - out[prefix + "__device"] = self.device + out[prefix + "__batch_size"] = source.batch_size + out[prefix + "__device"] = source.device if destination is not None: destination.update(out) return destination @@ -816,7 +816,9 @@ def load_state_dict( self.batch_size = state_dict.pop("__batch_size") device = state_dict.pop("__device", None) if device is not None: - self.to(device) + if device != self.device: + raise RuntimeError("Loading data from another device is not yet supproted.") + for key, item in state_dict.items(): if isinstance(item, dict): self.set( diff --git a/test/test_tensordict.py b/test/test_tensordict.py index c473421f3..8a6612118 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -5143,7 +5143,7 @@ def test_inplace(self, save_name): td.memmap_() assert isinstance(td["b", "c"], MemmapTensor) - app_state = {"state": torchsnapshot.StateDict(**{save_name: td.state_dict()})} + app_state = {"state": torchsnapshot.StateDict(**{save_name: td.state_dict(keep_vars=True)})} path = f"/tmp/{uuid.uuid4()}" snapshot = torchsnapshot.Snapshot.take(app_state=app_state, path=path) @@ -5160,7 +5160,7 @@ def test_inplace(self, save_name): td_dest.memmap_() assert isinstance(td_dest["b", "c"], MemmapTensor) app_state = { - "state": torchsnapshot.StateDict(**{save_name: td_dest.state_dict()}) + "state": torchsnapshot.StateDict(**{save_name: td_dest.state_dict(keep_vars=True)}) } snapshot.restore(app_state=app_state) From 84ac82652dd64e3a6e716a1a9285294f0f264479 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 14 Sep 2023 11:11:53 -0400 Subject: [PATCH 5/7] update doc --- docs/source/saving.rst | 14 +++++++------- tensordict/tensordict.py | 4 +++- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/docs/source/saving.rst b/docs/source/saving.rst index 06a26abc6..f03d7b62b 100644 --- a/docs/source/saving.rst +++ b/docs/source/saving.rst @@ -79,17 +79,17 @@ that we will re-populate with the saved data. Again, two lines of code are sufficient to save the data: >>> app_state = { - ... "state": torchsnapshot.StateDict(tensordict=tensordict_source.state_dict()) + ... "state": torchsnapshot.StateDict(tensordict=tensordict_source.state_dict(keep_vars=True)) ... } >>> snapshot = torchsnapshot.Snapshot.take(app_state=app_state, path="/path/to/my/snapshot") We have been using :obj:`torchsnapshot.StateDict` and we explicitly called -:obj:`my_tensordict_source.state_dict()`, unlike the previous example. +:obj:`my_tensordict_source.state_dict(keep_vars=True)`, unlike the previous example. Now, to load this onto a destination tensordict: >>> snapshot = Snapshot(path="/path/to/my/snapshot") >>> app_state = { - ... "state": torchsnapshot.StateDict(tensordict=tensordict_target.state_dict()) + ... "state": torchsnapshot.StateDict(tensordict=tensordict_target.state_dict(keep_vars=True)) ... } >>> snapshot.restore(app_state=app_state) @@ -117,7 +117,7 @@ Here is a full example: >>> assert isinstance(td["b", "c"], MemmapTensor) >>> >>> app_state = { - ... "state": torchsnapshot.StateDict(tensordict=td.state_dict()) + ... "state": torchsnapshot.StateDict(tensordict=td.state_dict(keep_vars=True)) ... } >>> snapshot = torchsnapshot.Snapshot.take(app_state=app_state, path=f"/tmp/{uuid.uuid4()}") >>> @@ -126,7 +126,7 @@ Here is a full example: >>> td_dest.memmap_() >>> assert isinstance(td_dest["b", "c"], MemmapTensor) >>> app_state = { - ... "state": torchsnapshot.StateDict(tensordict=td_dest.state_dict()) + ... "state": torchsnapshot.StateDict(tensordict=td_dest.state_dict(keep_vars=True)) ... } >>> snapshot.restore(app_state=app_state) >>> # sanity check @@ -157,7 +157,7 @@ Finally, tensorclass also supports this feature. The code is fairly similar to t >>> assert isinstance(tc.y.x, MemmapTensor) >>> >>> app_state = { - ... "state": torchsnapshot.StateDict(tensordict=tc.state_dict()) + ... "state": torchsnapshot.StateDict(tensordict=tc.state_dict(keep_vars=True)) ... } >>> snapshot = torchsnapshot.Snapshot.take(app_state=app_state, path=f"/tmp/{uuid.uuid4()}") >>> @@ -165,7 +165,7 @@ Finally, tensorclass also supports this feature. The code is fairly similar to t >>> tc_dest.memmap_() >>> assert isinstance(tc_dest.y.x, MemmapTensor) >>> app_state = { - ... "state": torchsnapshot.StateDict(tensordict=tc_dest.state_dict()) + ... "state": torchsnapshot.StateDict(tensordict=tc_dest.state_dict(keep_vars=True)) ... } >>> snapshot.restore(app_state=app_state) >>> diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index a3c89f4bf..b64bd7712 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -817,7 +817,9 @@ def load_state_dict( device = state_dict.pop("__device", None) if device is not None: if device != self.device: - raise RuntimeError("Loading data from another device is not yet supproted.") + raise RuntimeError( + "Loading data from another device is not yet supproted." + ) for key, item in state_dict.items(): if isinstance(item, dict): From 070822972a1de067c6c9503fd5e451faad1c27b0 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 14 Sep 2023 11:12:52 -0400 Subject: [PATCH 6/7] lint --- test/test_tensordict.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 8a6612118..abad0dd81 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -5143,7 +5143,11 @@ def test_inplace(self, save_name): td.memmap_() assert isinstance(td["b", "c"], MemmapTensor) - app_state = {"state": torchsnapshot.StateDict(**{save_name: td.state_dict(keep_vars=True)})} + app_state = { + "state": torchsnapshot.StateDict( + **{save_name: td.state_dict(keep_vars=True)} + ) + } path = f"/tmp/{uuid.uuid4()}" snapshot = torchsnapshot.Snapshot.take(app_state=app_state, path=path) @@ -5160,7 +5164,9 @@ def test_inplace(self, save_name): td_dest.memmap_() assert isinstance(td_dest["b", "c"], MemmapTensor) app_state = { - "state": torchsnapshot.StateDict(**{save_name: td_dest.state_dict(keep_vars=True)}) + "state": torchsnapshot.StateDict( + **{save_name: td_dest.state_dict(keep_vars=True)} + ) } snapshot.restore(app_state=app_state) From 05d010542e2665c0b1ba5409c11c3b0ad69eccba Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 14 Sep 2023 12:22:25 -0400 Subject: [PATCH 7/7] fix tensorclass tests --- test/test_tensorclass.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index e43686fd1..fb44fa949 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -1352,7 +1352,9 @@ class MyClass: assert isinstance(tc.y.x, MemmapTensor) assert tc.z == z - app_state = {"state": torchsnapshot.StateDict(tensordict=tc.state_dict())} + app_state = { + "state": torchsnapshot.StateDict(tensordict=tc.state_dict(keep_vars=True)) + } snapshot = torchsnapshot.Snapshot.take(app_state=app_state, path=str(tmp_path)) tc_dest = MyClass( @@ -1363,7 +1365,9 @@ class MyClass: ) tc_dest.memmap_() assert isinstance(tc_dest.y.x, MemmapTensor) - app_state = {"state": torchsnapshot.StateDict(tensordict=tc_dest.state_dict())} + app_state = { + "state": torchsnapshot.StateDict(tensordict=tc_dest.state_dict(keep_vars=True)) + } snapshot.restore(app_state=app_state) assert (tc_dest == tc).all()