Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Nov 24, 2023
1 parent 103773b commit 3d718a0
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 22 deletions.
30 changes: 18 additions & 12 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,20 +259,31 @@ def from_module(
td_struct.lock_()
return td_struct

def is_empty(self):
for _ in self._tensordict:
return False
return True

@as_decorator()
def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None):
# we use __dict__ directly to avoid the getattr/setattr overhead whenever we can
__dict__ = module.__dict__

swap = None
has_set_device = False
if memo is None:
memo = {}
hooks = getattr(
torch.nn.modules.module, "_global_parameter_registration_hooks", {}
)
memo = {"hooks": tuple(hooks.values())}
else:
hooks = memo["hooks"]
if return_swap:
# this could break if the device and batch-size are not congruent.
# For batch-size it is a minor issue (unlikely that a td with batch-size
# is passed with to_module) but for the device it could be a problem.
if swap_dest is None:
swap = TensorDict({}, batch_size=[])
swap = TensorDict({}, batch_size=torch.Size(()), _run_checks=False)
else:
swap = swap_dest
memo[id(module)] = swap
Expand All @@ -282,18 +293,16 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None)
if isinstance(value, (Tensor, ftdim.Tensor)):
if module.__class__.__setattr__ is __base__setattr__:
# if setattr is the native nn.Module.setattr, we can rely on _set_tensor_dict
local_out = _set_tensor_dict(__dict__, module, key, value)
local_out = _set_tensor_dict(__dict__, hooks, module, key, value)
else:
if return_swap:
local_out = getattr(module, key)
# use specialized __setattr__ if needed
setattr(module, key, value)
else:
for _ in value.keys():
if value.is_empty():
# if there is at least one key, we must populate the module.
# Otherwise we just go to the next key
break
else:
continue
if swap_dest is not None:
local_dest = swap_dest._get_str(key, default=NO_DEFAULT)
Expand Down Expand Up @@ -2571,7 +2580,7 @@ def __repr__(self):


def _set_tensor_dict( # noqa: F811
module_dict, module, name: str, tensor: torch.Tensor
module_dict, hooks, module, name: str, tensor: torch.Tensor
) -> None:
"""Simplified version of torch.nn.utils._named_member_accessor."""
was_buffer = False
Expand All @@ -2580,13 +2589,10 @@ def _set_tensor_dict( # noqa: F811
out = module_dict["_buffers"].pop(name, None)
was_buffer = out is not None
if out is None:
out = module_dict.pop(name, None)
out = module_dict.pop(name)

if isinstance(tensor, torch.nn.Parameter):
# module.register_parameter(name, tensor)
for hook in getattr(
torch.nn.modules.module, "_global_parameter_registration_hooks", {}
).values():
for hook in hooks:
output = hook(module, name, tensor)
if output is not None:
tensor = output
Expand Down
29 changes: 19 additions & 10 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,19 +1116,28 @@ def __init__(self, attr=None):
self.attr = attr

def __call__(self, func):
@wraps(func)
def new_func(_self, *args, **kwargs):
if self.attr is not None:
if self.attr is not None:

@wraps(func)
def new_func(_self, *args, **kwargs):
_attr_pre = getattr(_self, self.attr)
out = func(_self, *args, **kwargs)
if self.attr is not None:
out = func(_self, *args, **kwargs)
_attr_post = getattr(_self, self.attr)
if out is not None:
if self.attr is None or (_attr_post is not _attr_pre):
if out is not None:
if _attr_post is not _attr_pre:
out._last_op = (new_func.__name__, (args, kwargs, _self))
else:
out._last_op = None
return out

else:

@wraps(func)
def new_func(_self, *args, **kwargs):
out = func(_self, *args, **kwargs)
if out is not None:
out._last_op = (new_func.__name__, (args, kwargs, _self))
else:
out._last_op = None
return out
return out

return new_func

Expand Down

0 comments on commit 3d718a0

Please sign in to comment.