Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 21, 2024
1 parent d5c452d commit fb35d6b
Showing 1 changed file with 25 additions and 31 deletions.
56 changes: 25 additions & 31 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from tensordict._C import unravel_key_list
from tensordict.nn import (
dispatch,
probabilistic as nn_probabilistic,
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
TensorDictModuleBase,
Expand All @@ -31,7 +30,11 @@
)
from tensordict.nn.distributions.composite import CompositeDistribution
from tensordict.nn.ensemble import EnsembleModule
from tensordict.nn.probabilistic import InteractionType, set_interaction_type
from tensordict.nn.probabilistic import (
interaction_type,
InteractionType,
set_interaction_type,
)
from tensordict.nn.utils import (
_set_dispatch_td_nn_modules,
set_skip_existing,
Expand Down Expand Up @@ -299,10 +302,8 @@ class Data:

@pytest.mark.parametrize("out_keys", [["loc", "scale"], ["loc_1", "scale_1"]])
@pytest.mark.parametrize("lazy", [True, False])
@pytest.mark.parametrize(
"interaction_type", [InteractionType.MODE, InteractionType.RANDOM, None]
)
def test_stateful_probabilistic_deprec(self, lazy, interaction_type, out_keys):
@pytest.mark.parametrize("it", [InteractionType.MODE, InteractionType.RANDOM, None])
def test_stateful_probabilistic_deprec(self, lazy, it, out_keys):
torch.manual_seed(0)
param_multiplier = 2
if lazy:
Expand Down Expand Up @@ -332,10 +333,10 @@ def test_stateful_probabilistic_deprec(self, lazy, interaction_type, out_keys):
tensordict_module = ProbabilisticTensorDictSequential(net, prob_module)

td = TensorDict({"in": torch.randn(3, 3)}, [3])
with set_interaction_type(interaction_type):
with set_interaction_type(it):
with (
pytest.warns(UserWarning, match="deterministic_sample")
if interaction_type in (InteractionType.DETERMINISTIC, None)
if it in (InteractionType.DETERMINISTIC, None)
else contextlib.nullcontext()
):
tensordict_module(td)
Expand All @@ -345,12 +346,8 @@ def test_stateful_probabilistic_deprec(self, lazy, interaction_type, out_keys):
@pytest.mark.parametrize("out_keys", [["low"], ["low1"], [("stuff", "low1")]])
@pytest.mark.parametrize("lazy", [True, False])
@pytest.mark.parametrize("max_dist", [1.0, 2.0])
@pytest.mark.parametrize(
"interaction_type", [InteractionType.MODE, InteractionType.RANDOM, None]
)
def test_stateful_probabilistic_kwargs(
self, lazy, interaction_type, out_keys, max_dist
):
@pytest.mark.parametrize("it", [InteractionType.MODE, InteractionType.RANDOM, None])
def test_stateful_probabilistic_kwargs(self, lazy, it, out_keys, max_dist):
torch.manual_seed(0)
if lazy:
net = nn.LazyLinear(4)
Expand All @@ -376,10 +373,10 @@ def test_stateful_probabilistic_kwargs(
tensordict_module = ProbabilisticTensorDictSequential(net, prob_module)

td = TensorDict({"in": torch.randn(3, 3)}, [3])
with set_interaction_type(interaction_type):
with set_interaction_type(it):
with (
pytest.warns(UserWarning, match="deterministic_sample")
if interaction_type in (None, InteractionType.DETERMINISTIC)
if it in (None, InteractionType.DETERMINISTIC)
else contextlib.nullcontext()
):
tensordict_module(td)
Expand Down Expand Up @@ -409,10 +406,8 @@ def test_nontensor(self):
],
)
@pytest.mark.parametrize("lazy", [True, False])
@pytest.mark.parametrize(
"interaction_type", [InteractionType.MODE, InteractionType.RANDOM, None]
)
def test_stateful_probabilistic(self, lazy, interaction_type, out_keys):
@pytest.mark.parametrize("it", [InteractionType.MODE, InteractionType.RANDOM, None])
def test_stateful_probabilistic(self, lazy, it, out_keys):
torch.manual_seed(0)
param_multiplier = 2
if lazy:
Expand Down Expand Up @@ -441,10 +436,10 @@ def test_stateful_probabilistic(self, lazy, interaction_type, out_keys):
)

td = TensorDict({"in": torch.randn(3, 3)}, [3])
with set_interaction_type(interaction_type):
with set_interaction_type(it):
with (
pytest.warns(UserWarning, match="deterministic_sample")
if interaction_type in (None, InteractionType.DETERMINISTIC)
if it in (None, InteractionType.DETERMINISTIC)
else contextlib.nullcontext()
):
tensordict_module(td)
Expand Down Expand Up @@ -1043,18 +1038,16 @@ def test_subsequence_weight_update(self):
assert torch.allclose(td_module[0].module.weight, sub_seq_1[0].module.weight)


@pytest.mark.parametrize(
"interaction_type", [InteractionType.RANDOM, InteractionType.MODE]
)
@pytest.mark.parametrize("it", [InteractionType.RANDOM, InteractionType.MODE])
class TestSIM:
def test_cm(self, interaction_type):
with set_interaction_type(interaction_type):
assert nn_probabilistic._INTERACTION_TYPE == interaction_type
def test_cm(self, it):
with set_interaction_type(it):
assert interaction_type() == it

def test_dec(self, interaction_type):
@set_interaction_type(interaction_type)
def test_dec(self, it):
@set_interaction_type(it)
def dummy():
assert nn_probabilistic._INTERACTION_TYPE == interaction_type
assert interaction_type() == it

dummy()

Expand Down Expand Up @@ -1950,6 +1943,7 @@ class MyModule(nn.Module):
params_m0 = params_m
params_m0 = params_m0.apply(lambda x: x.data * 0)
assert (params_m0 == 0).all()
assert not (params_m == 0).all()
with params_m0.to_module(m):
assert (params_m == 0).all()
assert not (params_m == 0).all()
Expand Down

0 comments on commit fb35d6b

Please sign in to comment.