Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 22, 2024
1 parent 3bafc0a commit b0e9b81
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
18 changes: 15 additions & 3 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def __init__(
raise ValueError(
"default_interaction_mode is deprecated, use default_interaction_type instead."
)
self.default_interaction_type = default_interaction_type
self.default_interaction_type = InteractionType(default_interaction_type)

if isinstance(distribution_class, str):
distribution_class = distributions_maps.get(distribution_class.lower())
Expand Down Expand Up @@ -630,8 +630,20 @@ def forward(
tensordict_out: TensorDictBase | None = None,
**kwargs,
) -> TensorDictBase:
tensordict_out = self.get_dist_params(tensordict, tensordict_out, **kwargs)
return self.module[-1](tensordict_out, _requires_sample=self._requires_sample)
if (tensordict_out is None and self._select_before_return) or (tensordict_out is not None):
tensordict_exec = tensordict.copy()
else:
tensordict_exec = tensordict
tensordict_exec = self.get_dist_params(tensordict_exec, tensordict_out, **kwargs)
tensordict_exec = self.module[-1](tensordict_exec, _requires_sample=self._requires_sample)
if tensordict_out is not None:
result = tensordict_out
result.update(tensordict_exec, keys_to_update=self.out_keys)
else:
result = tensordict_exec
if self._select_before_return:
return tensordict.update(result.select(*self.out_keys))
return result


def _dynamo_friendly_to_dict(data):
Expand Down
4 changes: 1 addition & 3 deletions tensordict/nn/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,13 +470,11 @@ def forward(
tensordict_out: TensorDictBase | None = None,
**kwargs: Any,
) -> TensorDictBase:
if tensordict_out is None and self._select_before_return:
if (tensordict_out is None and self._select_before_return) or (tensordict_out is not None):
tensordict_exec = tensordict.copy()
else:
tensordict_exec = tensordict
if not len(kwargs):
if tensordict_out is not None:
tensordict_exec = tensordict_exec.copy()
for module in self.module:
tensordict_exec = self._run_module(module, tensordict_exec, **kwargs)
else:
Expand Down

0 comments on commit b0e9b81

Please sign in to comment.