diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index a9b2cdf5a..621388f2c 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -510,9 +510,12 @@ def forward( kwargs = {"aggregate_probabilities": False} log_prob = dist.log_prob(out_tensors, **kwargs) if log_prob is not out_tensors: - # Composite dists return the tensordict_out directly when aggrgate_prob is False - out_tensors.set(self.log_prob_key, log_prob) - else: + if is_tensor_collection(log_prob): + out_tensors.update(log_prob) + else: + # Composite dists return the tensordict_out directly when aggrgate_prob is False + out_tensors.set(self.log_prob_key, log_prob) + elif dist.log_prob_key in out_tensors: out_tensors.rename_key_(dist.log_prob_key, self.log_prob_key) tensordict_out.update(out_tensors) else: