Skip to content

Commit

Permalink
[Refactor] Refactor to create empty tds where necessary (#522)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Sep 6, 2023
1 parent 4b309d9 commit 14ca63b
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 14 deletions.
39 changes: 28 additions & 11 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,14 @@ def __setstate__(self, state: dict[str, Any]) -> dict[str, Any]:
self.__dict__.update(state)

@staticmethod
def from_module(module):
def from_module(module, as_module: bool = False):
"""Copies the params and buffers of a module in a tensordict.
Args:
as_module (bool, optional): if ``True``, a :class:`tensordict.nn.TensorDictParams`
instance will be returned which can be used to store parameters
within a :class:`torch.nn.Module`. Defaults to ``False``.
Examples:
>>> from torch import nn
>>> module = nn.TransformerDecoder(
Expand All @@ -390,10 +395,17 @@ def from_module(module):
device=None,
is_shared=False)
"""
td = TensorDict(dict(module.named_parameters()), [])
td.update(dict(module.named_buffers()))
td = td.unflatten_keys(".")
td_struct = {k: {} for k in dict(module.named_modules()).keys()}
del td_struct[""]
td_struct = TensorDict(td_struct, []).unflatten_keys(".")
td_params = TensorDict(dict(module.named_parameters()), []).unflatten_keys(".")
td_buffers = TensorDict(dict(module.named_buffers()), []).unflatten_keys(".")
td = td_struct.update(td_params).update(td_buffers)
td.lock_()
if as_module:
from tensordict.nn import TensorDictParams

return TensorDictParams(td, no_convert=True)
return td

@property
Expand Down Expand Up @@ -3329,7 +3341,11 @@ def unflatten_keys(self, separator: str = ".", inplace: bool = False) -> T:

keys = set(out.keys())
for key, list_of_keys in to_unflatten.items():
if key in keys:
# if the key is present and either (1) it is not a tensor collection or (2) it is but it's not empty, then we raise an error.
if key in keys and (
not is_tensor_collection(out.get(key)) or not out.get(key).is_empty()
):
print(out.get(key))
raise KeyError(
"Unflattening key(s) in tensordict will override existing unflattened key"
)
Expand Down Expand Up @@ -3594,7 +3610,8 @@ def empty(self, recurse=False) -> T:
return self.exclude(*self.keys(True, True))

def is_empty(self) -> bool:
for _ in self.keys():
"""Checks if the tensordict contains any leaf."""
for _ in self.keys(True, True):
return False
return True

Expand Down Expand Up @@ -4435,11 +4452,11 @@ def memmap_(
raise RuntimeError(
"memmap and shared memory are mutually exclusive features."
)
if not self._tensordict.keys():
raise Exception(
"memmap_() must be called when the TensorDict is (partially) "
"populated. Set a tensor first."
)
# if not self._tensordict.keys():
# raise Exception(
# "memmap_() must be called when the TensorDict is (partially) "
# "populated. Set a tensor first."
# )
for key, value in self.items():
if value.requires_grad:
raise Exception(
Expand Down
21 changes: 18 additions & 3 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5960,13 +5960,28 @@ def test_stacked_append_and_insert(self):


@pytest.mark.parametrize("memmap", [True, False])
def test_from_module(memmap):
net = nn.Transformer()
td = TensorDict.from_module(net)
@pytest.mark.parametrize("params", [False, True])
def test_from_module(memmap, params):
net = nn.Transformer(
d_model=16,
nhead=2,
num_encoder_layers=3,
dim_feedforward=12,
)
td = TensorDict.from_module(net, as_module=params)
# check that we have empty tensordicts, reflecting modules wihout params
for subtd in td.values(True):
if isinstance(subtd, TensorDictBase) and subtd.is_empty():
break
else:
raise RuntimeError
if memmap:
td = td.detach().memmap_()
net.load_state_dict(td.flatten_keys("."))

if not memmap and params:
assert set(td.parameters()) == set(net.parameters())


@pytest.mark.parametrize("batch_size", [None, [3, 4]])
@pytest.mark.parametrize("batch_dims", [None, 1, 2])
Expand Down

0 comments on commit 14ca63b

Please sign in to comment.