Skip to content

Commit

Permalink
[Feature] Consolidate functional calls (#565)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Nov 22, 2023
1 parent 3689afa commit 2ea264b
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 23 deletions.
111 changes: 111 additions & 0 deletions benchmarks/nn/functional_benchmarks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,16 @@ def test_instantiation_td(benchmark, net):

# Execution
def test_exec_functorch(benchmark, net):
x = torch.randn(2, 2)
sd = net.state_dict()

def fun(x, sd):
torch.func.functional_call(net, sd, x)

benchmark(fun, x, sd)


def test_exec_functional_call(benchmark, net):
x = torch.randn(2, 2)
fmodule, params, buffers = functorch_make_functional(net)
benchmark(fmodule, params, buffers, x)
Expand All @@ -132,6 +142,18 @@ def test_exec_td(benchmark, net):
benchmark(fmodule, x, params=params)


def test_exec_td_decorator(benchmark, net):
x = torch.randn(2, 2)
fmodule = net
params = TensorDict.from_module(fmodule)

def fun(x, params):
with params.to_module(net):
net(x)

benchmark(fun, x, params)


@torch.no_grad()
@pytest.mark.parametrize("stack", [True, False])
@pytest.mark.parametrize("tdmodule", [True, False])
Expand Down Expand Up @@ -169,6 +191,48 @@ def test_vmap_mlp_speed(benchmark, stack, tdmodule):
benchmark(fun, x, params)


@torch.no_grad()
@pytest.mark.parametrize("stack", [True, False])
@pytest.mark.parametrize("tdmodule", [True, False])
def test_vmap_mlp_speed_decorator(benchmark, stack, tdmodule):
# tests speed of vmapping over a transformer
device = "cuda" if torch.cuda.device_count() else "cpu"
t = nn.Sequential(
nn.Linear(64, 64, device=device),
nn.ReLU(),
nn.Linear(64, 64, device=device),
nn.ReLU(),
nn.Linear(64, 64, device=device),
nn.ReLU(),
nn.Linear(64, 64, device=device),
nn.ReLU(),
)
if tdmodule:
t = TensorDictModule(t, in_keys=["x"], out_keys=["y"])

x = torch.randn(1, 1, 64, device=device)
t.eval()
params = TensorDict.from_module(t)
if not stack:
params = params.expand(2).to_tensordict().lock_()
else:
params = torch.stack([params, params.clone()], 0).lock_()

def fun(x, params):
with params.to_module(t):
return t(x)

vfun = vmap(fun, (None, 0))

if tdmodule:
data = TensorDict({"x": x}, [])
vfun(data, params)
benchmark(vfun, data, params)
else:
vfun(x, params)
benchmark(vfun, x, params)


@torch.no_grad()
@pytest.mark.skipif(
not torch.cuda.device_count(), reason="cuda device required for test"
Expand Down Expand Up @@ -208,6 +272,53 @@ def test_vmap_transformer_speed(benchmark, stack, tdmodule):
benchmark(fun, x, x, params)


@torch.no_grad()
@pytest.mark.skipif(
not torch.cuda.device_count(), reason="cuda device required for test"
)
@pytest.mark.parametrize("stack", [True, False])
@pytest.mark.parametrize("tdmodule", [True, False])
def test_vmap_transformer_speed_decorator(benchmark, stack, tdmodule):
# tests speed of vmapping over a transformer
device = "cuda" if torch.cuda.device_count() else "cpu"
t = torch.nn.Transformer(
8,
dim_feedforward=8,
device=device,
batch_first=False,
)
if tdmodule:
t = TensorDictModule(t, in_keys=["x", "x"], out_keys=["y"])

x = torch.randn(2, 2, 8, device=device)
t.eval()
params = TensorDict.from_module(t)
if not stack:
params = params.expand(2).to_tensordict().lock_()
else:
params = torch.stack([params, params.clone()], 0).lock_()

if tdmodule:

def fun(x, params):
with params.to_module(t):
return t(x)

vfun = vmap(fun, (None, 0))
data = TensorDict({"x": x}, [])
vfun(data, params)
benchmark(vfun, data, params)
else:

def fun(x, params):
with params.to_module(t):
return t(x, x)

vfun = vmap(fun, (None, 0))
vfun(x, params)
benchmark(vfun, x, params)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
2 changes: 2 additions & 0 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,8 @@ def hook_in(
batch_size=n,
vmap_level=vmap_level,
):
if _is_tensor_collection(type(tensor)):
return tensor._remove_batch_dim(vmap_level, batch_size, out_dim)
return _remove_batch_dim(tensor, vmap_level, batch_size, out_dim)

out.hook_out = hook_out
Expand Down
44 changes: 27 additions & 17 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,20 +259,23 @@ def from_module(
return td_struct

@as_decorator()
def to_module(self, module, return_swap: bool = True, swap_dest=None):

def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None):
# we use __dict__ directly to avoid the getattr/setattr overhead whenever we can
__dict__ = module.__dict__
out = None
swap = None
has_set_device = False
if memo is None:
memo = {}
if return_swap:
# this could break if the device and batch-size are not congruent.
# For batch-size it is a minor issue (unlikely that a td with batch-size
# is passed with to_module) but for the device it could be a problem.
if swap_dest is None:
out = self.empty()
swap = self.empty()
swap.clear_device_()
else:
out = swap_dest
swap = swap_dest
memo[id(module)] = swap

for key, value in self.items():
if isinstance(value, (Tensor, ftdim.Tensor)):
Expand All @@ -295,24 +298,31 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None):
local_dest = swap_dest._get_str(key, default=NO_DEFAULT)
else:
local_dest = None
local_out = value.to_module(
__dict__["_modules"][key],
return_swap=return_swap,
swap_dest=local_dest,
)
if return_swap:
child = __dict__["_modules"][key]
if id(child) in memo:
local_out = memo[id(child)]
else:
local_out = value.to_module(
child,
return_swap=return_swap,
swap_dest=local_dest,
memo=memo,
)
# we don't want to do this op more than once
if (
if return_swap and (
not has_set_device
and out.device is not None
and swap.device is not None
and local_out.device is not None
and local_out.device != out.device
and local_out.device != swap.device
):
has_set_device = True
# map out to the local_out device
out = out.to(device=local_out.device)
out._set_str(key, local_out, inplace=False, validated=True)
return out
swap = swap.to(device=local_out.device)

if return_swap:
assert local_out is not None, key
swap._set_str(key, local_out, inplace=False, validated=True)
return swap

def __ne__(self, other: object) -> T | bool:
if _is_tensorclass(other):
Expand Down
11 changes: 9 additions & 2 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,9 @@ def from_module(module, as_module: bool = False, lock: bool = True):
...

@abc.abstractmethod
def to_module(self, module: nn.Module, return_swap: bool = False, swap_dest=None):
def to_module(
self, module: nn.Module, return_swap: bool = False, swap_dest=None, memo=None
):
"""Writes the content of a TensorDictBase instance onto a given nn.Module attributes, recursively.
Args:
Expand All @@ -337,6 +339,10 @@ def to_module(self, module: nn.Module, return_swap: bool = False, swap_dest=None
will be returned. Defaults to ``False``.
swap_dest (TensorDictBase, optional): if ``return_swap`` is ``True``,
the tensordict where the swap should be written.
memo (dict, optional): when the same module is present multiple times
in the input module, a memo is used to avoid fetching the params
that have just been set. This argument should be ignored during
regular calls to `to_module`.
Examples:
>>> from torch import nn
Expand Down Expand Up @@ -3119,7 +3125,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):
return self.lock_()
if last_op == self.__class__.to_module.__name__:
if is_tensor_collection(out):
return self.to_module(*args, **kwargs, swap_dest=out)
with out.unlock_():
return self.to_module(*args, **kwargs, swap_dest=out)
else:
raise RuntimeError(
"to_module cannot be used as a decorator when return_swap=False."
Expand Down
9 changes: 5 additions & 4 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,10 +1123,11 @@ def new_func(_self, *args, **kwargs):
out = func(_self, *args, **kwargs)
if self.attr is not None:
_attr_post = getattr(_self, self.attr)
if self.attr is None or (_attr_post is not _attr_pre):
out._last_op = (new_func.__name__, (args, kwargs, _self))
else:
out._last_op = None
if out is not None:
if self.attr is None or (_attr_post is not _attr_pre):
out._last_op = (new_func.__name__, (args, kwargs, _self))
else:
out._last_op = None
return out

return new_func
Expand Down
86 changes: 86 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3198,6 +3198,92 @@ def test_sd_module(self):
assert isinstance(val, nn.Parameter)


@pytest.mark.parametrize(
"module_name,input_name", [["_module_shared", "_x"], ["_transformer", "_tuple_x"]]
)
@pytest.mark.parametrize("as_module", [True, False])
class TestToModule:
@property
def _transformer(self):
# we use transformer because it's deep, has buffers etc.
return nn.Transformer(d_model=8, dim_feedforward=8).eval()

@property
def _module_shared(self):
# a module with the same layer appearing twice
l0 = nn.Linear(8, 9)
l1 = nn.Linear(9, 8)
return nn.Sequential(
l0,
l1,
nn.Sequential(
l0,
),
)

@property
def _tuple_x(self):
x = torch.randn(2, 2, 8)
return (x, x)

@property
def _x(self):
return (torch.randn(2, 2, 8),)

def test_static(self, module_name, input_name, as_module):
torch.manual_seed(0)
module = getattr(self, module_name)
x = getattr(self, input_name)
params = TensorDict.from_module(module, as_module=as_module)
params0 = params.clone().apply(
lambda t, p: nn.Parameter(t * 0) if isinstance(p, nn.Parameter) else t * 0,
params,
)
y = module(*x)
params0.to_module(module)
y0 = module(*x)
params.to_module(module)
y1 = module(*x)
torch.testing.assert_close(y, y1)
assert (y0 == 0).all()
assert (y0 != y1).all()

def test_cm(self, module_name, input_name, as_module):
torch.manual_seed(0)
module = getattr(self, module_name)
x = getattr(self, input_name)
params = TensorDict.from_module(module, as_module=as_module)
params0 = params.clone().apply(
lambda t, p: nn.Parameter(t * 0) if isinstance(p, nn.Parameter) else t * 0,
params,
)
y = module(*x)
with params0.to_module(module):
y0 = module(*x)
assert (params0 == TensorDict.from_module(module)).all()
y1 = module(*x)
torch.testing.assert_close(y, y1)
assert (y0 == 0).all()
assert (y0 != y1).all()
assert (TensorDict.from_module(module) == params).all()

def test_cm_meta(self, module_name, input_name, as_module):
torch.manual_seed(0)
module = getattr(self, module_name)
x = getattr(self, input_name)
params = TensorDict.from_module(module, as_module=as_module)
params_meta = params.detach().to("meta")
y = module(*x)
with params_meta.to_module(module):
module_meta = copy.deepcopy(module)
y1 = module(*x)
with params.to_module(module_meta):
y2 = module_meta(*x)
torch.testing.assert_close(y, y1)
torch.testing.assert_close(y, y2)
assert (TensorDict.from_module(module) == params).all()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 comments on commit 2ea264b

Please sign in to comment.