From b229c59aba44915f51c3ed974e835a8e9f567a89 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 20 Dec 2024 17:35:21 +0000 Subject: [PATCH] [BugFix] Better return_log_prob=True for tensordict outputs ghstack-source-id: 977af3880f39cb341c1c715f1b8c9d59b7c580a0 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1155 --- tensordict/nn/probabilistic.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index a9b2cdf5a..621388f2c 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -510,9 +510,12 @@ def forward( kwargs = {"aggregate_probabilities": False} log_prob = dist.log_prob(out_tensors, **kwargs) if log_prob is not out_tensors: - # Composite dists return the tensordict_out directly when aggrgate_prob is False - out_tensors.set(self.log_prob_key, log_prob) - else: + if is_tensor_collection(log_prob): + out_tensors.update(log_prob) + else: + # Composite dists return the tensordict_out directly when aggrgate_prob is False + out_tensors.set(self.log_prob_key, log_prob) + elif dist.log_prob_key in out_tensors: out_tensors.rename_key_(dist.log_prob_key, self.log_prob_key) tensordict_out.update(out_tensors) else: