diff --git a/test/test_nn.py b/test/test_nn.py index 6d6e0550b..b894b3756 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -3282,10 +3282,7 @@ def test_static(self, module_name, input_name, as_module, inplace): params = TensorDict.from_module(module, as_module=as_module) if inplace: params = params.clone() - params0 = params.clone().apply( - lambda t, p: nn.Parameter(t * 0) if isinstance(p, nn.Parameter) else t * 0, - params, - ) + params0 = params.clone().zero_() y = module(*x) params0.to_module(module, inplace=inplace) y0 = module(*x) @@ -3324,24 +3321,23 @@ def test_cm(self, module_name, input_name, as_module, inplace): module = getattr(self, module_name) x = getattr(self, input_name) params = TensorDict.from_module(module, as_module=as_module) - params0 = params.clone().apply( - lambda t, p: nn.Parameter(t * 0) if isinstance(p, nn.Parameter) else t * 0, - params, - ) + params0 = params.clone().zero_() y = module(*x) with params0.to_module(module, inplace=inplace): y0 = module(*x) - assert (params0 == TensorDict.from_module(module)).all() - - # check identities - for p1, p2 in zip( - TensorDict.from_module(module).values(True, True), - params0.values(True, True), - ): - if inplace: - assert p1 is not p2 - else: - assert p1 is p2 + if as_module: + # if as_module=False, params0 is not made of parameters anymore + assert (params0 == TensorDict.from_module(module)).all() + + # check identities + for p1, p2 in zip( + TensorDict.from_module(module).values(True, True), + params0.values(True, True), + ): + if inplace: + assert p1 is not p2 + else: + assert p1 is p2 y1 = module(*x) torch.testing.assert_close(y, y1)