From d6970cfc80e25b1e69e3c1499aaecdd6bc8cca34 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 6 Sep 2023 10:11:54 -0400 Subject: [PATCH] init --- tensordict/tensordict.py | 39 ++++++++++++++++++++++++++++----------- test/test_tensordict.py | 21 ++++++++++++++++++--- 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index 17cfe0a76..9ecf1c0cb 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -372,9 +372,14 @@ def __setstate__(self, state: dict[str, Any]) -> dict[str, Any]: self.__dict__.update(state) @staticmethod - def from_module(module): + def from_module(module, as_module: bool = False): """Copies the params and buffers of a module in a tensordict. + Args: + as_module (bool, optional): if ``True``, a :class:`tensordict.nn.TensorDictParams` + instance will be returned which can be used to store parameters + within a :class:`torch.nn.Module`. Defaults to ``False``. + Examples: >>> from torch import nn >>> module = nn.TransformerDecoder( @@ -390,10 +395,17 @@ def from_module(module): device=None, is_shared=False) """ - td = TensorDict(dict(module.named_parameters()), []) - td.update(dict(module.named_buffers())) - td = td.unflatten_keys(".") + td_struct = {k: {} for k in dict(module.named_modules()).keys()} + del td_struct[""] + td_struct = TensorDict(td_struct, []).unflatten_keys(".") + td_params = TensorDict(dict(module.named_parameters()), []).unflatten_keys(".") + td_buffers = TensorDict(dict(module.named_buffers()), []).unflatten_keys(".") + td = td_struct.update(td_params).update(td_buffers) td.lock_() + if as_module: + from tensordict.nn import TensorDictParams + + return TensorDictParams(td, no_convert=True) return td @property @@ -3329,7 +3341,11 @@ def unflatten_keys(self, separator: str = ".", inplace: bool = False) -> T: keys = set(out.keys()) for key, list_of_keys in to_unflatten.items(): - if key in keys: + # if the key is present and either (1) it is not a tensor collection or (2) it is but it's not empty, then we raise an error. + if key in keys and ( + not is_tensor_collection(out.get(key)) or not out.get(key).is_empty() + ): + print(out.get(key)) raise KeyError( "Unflattening key(s) in tensordict will override existing unflattened key" ) @@ -3594,7 +3610,8 @@ def empty(self, recurse=False) -> T: return self.exclude(*self.keys(True, True)) def is_empty(self) -> bool: - for _ in self.keys(): + """Checks if the tensordict contains any leaf.""" + for _ in self.keys(True, True): return False return True @@ -4435,11 +4452,11 @@ def memmap_( raise RuntimeError( "memmap and shared memory are mutually exclusive features." ) - if not self._tensordict.keys(): - raise Exception( - "memmap_() must be called when the TensorDict is (partially) " - "populated. Set a tensor first." - ) + # if not self._tensordict.keys(): + # raise Exception( + # "memmap_() must be called when the TensorDict is (partially) " + # "populated. Set a tensor first." + # ) for key, value in self.items(): if value.requires_grad: raise Exception( diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 8a76ff1f2..fa7303738 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -5960,13 +5960,28 @@ def test_stacked_append_and_insert(self): @pytest.mark.parametrize("memmap", [True, False]) -def test_from_module(memmap): - net = nn.Transformer() - td = TensorDict.from_module(net) +@pytest.mark.parametrize("params", [False, True]) +def test_from_module(memmap, params): + net = nn.Transformer( + d_model=16, + nhead=2, + num_encoder_layers=3, + dim_feedforward=12, + ) + td = TensorDict.from_module(net, as_module=params) + # check that we have empty tensordicts, reflecting modules wihout params + for subtd in td.values(True): + if isinstance(subtd, TensorDictBase) and subtd.is_empty(): + break + else: + raise RuntimeError if memmap: td = td.detach().memmap_() net.load_state_dict(td.flatten_keys(".")) + if not memmap and params: + assert set(td.parameters()) == set(net.parameters()) + @pytest.mark.parametrize("batch_size", [None, [3, 4]]) @pytest.mark.parametrize("batch_dims", [None, 1, 2])