Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Refactor to create empty tds where necessary #522

Merged
merged 1 commit into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,14 @@ def __setstate__(self, state: dict[str, Any]) -> dict[str, Any]:
self.__dict__.update(state)

@staticmethod
def from_module(module):
def from_module(module, as_module: bool = False):
"""Copies the params and buffers of a module in a tensordict.

Args:
as_module (bool, optional): if ``True``, a :class:`tensordict.nn.TensorDictParams`
instance will be returned which can be used to store parameters
within a :class:`torch.nn.Module`. Defaults to ``False``.

Examples:
>>> from torch import nn
>>> module = nn.TransformerDecoder(
Expand All @@ -390,10 +395,17 @@ def from_module(module):
device=None,
is_shared=False)
"""
td = TensorDict(dict(module.named_parameters()), [])
td.update(dict(module.named_buffers()))
td = td.unflatten_keys(".")
td_struct = {k: {} for k in dict(module.named_modules()).keys()}
del td_struct[""]
td_struct = TensorDict(td_struct, []).unflatten_keys(".")
td_params = TensorDict(dict(module.named_parameters()), []).unflatten_keys(".")
td_buffers = TensorDict(dict(module.named_buffers()), []).unflatten_keys(".")
td = td_struct.update(td_params).update(td_buffers)
td.lock_()
if as_module:
from tensordict.nn import TensorDictParams

return TensorDictParams(td, no_convert=True)
return td

@property
Expand Down Expand Up @@ -3329,7 +3341,11 @@ def unflatten_keys(self, separator: str = ".", inplace: bool = False) -> T:

keys = set(out.keys())
for key, list_of_keys in to_unflatten.items():
if key in keys:
# if the key is present and either (1) it is not a tensor collection or (2) it is but it's not empty, then we raise an error.
if key in keys and (
not is_tensor_collection(out.get(key)) or not out.get(key).is_empty()
):
print(out.get(key))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left over

raise KeyError(
"Unflattening key(s) in tensordict will override existing unflattened key"
)
Expand Down Expand Up @@ -3594,7 +3610,8 @@ def empty(self, recurse=False) -> T:
return self.exclude(*self.keys(True, True))

def is_empty(self) -> bool:
for _ in self.keys():
"""Checks if the tensordict contains any leaf."""
for _ in self.keys(True, True):
return False
return True

Expand Down Expand Up @@ -4435,11 +4452,11 @@ def memmap_(
raise RuntimeError(
"memmap and shared memory are mutually exclusive features."
)
if not self._tensordict.keys():
raise Exception(
"memmap_() must be called when the TensorDict is (partially) "
"populated. Set a tensor first."
)
# if not self._tensordict.keys():
# raise Exception(
# "memmap_() must be called when the TensorDict is (partially) "
# "populated. Set a tensor first."
# )
for key, value in self.items():
if value.requires_grad:
raise Exception(
Expand Down
21 changes: 18 additions & 3 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5960,13 +5960,28 @@ def test_stacked_append_and_insert(self):


@pytest.mark.parametrize("memmap", [True, False])
def test_from_module(memmap):
net = nn.Transformer()
td = TensorDict.from_module(net)
@pytest.mark.parametrize("params", [False, True])
def test_from_module(memmap, params):
net = nn.Transformer(
d_model=16,
nhead=2,
num_encoder_layers=3,
dim_feedforward=12,
)
td = TensorDict.from_module(net, as_module=params)
# check that we have empty tensordicts, reflecting modules wihout params
for subtd in td.values(True):
if isinstance(subtd, TensorDictBase) and subtd.is_empty():
break
else:
raise RuntimeError
if memmap:
td = td.detach().memmap_()
net.load_state_dict(td.flatten_keys("."))

if not memmap and params:
assert set(td.parameters()) == set(net.parameters())


@pytest.mark.parametrize("batch_size", [None, [3, 4]])
@pytest.mark.parametrize("batch_dims", [None, 1, 2])
Expand Down