diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index 24c9fe246..28f0ce2de 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -948,6 +948,21 @@ def get_dist( include_sum=self.include_sum, ) + @property + def default_interaction_type(self): + """Returns the `default_interaction_type` of the module using an iterative heuristic. + + This property iterates over all modules in reverse order, attempting to retrieve the + `default_interaction_type` attribute from any child module. The first non-None value + encountered is returned. If no such value is found, a default `interaction_type()` is returned. + + """ + for m in reversed(self.module): + interaction = getattr(m, "default_interaction_type", None) + if interaction is not None: + return interaction + return interaction_type() + def log_prob( self, tensordict, diff --git a/test/test_nn.py b/test/test_nn.py index 0d6085ef9..ff22382a9 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -410,6 +410,7 @@ def test_stateful_probabilistic_deprec(self, lazy, it, out_keys): ) tensordict_module = ProbabilisticTensorDictSequential(net, prob_module) + assert tensordict_module.default_interaction_type is not None td = TensorDict({"in": torch.randn(3, 3)}, [3]) with set_interaction_type(it): @@ -450,6 +451,7 @@ def test_stateful_probabilistic_kwargs(self, lazy, it, out_keys, max_dist): ) tensordict_module = ProbabilisticTensorDictSequential(net, prob_module) + assert tensordict_module.default_interaction_type is not None td = TensorDict({"in": torch.randn(3, 3)}, [3]) with set_interaction_type(it): @@ -513,6 +515,7 @@ def test_stateful_probabilistic(self, lazy, it, out_keys): tensordict_module = ProbabilisticTensorDictSequential( net, normal_params, prob_module ) + assert tensordict_module.default_interaction_type is not None td = TensorDict({"in": torch.randn(3, 3)}, [3]) with set_interaction_type(it): @@ -962,6 +965,7 @@ def test_stateful_probabilistic_deprec(self, lazy): tdmodule = ProbabilisticTensorDictSequential( tdmodule1, dummy_tdmodule, tdmodule2, prob_module ) + assert tdmodule.default_interaction_type is not None assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 4 @@ -1002,6 +1006,7 @@ def test_probtdseq(self, return_log_prob, td_out): default_interaction_type="random", ), ) + assert mod.default_interaction_type is not None inp = TensorDict({"a": 0.0, "b": 1.0}) inp_clone = inp.clone() if td_out: @@ -1073,6 +1078,7 @@ def test_probtdseq_multdist(self, include_sum, aggregate_probabilities, inplace) inplace=inplace, return_composite=True, ) + assert tdm.default_interaction_type is not None dist: CompositeDistribution = tdm.get_dist(TensorDict(x=torch.randn(10, 3))) s = dist.sample() assert dist.aggregate_probabilities is aggregate_probabilities @@ -1129,6 +1135,7 @@ def test_probtdseq_intermediate_dist( inplace=inplace, return_composite=True, ) + assert tdm.default_interaction_type is not None dist: CompositeDistribution = tdm.get_dist(TensorDict(x=torch.randn(10, 3))) assert isinstance(dist, CompositeDistribution)