diff --git a/benchmarks/nn/functional_benchmarks_test.py b/benchmarks/nn/functional_benchmarks_test.py index 096b3901b..0308797eb 100644 --- a/benchmarks/nn/functional_benchmarks_test.py +++ b/benchmarks/nn/functional_benchmarks_test.py @@ -120,6 +120,16 @@ def test_instantiation_td(benchmark, net): # Execution def test_exec_functorch(benchmark, net): + x = torch.randn(2, 2) + sd = net.state_dict() + + def fun(x, sd): + torch.func.functional_call(net, sd, x) + + benchmark(fun, x, sd) + + +def test_exec_functional_call(benchmark, net): x = torch.randn(2, 2) fmodule, params, buffers = functorch_make_functional(net) benchmark(fmodule, params, buffers, x) @@ -132,6 +142,18 @@ def test_exec_td(benchmark, net): benchmark(fmodule, x, params=params) +def test_exec_td_decorator(benchmark, net): + x = torch.randn(2, 2) + fmodule = net + params = TensorDict.from_module(fmodule) + + def fun(x, params): + with params.to_module(net): + net(x) + + benchmark(fun, x, params) + + @torch.no_grad() @pytest.mark.parametrize("stack", [True, False]) @pytest.mark.parametrize("tdmodule", [True, False]) @@ -169,6 +191,48 @@ def test_vmap_mlp_speed(benchmark, stack, tdmodule): benchmark(fun, x, params) +@torch.no_grad() +@pytest.mark.parametrize("stack", [True, False]) +@pytest.mark.parametrize("tdmodule", [True, False]) +def test_vmap_mlp_speed_decorator(benchmark, stack, tdmodule): + # tests speed of vmapping over a transformer + device = "cuda" if torch.cuda.device_count() else "cpu" + t = nn.Sequential( + nn.Linear(64, 64, device=device), + nn.ReLU(), + nn.Linear(64, 64, device=device), + nn.ReLU(), + nn.Linear(64, 64, device=device), + nn.ReLU(), + nn.Linear(64, 64, device=device), + nn.ReLU(), + ) + if tdmodule: + t = TensorDictModule(t, in_keys=["x"], out_keys=["y"]) + + x = torch.randn(1, 1, 64, device=device) + t.eval() + params = TensorDict.from_module(t) + if not stack: + params = params.expand(2).to_tensordict().lock_() + else: + params = torch.stack([params, params.clone()], 0).lock_() + + def fun(x, params): + with params.to_module(t): + return t(x) + + vfun = vmap(fun, (None, 0)) + + if tdmodule: + data = TensorDict({"x": x}, []) + vfun(data, params) + benchmark(vfun, data, params) + else: + vfun(x, params) + benchmark(vfun, x, params) + + @torch.no_grad() @pytest.mark.skipif( not torch.cuda.device_count(), reason="cuda device required for test" @@ -208,6 +272,53 @@ def test_vmap_transformer_speed(benchmark, stack, tdmodule): benchmark(fun, x, x, params) +@torch.no_grad() +@pytest.mark.skipif( + not torch.cuda.device_count(), reason="cuda device required for test" +) +@pytest.mark.parametrize("stack", [True, False]) +@pytest.mark.parametrize("tdmodule", [True, False]) +def test_vmap_transformer_speed_decorator(benchmark, stack, tdmodule): + # tests speed of vmapping over a transformer + device = "cuda" if torch.cuda.device_count() else "cpu" + t = torch.nn.Transformer( + 8, + dim_feedforward=8, + device=device, + batch_first=False, + ) + if tdmodule: + t = TensorDictModule(t, in_keys=["x", "x"], out_keys=["y"]) + + x = torch.randn(2, 2, 8, device=device) + t.eval() + params = TensorDict.from_module(t) + if not stack: + params = params.expand(2).to_tensordict().lock_() + else: + params = torch.stack([params, params.clone()], 0).lock_() + + if tdmodule: + + def fun(x, params): + with params.to_module(t): + return t(x) + + vfun = vmap(fun, (None, 0)) + data = TensorDict({"x": x}, []) + vfun(data, params) + benchmark(vfun, data, params) + else: + + def fun(x, params): + with params.to_module(t): + return t(x, x) + + vfun = vmap(fun, (None, 0)) + vfun(x, params) + benchmark(vfun, x, params) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 84b5bf1bb..488681dcd 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -863,6 +863,8 @@ def hook_in( batch_size=n, vmap_level=vmap_level, ): + if _is_tensor_collection(type(tensor)): + return tensor._remove_batch_dim(vmap_level, batch_size, out_dim) return _remove_batch_dim(tensor, vmap_level, batch_size, out_dim) out.hook_out = hook_out diff --git a/tensordict/_td.py b/tensordict/_td.py index fd84fa9a3..2e01b91cd 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -259,20 +259,23 @@ def from_module( return td_struct @as_decorator() - def to_module(self, module, return_swap: bool = True, swap_dest=None): - + def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None): # we use __dict__ directly to avoid the getattr/setattr overhead whenever we can __dict__ = module.__dict__ - out = None + swap = None has_set_device = False + if memo is None: + memo = {} if return_swap: # this could break if the device and batch-size are not congruent. # For batch-size it is a minor issue (unlikely that a td with batch-size # is passed with to_module) but for the device it could be a problem. if swap_dest is None: - out = self.empty() + swap = self.empty() + swap.clear_device_() else: - out = swap_dest + swap = swap_dest + memo[id(module)] = swap for key, value in self.items(): if isinstance(value, (Tensor, ftdim.Tensor)): @@ -295,24 +298,31 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None): local_dest = swap_dest._get_str(key, default=NO_DEFAULT) else: local_dest = None - local_out = value.to_module( - __dict__["_modules"][key], - return_swap=return_swap, - swap_dest=local_dest, - ) - if return_swap: + child = __dict__["_modules"][key] + if id(child) in memo: + local_out = memo[id(child)] + else: + local_out = value.to_module( + child, + return_swap=return_swap, + swap_dest=local_dest, + memo=memo, + ) # we don't want to do this op more than once - if ( + if return_swap and ( not has_set_device - and out.device is not None + and swap.device is not None and local_out.device is not None - and local_out.device != out.device + and local_out.device != swap.device ): has_set_device = True # map out to the local_out device - out = out.to(device=local_out.device) - out._set_str(key, local_out, inplace=False, validated=True) - return out + swap = swap.to(device=local_out.device) + + if return_swap: + assert local_out is not None, key + swap._set_str(key, local_out, inplace=False, validated=True) + return swap def __ne__(self, other: object) -> T | bool: if _is_tensorclass(other): diff --git a/tensordict/base.py b/tensordict/base.py index aec2f9fec..74e4980f3 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -328,7 +328,9 @@ def from_module(module, as_module: bool = False, lock: bool = True): ... @abc.abstractmethod - def to_module(self, module: nn.Module, return_swap: bool = False, swap_dest=None): + def to_module( + self, module: nn.Module, return_swap: bool = False, swap_dest=None, memo=None + ): """Writes the content of a TensorDictBase instance onto a given nn.Module attributes, recursively. Args: @@ -337,6 +339,10 @@ def to_module(self, module: nn.Module, return_swap: bool = False, swap_dest=None will be returned. Defaults to ``False``. swap_dest (TensorDictBase, optional): if ``return_swap`` is ``True``, the tensordict where the swap should be written. + memo (dict, optional): when the same module is present multiple times + in the input module, a memo is used to avoid fetching the params + that have just been set. This argument should be ignored during + regular calls to `to_module`. Examples: >>> from torch import nn @@ -3119,7 +3125,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): return self.lock_() if last_op == self.__class__.to_module.__name__: if is_tensor_collection(out): - return self.to_module(*args, **kwargs, swap_dest=out) + with out.unlock_(): + return self.to_module(*args, **kwargs, swap_dest=out) else: raise RuntimeError( "to_module cannot be used as a decorator when return_swap=False." diff --git a/tensordict/utils.py b/tensordict/utils.py index 7ec0056b5..8c462d1dc 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1123,10 +1123,11 @@ def new_func(_self, *args, **kwargs): out = func(_self, *args, **kwargs) if self.attr is not None: _attr_post = getattr(_self, self.attr) - if self.attr is None or (_attr_post is not _attr_pre): - out._last_op = (new_func.__name__, (args, kwargs, _self)) - else: - out._last_op = None + if out is not None: + if self.attr is None or (_attr_post is not _attr_pre): + out._last_op = (new_func.__name__, (args, kwargs, _self)) + else: + out._last_op = None return out return new_func diff --git a/test/test_nn.py b/test/test_nn.py index c45158340..13e420a67 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -3198,6 +3198,92 @@ def test_sd_module(self): assert isinstance(val, nn.Parameter) +@pytest.mark.parametrize( + "module_name,input_name", [["_module_shared", "_x"], ["_transformer", "_tuple_x"]] +) +@pytest.mark.parametrize("as_module", [True, False]) +class TestToModule: + @property + def _transformer(self): + # we use transformer because it's deep, has buffers etc. + return nn.Transformer(d_model=8, dim_feedforward=8).eval() + + @property + def _module_shared(self): + # a module with the same layer appearing twice + l0 = nn.Linear(8, 9) + l1 = nn.Linear(9, 8) + return nn.Sequential( + l0, + l1, + nn.Sequential( + l0, + ), + ) + + @property + def _tuple_x(self): + x = torch.randn(2, 2, 8) + return (x, x) + + @property + def _x(self): + return (torch.randn(2, 2, 8),) + + def test_static(self, module_name, input_name, as_module): + torch.manual_seed(0) + module = getattr(self, module_name) + x = getattr(self, input_name) + params = TensorDict.from_module(module, as_module=as_module) + params0 = params.clone().apply( + lambda t, p: nn.Parameter(t * 0) if isinstance(p, nn.Parameter) else t * 0, + params, + ) + y = module(*x) + params0.to_module(module) + y0 = module(*x) + params.to_module(module) + y1 = module(*x) + torch.testing.assert_close(y, y1) + assert (y0 == 0).all() + assert (y0 != y1).all() + + def test_cm(self, module_name, input_name, as_module): + torch.manual_seed(0) + module = getattr(self, module_name) + x = getattr(self, input_name) + params = TensorDict.from_module(module, as_module=as_module) + params0 = params.clone().apply( + lambda t, p: nn.Parameter(t * 0) if isinstance(p, nn.Parameter) else t * 0, + params, + ) + y = module(*x) + with params0.to_module(module): + y0 = module(*x) + assert (params0 == TensorDict.from_module(module)).all() + y1 = module(*x) + torch.testing.assert_close(y, y1) + assert (y0 == 0).all() + assert (y0 != y1).all() + assert (TensorDict.from_module(module) == params).all() + + def test_cm_meta(self, module_name, input_name, as_module): + torch.manual_seed(0) + module = getattr(self, module_name) + x = getattr(self, input_name) + params = TensorDict.from_module(module, as_module=as_module) + params_meta = params.detach().to("meta") + y = module(*x) + with params_meta.to_module(module): + module_meta = copy.deepcopy(module) + y1 = module(*x) + with params.to_module(module_meta): + y2 = module_meta(*x) + torch.testing.assert_close(y, y1) + torch.testing.assert_close(y, y2) + assert (TensorDict.from_module(module) == params).all() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)