From a0a431583b9c2f7d7cd767d64e363ee95d8faf7a Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 11 Dec 2023 17:29:26 +0000 Subject: [PATCH 1/3] init --- tensordict/_td.py | 48 ++++++++++++++++++++++++-- tensordict/base.py | 10 +++++- test/test_tensordict.py | 74 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 128 insertions(+), 4 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index c2f493274..5279dd878 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -234,8 +234,14 @@ def __init__( @staticmethod def from_module( - module: torch.nn.Module, as_module: bool = False, lock: bool = False + module: torch.nn.Module, + as_module: bool = False, + lock: bool = False, + use_state_dict: bool = False, ): + if use_state_dict: + return TensorDict(module.state_dict(), batch_size=[]).unflatten_keys(".") + td_struct = TensorDict({}, []) for key, param in module.named_parameters(recurse=False): td_struct._set_str(key, param, validated=True, inplace=False) @@ -262,7 +268,14 @@ def is_empty(self): return True @as_decorator() - def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None): + def to_module( + self, + module, + return_swap: bool = True, + swap_dest=None, + memo=None, + use_state_dict: bool = False, + ): # we use __dict__ directly to avoid the getattr/setattr overhead whenever we can __dict__ = module.__dict__ @@ -285,8 +298,36 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None) swap = swap_dest memo[id(module)] = swap _swap = {} + if use_state_dict: + # execute module's pre-hooks + state_dict = self.flatten_keys(".") + prefix = "" + strict = True + local_metadata = {} + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + for hook in module._load_state_dict_pre_hooks.values(): + hook( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) - for key, value in self.items(): + def convert_type(x, y): + if isinstance(y, torch.nn.Parameter): + return torch.nn.Parameter(x) + return x + + input = state_dict.unflatten_keys(".").apply(convert_type, self) + else: + input = self + + for key, value in input.items(): if isinstance(value, (Tensor, ftdim.Tensor)): if module.__class__.__setattr__ is __base__setattr__: # if setattr is the native nn.Module.setattr, we can rely on _set_tensor_dict @@ -315,6 +356,7 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None) return_swap=return_swap, swap_dest=local_dest, memo=memo, + use_state_dict=use_state_dict, ) # we don't want to do this op more than once if return_swap and ( diff --git a/tensordict/base.py b/tensordict/base.py index 99406c957..c1ffe471c 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -303,7 +303,9 @@ def any(self, dim: int = None) -> bool | TensorDictBase: # Module interaction @staticmethod - def from_module(module, as_module: bool = False, lock: bool = True): + def from_module( + module, as_module: bool = False, lock: bool = True, use_state_dict: bool = False + ): """Copies the params and buffers of a module in a tensordict. Args: @@ -312,6 +314,12 @@ def from_module(module, as_module: bool = False, lock: bool = True): within a :class:`torch.nn.Module`. Defaults to ``False``. lock (bool, optional): if ``True``, the resulting tensordict will be locked. Defaults to ``True``. + use_state_dict (bool, optional): if ``True``, the state-dict from the + module will be used and unflattened into a TensorDict with + the tree structure of the model. Defaults to ``False``. + .. note:: + This is particularily useful when state-dict hooks have to be + used. Examples: >>> from torch import nn diff --git a/test/test_tensordict.py b/test/test_tensordict.py index bb6ef79a8..890305341 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -6288,6 +6288,80 @@ def test_from_module(memmap, params): assert set(td.parameters()) == set(net.parameters()) +def test_from_module_state_dict(): + net = nn.Transformer( + d_model=16, + nhead=2, + num_encoder_layers=3, + dim_feedforward=12, + ) + + def adder(module, *args, **kwargs): + for p in module.parameters(recurse=False): + p.data += 1 + + def remover(module, *args, **kwargs): + for p in module.parameters(recurse=False): + p.data = p.data - 1 + + for module in net.modules(): + module.register_state_dict_pre_hook(adder) + module._register_state_dict_hook(remover) + params_reg = TensorDict.from_module(net) + params_reg = params_reg.select(*params_reg.keys(True, True)) + + params_sd = TensorDict.from_module(net, use_state_dict=True) + assert_allclose_td(params_sd, params_reg.apply(lambda x: x + 1)) + + sd = net.state_dict() + assert_allclose_td(params_sd.flatten_keys("."), TensorDict(sd, [])) + + +def test_to_module_state_dict(): + net0 = nn.Transformer( + d_model=16, + nhead=2, + num_encoder_layers=3, + dim_feedforward=12, + ) + net1 = nn.Transformer( + d_model=16, + nhead=2, + num_encoder_layers=3, + dim_feedforward=12, + ) + + def hook( + module, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + for key, val in list(state_dict.items()): + state_dict[key] = val * 0 + + for module in net0.modules(): + module._register_load_state_dict_pre_hook(hook, with_module=True) + for module in net1.modules(): + module._register_load_state_dict_pre_hook(hook, with_module=True) + + params_reg = TensorDict.from_module(net0) + params_reg.to_module(net0, use_state_dict=True) + params_reg = TensorDict.from_module(net0) + + sd = net1.state_dict() + net1.load_state_dict(sd) + sd = net1.state_dict() + + assert (params_reg == 0).all() + assert set(params_reg.flatten_keys(".").keys()) == set(sd.keys()) + assert_allclose_td(params_reg.flatten_keys("."), TensorDict(sd, [])) + + @pytest.mark.parametrize("batch_size", [None, [3, 4]]) @pytest.mark.parametrize("batch_dims", [None, 1, 2]) @pytest.mark.parametrize("device", get_available_devices()) From e8a51d71926128a4ff6b213fd5cb4a8ea9213536 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 13 Dec 2023 21:12:06 +0000 Subject: [PATCH 2/3] amend --- tensordict/_td.py | 82 ++++++++++++++++++++++++++++++++--------- tensordict/base.py | 8 +++- test/test_tensordict.py | 1 + 3 files changed, 72 insertions(+), 19 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index 24b91a263..7965ad63c 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -232,34 +232,82 @@ def __init__( for key, value in source.items(): self.set(key, value) - @staticmethod + @classmethod def from_module( + cls, module: torch.nn.Module, as_module: bool = False, lock: bool = False, use_state_dict: bool = False, ): + result = cls._from_module( + module=module, as_module=as_module, use_state_dict=use_state_dict + ) + if lock: + result.lock_() + return result + + @classmethod + def _from_module( + cls, + module: torch.nn.Module, + as_module: bool = False, + use_state_dict: bool = False, + prefix="", + ): + destination = {} if use_state_dict: - return TensorDict(module.state_dict(), batch_size=[]).unflatten_keys(".") - - td_struct = TensorDict({}, []) - for key, param in module.named_parameters(recurse=False): - td_struct._set_str(key, param, validated=True, inplace=False) - for key, param in module.named_buffers(recurse=False): - td_struct._set_str(key, param, validated=True, inplace=False) - for key, mod in module.named_children(): - td_struct._set_str( - key, - TensorDict.from_module(mod, as_module=False, lock=False), - validated=True, - inplace=False, - ) + keep_vars = False + # do we need this feature atm? + local_metadata = {} + # if hasattr(destination, "_metadata"): + # destination._metadata[prefix[:-1]] = local_metadata + for hook in module._state_dict_pre_hooks.values(): + hook(module, prefix, keep_vars) + module._save_to_state_dict(destination, "", keep_vars) + else: + for name, param in module._parameters.items(): + if param is None: + continue + destination[name] = param + for name, buffer in module._buffers.items(): + if buffer is None: + continue + destination[name] = buffer + + if use_state_dict: + for hook in module._state_dict_hooks.values(): + hook_result = hook(module, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + destination = TensorDict(destination, batch_size=[]) + for name, submodule in module._modules.items(): + if submodule is not None: + subtd = cls._from_module( + module=submodule, + as_module=as_module, + use_state_dict=use_state_dict, + prefix=prefix + name + ".", + ) + destination._set_str(name, subtd, validated=True, inplace=False) + return destination + + # td_struct = TensorDict({}, []) + # for key, param in module.named_parameters(recurse=False): + # td_struct._set_str(key, param, validated=True, inplace=False) + # for key, param in module.named_buffers(recurse=False): + # td_struct._set_str(key, param, validated=True, inplace=False) + # for key, mod in module.named_children(): + # td_struct._set_str( + # key, + # TensorDict.from_module(mod, as_module=False, lock=False), + # validated=True, + # inplace=False, + # ) if as_module: from tensordict.nn.params import TensorDictParams return TensorDictParams(td_struct, no_convert=True) - if lock: - td_struct.lock_() return td_struct def is_empty(self): diff --git a/tensordict/base.py b/tensordict/base.py index dc58bbe92..ab9612956 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -308,9 +308,13 @@ def any(self, dim: int = None) -> bool | TensorDictBase: ... # Module interaction - @staticmethod + @classmethod def from_module( - module, as_module: bool = False, lock: bool = True, use_state_dict: bool = False + cls, + module, + as_module: bool = False, + lock: bool = True, + use_state_dict: bool = False, ): """Copies the params and buffers of a module in a tensordict. diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 0f8c57f37..4f0366d61 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -6351,6 +6351,7 @@ def remover(module, *args, **kwargs): params_reg = params_reg.select(*params_reg.keys(True, True)) params_sd = TensorDict.from_module(net, use_state_dict=True) + params_sd = params_sd.select(*params_sd.keys(True, True)) assert_allclose_td(params_sd, params_reg.apply(lambda x: x + 1)) sd = net.state_dict() From d13e896f81ef3f99bc4c5faea9e0a9de22665629 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 14 Dec 2023 10:10:03 +0000 Subject: [PATCH 3/3] fixes --- tensordict/_td.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index 7965ad63c..46f75ed8b 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -64,6 +64,7 @@ _sub_index, _unravel_key_to_tuple, as_decorator, + Buffer, cache, convert_ellipsis_to_idx, DeviceType, @@ -290,25 +291,12 @@ def _from_module( prefix=prefix + name + ".", ) destination._set_str(name, subtd, validated=True, inplace=False) - return destination - # td_struct = TensorDict({}, []) - # for key, param in module.named_parameters(recurse=False): - # td_struct._set_str(key, param, validated=True, inplace=False) - # for key, param in module.named_buffers(recurse=False): - # td_struct._set_str(key, param, validated=True, inplace=False) - # for key, mod in module.named_children(): - # td_struct._set_str( - # key, - # TensorDict.from_module(mod, as_module=False, lock=False), - # validated=True, - # inplace=False, - # ) if as_module: from tensordict.nn.params import TensorDictParams - return TensorDictParams(td_struct, no_convert=True) - return td_struct + return TensorDictParams(destination, no_convert=True) + return destination def is_empty(self): for _ in self._tensordict: @@ -369,6 +357,8 @@ def to_module( def convert_type(x, y): if isinstance(y, torch.nn.Parameter): return torch.nn.Parameter(x) + if isinstance(y, Buffer): + return Buffer(x) return x input = state_dict.unflatten_keys(".").apply(convert_type, self) @@ -2635,6 +2625,10 @@ def _items( tensordict = self.tensordict if isinstance(tensordict, TensorDict) or is_tensorclass(tensordict): return tensordict._tensordict.items() + from tensordict.nn import TensorDictParams + + if isinstance(tensordict, TensorDictParams): + return tensordict._param_td.items() if isinstance(tensordict, KeyedJaggedTensor): return tuple((key, tensordict[key]) for key in tensordict.keys()) from tensordict._lazy import (