Skip to content

Commit

Permalink
[Test] fix inline TDParams kwargs for nontensordata
Browse files Browse the repository at this point in the history
ghstack-source-id: da8b7f40d05715170a3e9f0b47763efe356afe5e
Pull Request resolved: #1095
  • Loading branch information
vmoens committed Nov 20, 2024
1 parent a5656cb commit 978eb6c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
7 changes: 6 additions & 1 deletion tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tensordict.nn.sequence import TensorDictSequential

from tensordict.nn.utils import _set_skip_existing_None
from tensordict.tensorclass import is_non_tensor
from tensordict.tensordict import TensorDictBase
from tensordict.utils import _zip_strict
from torch import distributions as D, Tensor
Expand Down Expand Up @@ -638,5 +639,9 @@ def _dynamo_friendly_to_dict(data):
return data
if isinstance(data, TensorDictBase):
# to_dict is recursive and we don't want that
return {data[key] for key in data.keys()}
items = dict(data.items())
for k, v in items.items():
if is_non_tensor(v):
items[k] = v.data
return items
return data
5 changes: 4 additions & 1 deletion test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from tensordict import (
assert_close,
NonTensorData,
PYTREE_REGISTERED_LAZY_TDS,
PYTREE_REGISTERED_TDS,
tensorclass,
Expand Down Expand Up @@ -665,7 +666,9 @@ def test_dispatch_tensor(self, mode):
torch.testing.assert_close(mod(x=x, y=y), mod_compile(x=x, y=y))

def test_prob_module_with_kwargs(self, mode):
kwargs = TensorDictParams(TensorDict(scale=1.0), no_convert=True)
kwargs = TensorDictParams(
TensorDict(scale=1.0, validate_args=NonTensorData(False)), no_convert=True
)
dist_cls = torch.distributions.Normal
mod = Mod(torch.nn.Linear(3, 3), in_keys=["inp"], out_keys=["loc"])
prob_mod = Seq(
Expand Down

0 comments on commit 978eb6c

Please sign in to comment.