Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Consolidate functional calls #565

Merged
merged 6 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions benchmarks/nn/functional_benchmarks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,16 @@ def test_instantiation_td(benchmark, net):

# Execution
def test_exec_functorch(benchmark, net):
x = torch.randn(2, 2)
sd = net.state_dict()

def fun(x, sd):
torch.func.functional_call(net, sd, x)

benchmark(fun, x, sd)


def test_exec_functional_call(benchmark, net):
x = torch.randn(2, 2)
fmodule, params, buffers = functorch_make_functional(net)
benchmark(fmodule, params, buffers, x)
Expand All @@ -132,6 +142,18 @@ def test_exec_td(benchmark, net):
benchmark(fmodule, x, params=params)


def test_exec_td_decorator(benchmark, net):
x = torch.randn(2, 2)
fmodule = net
params = TensorDict.from_module(fmodule)

def fun(x, params):
with params.to_module(net):
net(x)

benchmark(fun, x, params)


@torch.no_grad()
@pytest.mark.parametrize("stack", [True, False])
@pytest.mark.parametrize("tdmodule", [True, False])
Expand Down Expand Up @@ -169,6 +191,48 @@ def test_vmap_mlp_speed(benchmark, stack, tdmodule):
benchmark(fun, x, params)


@torch.no_grad()
@pytest.mark.parametrize("stack", [True, False])
@pytest.mark.parametrize("tdmodule", [True, False])
def test_vmap_mlp_speed_decorator(benchmark, stack, tdmodule):
# tests speed of vmapping over a transformer
device = "cuda" if torch.cuda.device_count() else "cpu"
t = nn.Sequential(
nn.Linear(64, 64, device=device),
nn.ReLU(),
nn.Linear(64, 64, device=device),
nn.ReLU(),
nn.Linear(64, 64, device=device),
nn.ReLU(),
nn.Linear(64, 64, device=device),
nn.ReLU(),
)
if tdmodule:
t = TensorDictModule(t, in_keys=["x"], out_keys=["y"])

x = torch.randn(1, 1, 64, device=device)
t.eval()
params = TensorDict.from_module(t)
if not stack:
params = params.expand(2).to_tensordict().lock_()
else:
params = torch.stack([params, params.clone()], 0).lock_()

def fun(x, params):
with params.to_module(t):
return t(x)

vfun = vmap(fun, (None, 0))

if tdmodule:
data = TensorDict({"x": x}, [])
vfun(data, params)
benchmark(vfun, data, params)
else:
vfun(x, params)
benchmark(vfun, x, params)


