From bcbcb1e40694e8d55fd7cc884ad7e40076ef3aac Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 26 Nov 2024 14:46:02 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- tensordict/nn/distributions/composite.py | 5 --- tensordict/nn/probabilistic.py | 45 +++++++++++++++++++----- test/test_nn.py | 25 +++++++++++++ 3 files changed, 61 insertions(+), 14 deletions(-) diff --git a/tensordict/nn/distributions/composite.py b/tensordict/nn/distributions/composite.py index 4285624f3..e6b831269 100644 --- a/tensordict/nn/distributions/composite.py +++ b/tensordict/nn/distributions/composite.py @@ -225,11 +225,6 @@ def from_distributions( def aggregate_probabilities(self): aggregate_probabilities = self._aggregate_probabilities if aggregate_probabilities is None: - warnings.warn( - "The default value of `aggregate_probabilities` will change from `False` to `True` in v0.7. " - "Please pass this value explicitly to avoid this warning.", - FutureWarning, - ) aggregate_probabilities = self._aggregate_probabilities = False return aggregate_probabilities diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index 61df65d2b..e36f556b4 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -11,6 +11,8 @@ from textwrap import indent from typing import Any, Dict, List, Optional +import torch + from tensordict._nestedkey import NestedKey from tensordict.nn import CompositeDistribution @@ -367,9 +369,12 @@ def get_dist(self, tensordict: TensorDictBase) -> D.Distribution: raise err return dist - def log_prob(self, tensordict): + def log_prob( + self, tensordict, *, dist: torch.distributions.Distribution | None = None + ): """Writes the log-probability of the distribution sample.""" - dist = self.get_dist(tensordict) + if dist is None: + dist = self.get_dist(tensordict) if isinstance(dist, CompositeDistribution): td = dist.log_prob(tensordict, aggregate_probabilities=False) return td.get(dist.log_prob_key) @@ -560,6 +565,8 @@ def __init__( self.__dict__["_det_part"] = TensorDictSequential(*modules[:-1]) super().__init__(*modules, partial_tolerant=partial_tolerant) + _dist_sample = ProbabilisticTensorDictModule._dist_sample + @property def det_part(self): return self._det_part @@ -584,17 +591,37 @@ def get_dist( **kwargs, ) -> D.Distribution: """Get the distribution that results from passing the input tensordict through the sequence, and then using the resulting parameters.""" - tensordict_out = self.get_dist_params(tensordict, tensordict_out, **kwargs) - return self.build_dist_from_params(tensordict_out) + td_copy = tensordict.copy() + dists = {} + for i, tdm in enumerate(self.module): + if isinstance( + tdm, (ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential) + ): + dist = tdm.get_dist(td_copy) + if i < len(self.module) - 1: + sample = tdm._dist_sample(dist, interaction_type=interaction_type()) + if isinstance(tdm, ProbabilisticTensorDictModule): + if isinstance(sample, torch.Tensor): + sample = [sample] + for val, key in zip(sample, tdm.out_keys): + td_copy.set(key, val) + else: + td_copy.update(sample) + dists[tdm.out_keys[0]] = dist + else: + td_copy = tdm(td_copy) + if len(dists) == 0: + raise RuntimeError(f"No distribution module found in {self}.") + elif len(dists) == 1: + return dist + return CompositeDistribution.from_distributions(td_copy, dists) def log_prob( self, tensordict, tensordict_out: TensorDictBase | None = None, **kwargs ): - tensordict_out = self.get_dist_params( - tensordict, - tensordict_out, - **kwargs, - ) + dist = self.get_dist(tensordict) + if isinstance(dist, CompositeDistribution): + return dist.log_prob(tensordict) return self.module[-1].log_prob(tensordict_out) def build_dist_from_params(self, tensordict: TensorDictBase) -> D.Distribution: diff --git a/test/test_nn.py b/test/test_nn.py index 13231e927..55f735cdf 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -963,6 +963,31 @@ def test_probtdseq(self, return_log_prob, td_out): == expected ) + def test_probtdseq_multdist(self): + + tdm0 = TensorDictModule(torch.nn.Linear(3, 4), in_keys=["x"], out_keys=["loc"]) + tdm1 = ProbabilisticTensorDictModule( + in_keys=["loc"], + out_keys=["y"], + distribution_class=torch.distributions.Normal, + distribution_kwargs={"scale": 1}, + ) + tdm2 = TensorDictModule(torch.nn.Linear(4, 5), in_keys=["y"], out_keys=["loc2"]) + tdm3 = ProbabilisticTensorDictModule( + in_keys={"loc": "loc2"}, + out_keys=["z"], + distribution_class=torch.distributions.Normal, + distribution_kwargs={"scale": 1}, + ) + + tdm = ProbabilisticTensorDictSequential(tdm0, tdm1, tdm2, tdm3) + dist = tdm.get_dist(TensorDict(x=torch.randn(10, 3))) + s = dist.sample() + assert isinstance(dist.log_prob(s), TensorDict) + v = tdm(TensorDict(x=torch.randn(10, 3))) + assert set(v.keys()) == {"x", "loc", "y", "loc2", "z"} + assert isinstance(tdm.log_prob(v), TensorDict) + @pytest.mark.parametrize("lazy", [True, False]) def test_stateful_probabilistic(self, lazy): torch.manual_seed(0)