Skip to content

Commit

Permalink
[Feature] Improve in-place ops for TensorDictParams (#609)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jan 5, 2024
1 parent ada9e3c commit da449cf
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 48 deletions.
4 changes: 1 addition & 3 deletions tensordict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
set_lazy_legacy,
)
from tensordict._pytree import *

from tensordict._tensordict import unravel_key, unravel_key_list
from tensordict.nn import TensorDictParams

try:
from tensordict.version import __version__
Expand All @@ -53,5 +53,3 @@
"dense_stack_tds",
"NonTensorData",
]

# from tensordict._pytree import *
7 changes: 5 additions & 2 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4017,8 +4017,11 @@ def empty(self, recurse=False) -> T:
# Filling
def zero_(self) -> T:
"""Zeros all tensors in the tensordict in-place."""
for key in self.keys():
self.fill_(key, 0)

def fn(item):
item.zero_()

self._fast_apply(fn=fn, call_on_nested=True)
return self

def fill_(self, key: NestedKey, value: float | bool) -> T:
Expand Down
116 changes: 100 additions & 16 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,15 @@ def new_func(self, *args, **kwargs):
return new_func


def _apply_on_data(func):
@wraps(func)
def new_func(self, *args, **kwargs):
getattr(self.data, func.__name__)(*args, **kwargs)
return self

return new_func


class TensorDictParams(TensorDictBase, nn.Module):
r"""Holds a TensorDictBase instance full of parameters.
Expand Down Expand Up @@ -428,10 +437,6 @@ def rename_key_(
) -> TensorDictBase:
...

@_unlock_and_set
def apply_(self, fn: Callable, *others) -> TensorDictBase:
...

def map(
self,
fn: Callable,
Expand Down Expand Up @@ -514,9 +519,18 @@ def cuda(self, device=None):
def clone(self, recurse: bool = True) -> TensorDictBase:
"""Clones the TensorDictParams.
The effect of this call is different from a regular torch.Tensor.clone call
in that it will create a TensorDictParams instance with a new copy of the
parameters and buffers __detached__ from the current graph.
.. warning::
The effect of this call is different from a regular torch.Tensor.clone call
in that it will create a TensorDictParams instance with a new copy of the
parameters and buffers __detached__ from the current graph. For a
regular clone (ie, cloning leaf parameters onto a new tensor that
is part of the graph), simply call
>>> params.apply(torch.clone)
.. note::
If a parameter is duplicated in the tree, ``clone`` will preserve this
identity (ie, parameter tying is preserved).
See :meth:`tensordict.TensorDictBase.clone` for more info on the clone
method.
Expand All @@ -525,14 +539,21 @@ def clone(self, recurse: bool = True) -> TensorDictBase:
if not recurse:
return TensorDictParams(self._param_td.clone(False), no_convert=True)

def _clone(tensor):
memo = {}

def _clone(tensor, memo=memo):
result = memo.get(tensor, None)
if result is not None:
return result

if isinstance(tensor, nn.Parameter):
tensor = nn.Parameter(
result = nn.Parameter(
tensor.data.clone(), requires_grad=tensor.requires_grad
)
else:
tensor = Buffer(tensor.data.clone(), requires_grad=tensor.requires_grad)
return tensor
result = Buffer(tensor.data.clone(), requires_grad=tensor.requires_grad)
memo[tensor] = result
return result

return TensorDictParams(self._param_td.apply(_clone), no_convert=True)

Expand Down Expand Up @@ -652,10 +673,6 @@ def _set_tuple(self, *args, **kwargs):
def _create_nested_str(self, *args, **kwargs):
...

@_fallback
def _stack_onto_(self, *args, **kwargs):
...

@_fallback_property
def batch_size(self) -> torch.Size:
...
Expand Down Expand Up @@ -984,6 +1001,73 @@ def items(
continue
yield k, self._apply_get_post_hook(v)

@_apply_on_data
def zero_(self) -> T:
...

@_apply_on_data
def fill_(self, key: NestedKey, value: float | bool) -> T:
...

@_apply_on_data
def copy_(self, tensordict: T, non_blocking: bool = None) -> T:
...

@_apply_on_data
def set_at_(self, key: NestedKey, value: CompatibleType, index: IndexType) -> T:
...

@_apply_on_data
def set_(
self,
key: NestedKey,
item: CompatibleType,
) -> T:
...

@_apply_on_data
def _stack_onto_(
self,
list_item: list[CompatibleType],
dim: int,
) -> T:
...

@_apply_on_data
def _stack_onto_at_(
self,
key: str,
list_item: list[CompatibleType],
dim: int,
idx: IndexType,
) -> T:
...

@_apply_on_data
def update_(
self,
input_dict_or_td: dict[str, CompatibleType] | T,
clone: bool = False,
*,
keys_to_update: Sequence[NestedKey] | None = None,
) -> T:
...

@_apply_on_data
def update_at_(
self,
input_dict_or_td: dict[str, CompatibleType] | T,
idx: IndexType,
clone: bool = False,
*,
keys_to_update: Sequence[NestedKey] | None = None,
) -> T:
...

@_apply_on_data
def apply_(self, fn: Callable, *others) -> T:
...

def _apply(self, fn, recurse=True):
"""Modifies torch.nn.Module._apply to work with Buffer class."""
if recurse:
Expand Down Expand Up @@ -1097,7 +1181,7 @@ def _empty_like(td: TensorDictBase, *args, **kwargs) -> TensorDictBase:
"cloned, preventing empty_like to be called. "
"Consider calling tensordict.to_tensordict() first."
) from err
return tdclone.data.apply_(lambda x: torch.empty_like(x, *args, **kwargs))
return tdclone.apply_(lambda x: torch.empty_like(x, *args, **kwargs))


_register_tensor_class(TensorDictParams)
5 changes: 5 additions & 0 deletions tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,11 @@ def empty(self, recurse=False) -> T:
names=self.names if self._has_names() else None,
)

def zero_(self) -> T:
for key in self.keys():
self.fill_(key, 0)
return self

def entry_class(self, key: NestedKey) -> type:
entry_class = self._get_metadata(key)
is_array = entry_class.get("array", None)
Expand Down
32 changes: 32 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2915,6 +2915,38 @@ def test_tdparams_clone(self):
assert val.data_ptr() != td.get(key).data_ptr()
assert (val == td.get(key)).all()

def test_tdparams_clone_tying(self):
c = nn.Parameter(torch.zeros((), requires_grad=True))
td = TensorDict(
{
"a": {
"b": {"c": c},
},
"c": c,
},
[],
)
td = TensorDictParams(td, no_convert=True)
td_clone = td.clone()
assert td_clone["c"] is td_clone["a", "b", "c"]

def test_inplace_ops(self):
td = TensorDict(
{
"a": {
"b": {"c": torch.zeros((), requires_grad=True)},
"d": torch.zeros((), requires_grad=True),
},
"e": torch.zeros((), requires_grad=True),
},
[],
)
param_td = TensorDictParams(td)
param_td.copy_(param_td.data.apply(lambda x: x + 1))
assert (param_td == 1).all()
param_td.zero_()
assert (param_td == 0).all()


class TestCompositeDist:
def test_const(self):
Expand Down
27 changes: 0 additions & 27 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,12 +1080,6 @@ def test_masked_fill(self, td_name, device):
def test_zero_(self, td_name, device):
torch.manual_seed(1)
td = getattr(self, td_name)(device)
if td_name == "td_params":
with pytest.raises(
RuntimeError, match="a leaf Variable that requires grad"
):
new_td = td.zero_()
return
new_td = td.zero_()
assert new_td is td
for k in td.keys():
Expand Down Expand Up @@ -1581,10 +1575,6 @@ def test_squeeze_with_none(self, td_name, device, squeeze_dim=None):
td = getattr(self, td_name)(device)
td_squeeze = torch.squeeze(td, dim=None)
tensor = torch.ones_like(td.get("a").squeeze())
if td_name == "td_params":
with pytest.raises(ValueError, match="Failed to update"):
td_squeeze.set_("a", tensor)
return
td_squeeze.set_("a", tensor)
assert (td_squeeze.get("a") == tensor).all()
if td_name == "unsqueezed_td":
Expand Down Expand Up @@ -1688,10 +1678,6 @@ def test_update(self, td_name, device, clone):
def test_update_at_(self, td_name, device):
td = getattr(self, td_name)(device)
td0 = td[1].clone().zero_()
if td_name == "td_params":
with pytest.raises(RuntimeError, match="a view of a leaf Variable"):
td.update_at_(td0, 0)
return
td.update_at_(td0, 0)
assert (td[0] == 0).all()

Expand Down Expand Up @@ -2325,10 +2311,6 @@ def test_stack_tds_on_subclass(self, td_name, device):
with pytest.raises(IndexError, match="storages of the indexed tensors"):
torch.stack(tds_list, 0, out=td)
return
if td_name == "td_params":
with pytest.raises(RuntimeError, match="arguments don't support automatic"):
torch.stack(tds_list, 0, out=td)
return
data_ptr_set_before = {val.data_ptr() for val in decompose(td)}

stacked_td = torch.stack(tds_list, 0, out=td)
Expand Down Expand Up @@ -3068,10 +3050,6 @@ def test_empty_like(self, td_name, device):
# we do not call skip to avoid systematic skips in internal code base
return
td_empty = torch.empty_like(td)
if td_name == "td_params":
with pytest.raises(ValueError, match="Failed to update"):
td.apply_(lambda x: x + 1.0)
return

td.apply_(lambda x: x + 1.0)
assert type(td) is type(td_empty)
Expand Down Expand Up @@ -3104,11 +3082,6 @@ def test_add_batch_dim_cache(self, td_name, device, nested):
return
fun(td)

if td_name == "td_params":
with pytest.raises(RuntimeError, match="leaf Variable that requires grad"):
td.zero_()
return

td.zero_()
# this value should be cached
std = fun(td)
Expand Down

0 comments on commit da449cf

Please sign in to comment.