@torch.no_grad()
@pytest.mark.skipif(
not torch.cuda.device_count(), reason="cuda device required for test"
Expand Down Expand Up @@ -208,6 +272,53 @@ def test_vmap_transformer_speed(benchmark, stack, tdmodule):
benchmark(fun, x, x, params)


@torch.no_grad()
@pytest.mark.skipif(
not torch.cuda.device_count(), reason="cuda device required for test"
)
@pytest.mark.parametrize("stack", [True, False])
@pytest.mark.parametrize("tdmodule", [True, False])
def test_vmap_transformer_speed_decorator(benchmark, stack, tdmodule):
# tests speed of vmapping over a transformer
device = "cuda" if torch.cuda.device_count() else "cpu"
t = torch.nn.Transformer(
8,
dim_feedforward=8,
device=device,
batch_first=False,
)
if tdmodule:
t = TensorDictModule(t, in_keys=["x", "x"], out_keys=["y"])

x = torch.randn(2, 2, 8, device=device)
t.eval()
params = TensorDict.from_module(t)
if not stack:
params = params.expand(2).to_tensordict().lock_()
else:
params = torch.stack([params, params.clone()], 0).lock_()

if tdmodule:

def fun(x, params):
with params.to_module(t):
return t(x)

vfun = vmap(fun, (None, 0))
data = TensorDict({"x": x}, [])
vfun(data, params)
benchmark(vfun, data, params)
else:

def fun(x, params):
with params.to_module(t):
return t(x, x)

vfun = vmap(fun, (None, 0))
vfun(x, params)
benchmark(vfun, x, params)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
2 changes: 2 additions & 0 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,8 @@ def hook_in(
batch_size=n,
vmap_level=vmap_level,
):
if _is_tensor_collection(type(tensor)):
return tensor._remove_batch_dim(vmap_level, batch_size, out_dim)
return _remove_batch_dim(tensor, vmap_level, batch_size, out_dim)

out.hook_out = hook_out
Expand Down
44 changes: 27 additions & 17 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,20 +259,23 @@ def from_module(
return td_struct

@as_decorator()
def to_module(self, module, return_swap: bool = True, swap_dest=None):

def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None):
# we use __dict__ directly to avoid the getattr/setattr overhead whenever we can
__dict__ = module.__dict__
out = None
swap = None
has_set_device = False
if memo is None:
memo = {}
if return_swap:
# this could break if the device and batch-size are not congruent.
# For batch-size it is a minor issue (unlikely that a td with batch-size
# is passed with to_module) but for the device it could be a problem.
if swap_dest is None:
out = self.empty()
swap = self.empty()
swap.clear_device_()
else:
out = swap_dest
swap = swap_dest
memo[id(module)] = swap

for key, value in self.items():
if isinstance(value, (Tensor, ftdim.Tensor)):
Expand All @@ -295,24 +298,31 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None):
local_dest = swap_dest._get_str(key, default=NO_DEFAULT)
else:
local_dest = None
local_out = value.to_module(
__dict__["_modules"][key],
return_swap=return_swap,
swap_dest=local_dest,
)
if return_swap:
child = __dict__["_modules"][key]
if id(child) in memo:
local_out = memo[id(child)]
else:
local_out = value.to_module(
child,
return_swap=return_swap,
swap_dest=local_dest,
memo=memo,
)
# we don't want to do this op more than once
if (
if return_swap and (
not has_set_device
and out.device is not None
and swap.device is not None
and local_out.device is not None
and local_out.device != out.device
and local_out.device != swap.device
):
has_set_device = True
# map out to the local_out device
out = out.to(device=local_out.device)
out._set_str(key, local_out, inplace=False, validated=True)
return out
swap = swap.to(device=local_out.device)

if return_swap:
assert local_out is not None, key
swap._set_str(key, local_out, inplace=False, validated=True)
return swap

def __ne__(self, other: object) -> T | bool:
if _is_tensorclass(other):
Expand Down
11 changes: 9 additions & 2 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,9 @@ def from_module(module, as_module: bool = False, lock: bool = True):
...

@abc.abstractmethod
def to_module(self, module: nn.Module, return_swap: bool = False, swap_dest=None):
def to_module(
self, module: nn.Module, return_swap: bool = False, swap_dest=None, memo=None
):
"""Writes the content of a TensorDictBase instance onto a given nn.Module attributes, recursively.

Args:
Expand All @@ -337,6 +339,10 @@ def to_module(self, module: nn.Module, return_swap: bool = False, swap_dest=None
will be returned. Defaults to ``False``.
swap_dest (TensorDictBase, optional): if ``return_swap`` is ``True``,
the tensordict where the swap should be written.
memo (dict, optional): when the same module is present multiple times
in the input module, a memo is used to avoid fetching the params
that have just been set. This argument should be ignored during
regular calls to `to_module`.

Examples:
>>> from torch import nn
Expand Down Expand Up @@ -3119,7 +3125,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):
return self.lock_()
if last_op == self.__class__.to_module.__name__:
if is_tensor_collection(out):
return self.to_module(*args, **kwargs, swap_dest=out)
with out.unlock_():
return self.to_module(*args, **kwargs, swap_dest=out)
else:
raise RuntimeError(
"to_module cannot be used as a decorator when return_swap=False."
Expand Down
9 changes: 5 additions & 4 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,10 +1123,11 @@ def new_func(_self, *args, **kwargs):
out = func(_self, *args, **kwargs)
if self.attr is not None:
_attr_post = getattr(_self, self.attr)
if self.attr is None or (_attr_post is not _attr_pre):
out._last_op = (new_func.__name__, (args, kwargs, _self))
else:
out._last_op = None
if out is not None:
if self.attr is None or (_attr_post is not _attr_pre):
out._last_op = (new_func.__name__, (args, kwargs, _self))
else:
out._last_op = None
return out

return new_func
Expand Down
86 changes: 86 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3198,6 +3198,92 @@ def test_sd_module(self):
assert isinstance(val, nn.Parameter)


@pytest.mark.parametrize(
"module_name,input_name", [["_module_shared", "_x"], ["_transformer", "_tuple_x"]]
)
@pytest.mark.parametrize("as_module", [True, False])
class TestToModule:
@property
def _transformer(self):
# we use transformer because it's deep, has buffers etc.
return nn.Transformer(d_model=8, dim_feedforward=8).eval()

@property
def _module_shared(self):
# a module with the same layer appearing twice
l0 = nn.Linear(8, 9)
l1 = nn.Linear(9, 8)
return nn.Sequential(
l0,
l1,
nn.Sequential(
l0,
),
)

@property
def _tuple_x(self):
x = torch.randn(2, 2, 8)
return (x, x)

@property
def _x(self):
return (torch.randn(2, 2, 8),)

def test_static(self, module_name, input_name, as_module):
torch.manual_seed(0)
module = getattr(self, module_name)
x = getattr(self, input_name)
params = TensorDict.from_module(module, as_module=as_module)
params0 = params.clone().apply(
lambda t, p: nn.Parameter(t * 0) if isinstance(p, nn.Parameter) else t * 0,
params,
)
y = module(*x)
params0.to_module(module)
y0 = module(*x)
params.to_module(module)
y1 = module(*x)
torch.testing.assert_close(y, y1)
assert (y0 == 0).all()
assert (y0 != y1).all()

def test_cm(self, module_name, input_name, as_module):
torch.manual_seed(0)
module = getattr(self, module_name)
x = getattr(self, input_name)
params = TensorDict.from_module(module, as_module=as_module)
params0 = params.clone().apply(
lambda t, p: nn.Parameter(t * 0) if isinstance(p, nn.Parameter) else t * 0,
params,
)
y = module(*x)
with params0.to_module(module):
y0 = module(*x)
assert (params0 == TensorDict.from_module(module)).all()
y1 = module(*x)
torch.testing.assert_close(y, y1)
assert (y0 == 0).all()
assert (y0 != y1).all()
assert (TensorDict.from_module(module) == params).all()

def test_cm_meta(self, module_name, input_name, as_module):
torch.manual_seed(0)
module = getattr(self, module_name)
x = getattr(self, input_name)
params = TensorDict.from_module(module, as_module=as_module)
params_meta = params.detach().to("meta")
y = module(*x)
with params_meta.to_module(module):
module_meta = copy.deepcopy(module)
y1 = module(*x)
with params.to_module(module_meta):
y2 = module_meta(*x)
torch.testing.assert_close(y, y1)
torch.testing.assert_close(y, y2)
assert (TensorDict.from_module(module) == params).all()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
Loading