Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 29, 2024
1 parent 2a6262e commit e851a5c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
13 changes: 12 additions & 1 deletion tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,18 @@ def forward(
else:
result = tensordict_exec
if self._select_before_return:
return tensordict.update(result, keys_to_update=self.out_keys)
# We must also update any value that has been updated during the course of execution
# from the input data.
if is_compiling():
keys = [ # noqa: C416
k
for k in {k for k in self.out_keys}.union( # noqa: C416
{k for k in tensordict.keys(True, True)} # noqa: C416
)
]
else:
keys = list(set(self.out_keys + list(tensordict.keys(True, True))))
return tensordict.update(result, keys_to_update=keys)
return result


Expand Down
14 changes: 13 additions & 1 deletion tensordict/nn/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
)
FUNCTORCH_ERROR = "functorch not installed. Consider installing functorch to use this functionality."

try:
from torch.compiler import is_compiling
except ImportError:
from torch._dynamo import is_compiling

__all__ = ["TensorDictSequential"]

Expand Down Expand Up @@ -491,7 +495,15 @@ def forward(
if self._select_before_return:
# We must also update any value that has been updated during the course of execution
# from the input data.
keys = list(set(self.out_keys + list(tensordict.keys(True, True))))
if is_compiling():
keys = [ # noqa: C416
k
for k in {k for k in self.out_keys}.union( # noqa: C416
{k for k in tensordict.keys(True, True)} # noqa: C416
)
]
else:
keys = list(set(self.out_keys + list(tensordict.keys(True, True))))
return tensordict.update(result, keys_to_update=keys)
return result

Expand Down

0 comments on commit e851a5c

Please sign in to comment.