diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index 3b254731d..2f2502353 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -336,7 +336,7 @@ def __init__( raise ValueError( "default_interaction_mode is deprecated, use default_interaction_type instead." ) - self.default_interaction_type = default_interaction_type + self.default_interaction_type = InteractionType(default_interaction_type) if isinstance(distribution_class, str): distribution_class = distributions_maps.get(distribution_class.lower()) @@ -630,8 +630,20 @@ def forward( tensordict_out: TensorDictBase | None = None, **kwargs, ) -> TensorDictBase: - tensordict_out = self.get_dist_params(tensordict, tensordict_out, **kwargs) - return self.module[-1](tensordict_out, _requires_sample=self._requires_sample) + if (tensordict_out is None and self._select_before_return) or (tensordict_out is not None): + tensordict_exec = tensordict.copy() + else: + tensordict_exec = tensordict + tensordict_exec = self.get_dist_params(tensordict_exec, tensordict_out, **kwargs) + tensordict_exec = self.module[-1](tensordict_exec, _requires_sample=self._requires_sample) + if tensordict_out is not None: + result = tensordict_out + result.update(tensordict_exec, keys_to_update=self.out_keys) + else: + result = tensordict_exec + if self._select_before_return: + return tensordict.update(result.select(*self.out_keys)) + return result def _dynamo_friendly_to_dict(data): diff --git a/tensordict/nn/sequence.py b/tensordict/nn/sequence.py index d33577975..2dc5f8f52 100644 --- a/tensordict/nn/sequence.py +++ b/tensordict/nn/sequence.py @@ -470,13 +470,11 @@ def forward( tensordict_out: TensorDictBase | None = None, **kwargs: Any, ) -> TensorDictBase: - if tensordict_out is None and self._select_before_return: + if (tensordict_out is None and self._select_before_return) or (tensordict_out is not None): tensordict_exec = tensordict.copy() else: tensordict_exec = tensordict if not len(kwargs): - if tensordict_out is not None: - tensordict_exec = tensordict_exec.copy() for module in self.module: tensordict_exec = self._run_module(module, tensordict_exec, **kwargs) else: