Skip to content

Commit

Permalink
[Performance] Faster dispatch (#487)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Dec 4, 2023
1 parent f16c076 commit 08597bc
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 5 deletions.
2 changes: 1 addition & 1 deletion tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,7 @@ def from_dict(cls, input_dict, batch_size=None, device=None, batch_dims=None):
"Cannot pass both batch_size and batch_dims to `from_dict`."
)

batch_size_set = [] if batch_size is None else batch_size
batch_size_set = torch.Size(()) if batch_size is None else batch_size
for key, value in list(input_dict.items()):
if isinstance(value, (dict,)):
# we don't know if another tensor of smaller size is coming
Expand Down
6 changes: 2 additions & 4 deletions tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch
from cloudpickle import dumps as cloudpickle_dumps, loads as cloudpickle_loads
from tensordict._td import is_tensor_collection, TensorDictBase
from tensordict._tensordict import unravel_key_list
from tensordict._tensordict import _unravel_key_to_tuple, unravel_key_list
from tensordict.functional import make_tensordict

from tensordict.nn.functional_modules import (
Expand Down Expand Up @@ -255,9 +255,7 @@ def wrapper(_self, *args: Any, **kwargs: Any) -> Any:
if isinstance(dest, str):
dest = getattr(_self, dest)
for key in source:
expected_key = (
self.separator.join(key) if isinstance(key, tuple) else key
)
expected_key = self.separator.join(_unravel_key_to_tuple(key))
if len(args):
tensordict_values[key] = args[0]
args = args[1:]
Expand Down

0 comments on commit 08597bc

Please sign in to comment.