From fb35d6b46a243046ac207d3c8af01f1979214ca4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 21 Nov 2024 16:58:49 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- test/test_nn.py | 56 ++++++++++++++++++++++--------------------------- 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index 274d08b29..a3b630956 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -16,7 +16,6 @@ from tensordict._C import unravel_key_list from tensordict.nn import ( dispatch, - probabilistic as nn_probabilistic, ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential, TensorDictModuleBase, @@ -31,7 +30,11 @@ ) from tensordict.nn.distributions.composite import CompositeDistribution from tensordict.nn.ensemble import EnsembleModule -from tensordict.nn.probabilistic import InteractionType, set_interaction_type +from tensordict.nn.probabilistic import ( + interaction_type, + InteractionType, + set_interaction_type, +) from tensordict.nn.utils import ( _set_dispatch_td_nn_modules, set_skip_existing, @@ -299,10 +302,8 @@ class Data: @pytest.mark.parametrize("out_keys", [["loc", "scale"], ["loc_1", "scale_1"]]) @pytest.mark.parametrize("lazy", [True, False]) - @pytest.mark.parametrize( - "interaction_type", [InteractionType.MODE, InteractionType.RANDOM, None] - ) - def test_stateful_probabilistic_deprec(self, lazy, interaction_type, out_keys): + @pytest.mark.parametrize("it", [InteractionType.MODE, InteractionType.RANDOM, None]) + def test_stateful_probabilistic_deprec(self, lazy, it, out_keys): torch.manual_seed(0) param_multiplier = 2 if lazy: @@ -332,10 +333,10 @@ def test_stateful_probabilistic_deprec(self, lazy, interaction_type, out_keys): tensordict_module = ProbabilisticTensorDictSequential(net, prob_module) td = TensorDict({"in": torch.randn(3, 3)}, [3]) - with set_interaction_type(interaction_type): + with set_interaction_type(it): with ( pytest.warns(UserWarning, match="deterministic_sample") - if interaction_type in (InteractionType.DETERMINISTIC, None) + if it in (InteractionType.DETERMINISTIC, None) else contextlib.nullcontext() ): tensordict_module(td) @@ -345,12 +346,8 @@ def test_stateful_probabilistic_deprec(self, lazy, interaction_type, out_keys): @pytest.mark.parametrize("out_keys", [["low"], ["low1"], [("stuff", "low1")]]) @pytest.mark.parametrize("lazy", [True, False]) @pytest.mark.parametrize("max_dist", [1.0, 2.0]) - @pytest.mark.parametrize( - "interaction_type", [InteractionType.MODE, InteractionType.RANDOM, None] - ) - def test_stateful_probabilistic_kwargs( - self, lazy, interaction_type, out_keys, max_dist - ): + @pytest.mark.parametrize("it", [InteractionType.MODE, InteractionType.RANDOM, None]) + def test_stateful_probabilistic_kwargs(self, lazy, it, out_keys, max_dist): torch.manual_seed(0) if lazy: net = nn.LazyLinear(4) @@ -376,10 +373,10 @@ def test_stateful_probabilistic_kwargs( tensordict_module = ProbabilisticTensorDictSequential(net, prob_module) td = TensorDict({"in": torch.randn(3, 3)}, [3]) - with set_interaction_type(interaction_type): + with set_interaction_type(it): with ( pytest.warns(UserWarning, match="deterministic_sample") - if interaction_type in (None, InteractionType.DETERMINISTIC) + if it in (None, InteractionType.DETERMINISTIC) else contextlib.nullcontext() ): tensordict_module(td) @@ -409,10 +406,8 @@ def test_nontensor(self): ], ) @pytest.mark.parametrize("lazy", [True, False]) - @pytest.mark.parametrize( - "interaction_type", [InteractionType.MODE, InteractionType.RANDOM, None] - ) - def test_stateful_probabilistic(self, lazy, interaction_type, out_keys): + @pytest.mark.parametrize("it", [InteractionType.MODE, InteractionType.RANDOM, None]) + def test_stateful_probabilistic(self, lazy, it, out_keys): torch.manual_seed(0) param_multiplier = 2 if lazy: @@ -441,10 +436,10 @@ def test_stateful_probabilistic(self, lazy, interaction_type, out_keys): ) td = TensorDict({"in": torch.randn(3, 3)}, [3]) - with set_interaction_type(interaction_type): + with set_interaction_type(it): with ( pytest.warns(UserWarning, match="deterministic_sample") - if interaction_type in (None, InteractionType.DETERMINISTIC) + if it in (None, InteractionType.DETERMINISTIC) else contextlib.nullcontext() ): tensordict_module(td) @@ -1043,18 +1038,16 @@ def test_subsequence_weight_update(self): assert torch.allclose(td_module[0].module.weight, sub_seq_1[0].module.weight) -@pytest.mark.parametrize( - "interaction_type", [InteractionType.RANDOM, InteractionType.MODE] -) +@pytest.mark.parametrize("it", [InteractionType.RANDOM, InteractionType.MODE]) class TestSIM: - def test_cm(self, interaction_type): - with set_interaction_type(interaction_type): - assert nn_probabilistic._INTERACTION_TYPE == interaction_type + def test_cm(self, it): + with set_interaction_type(it): + assert interaction_type() == it - def test_dec(self, interaction_type): - @set_interaction_type(interaction_type) + def test_dec(self, it): + @set_interaction_type(it) def dummy(): - assert nn_probabilistic._INTERACTION_TYPE == interaction_type + assert interaction_type() == it dummy() @@ -1950,6 +1943,7 @@ class MyModule(nn.Module): params_m0 = params_m params_m0 = params_m0.apply(lambda x: x.data * 0) assert (params_m0 == 0).all() + assert not (params_m == 0).all() with params_m0.to_module(m): assert (params_m == 0).all() assert not (params_m == 0).all()