Skip to content

Commit

Permalink
[Feature] TensorDictModule in_keys allowed as Dict[str, tuple | list]…
Browse files Browse the repository at this point in the history
… to enable multi use of a sample feature (#1101)
  • Loading branch information
bachdj-px authored Nov 29, 2024
1 parent b539beb commit e871b7d
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 8 deletions.
51 changes: 47 additions & 4 deletions tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,11 +773,20 @@ class TensorDictModule(TensorDictModuleBase):
order given by the in_keys iterable.
If ``in_keys`` is a dictionary, its keys must correspond to the key
to be read in the tensordict and its values must match the name of
the keyword argument in the function signature.
the keyword argument in the function signature. If `out_to_in_map` is ``True``,
the mapping gets inverted so that the keys correspond to the keyword
arguments in the function signature.
out_keys (iterable of str): keys to be written to the input tensordict. The length of out_keys must match the
number of tensors returned by the embedded module. Using "_" as a key avoid writing tensor to output.
Keyword Args:
out_to_in_map (bool, optional): if ``True``, `in_keys` is read as if the keys are the arguments keys of
the :meth:`~.forward` method and the values are the keys in the input :class:`~tensordict.TensorDict`. If
``False`` or ``None`` (default), keys are considered to be the input keys and values the method's arguments keys.
.. warning::
The default value of `out_to_in_map` will change from ``False`` to ``True`` in the v0.9 release.
inplace (bool or string, optional): if ``True`` (default), the output of the module are written in the tensordict
provided to the :meth:`~.forward` method. If ``False``, a new :class:`~tensordict.TensorDict` with and empty
batch-size and no device is created. if ``"empty"``, :meth:`~tensordict.TensorDict.empty` will be used to
Expand Down Expand Up @@ -865,12 +874,24 @@ class TensorDictModule(TensorDictModuleBase):
Examples:
>>> module = TensorDictModule(lambda x, *, y: x+y,
... in_keys={'1': 'x', '2': 'y'}, out_keys=['z'],
... in_keys={'1': 'x', '2': 'y'}, out_keys=['z'], out_to_in_map=False
... )
>>> td = module(TensorDict({'1': torch.ones(()), '2': torch.ones(())*2}, []))
>>> td['z']
tensor(3.)
If `out_to_in_map` is set to ``True``, then the `in_keys` mapping is reversed. This way,
one can use the same input key for different keyword arguments.
Examples:
>>> module = TensorDictModule(lambda x, *, y, z: x+y+z,
... in_keys={'x': '1', 'y': '2', z: '2'}, out_keys=['t'], out_to_in_map=True
... )
>>> td = module(TensorDict({'1': torch.ones(()), '2': torch.ones(())*2}, []))
>>> td['t']
tensor(5.)
Functional calls to a tensordict module is easy:
Examples:
Expand Down Expand Up @@ -923,17 +944,39 @@ def __init__(
in_keys: NestedKey | List[NestedKey] | Dict[NestedKey:str],
out_keys: NestedKey | List[NestedKey],
*,
out_to_in_map: bool | None = None,
inplace: bool | str = True,
) -> None:
super().__init__()

if out_to_in_map is not None and not isinstance(in_keys, dict):
warnings.warn(
"out_to_in_map is not None but is only used when in_key is a dictionary."
)

if isinstance(in_keys, dict):
if out_to_in_map is None:
out_to_in_map = False
warnings.warn(
"Using a dictionary in_keys without specifying out_to_in_map is deprecated. "
"By default, out_to_in_map is `False` (`in_keys` keys as tensordict pointers, "
"values as kwarg name), but from version>=0.9, default will be `True` "
"(`in_keys` keys as func kwarg name, values as tensordict pointers). "
"Please use explicit out_to_in_map to indicate the ordering of the input keys. ",
DeprecationWarning,
stacklevel=2,
)

# write the kwargs and create a list instead
_in_keys = []
self._kwargs = []
for key, value in in_keys.items():
self._kwargs.append(value)
_in_keys.append(key)
if out_to_in_map: # arg: td_key
self._kwargs.append(key)
_in_keys.append(value)
else: # td_key: arg
self._kwargs.append(value)
_in_keys.append(key)
in_keys = _in_keys
else:
if isinstance(in_keys, (str, tuple)):
Expand Down
88 changes: 84 additions & 4 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ def fn(a, b=None, *, c=None):
return a + 1

if kwargs:
module = TensorDictModule(fn, in_keys=kwargs, out_keys=["a"])
module = TensorDictModule(
fn, in_keys=kwargs, out_keys=["a"], out_to_in_map=False
)
td = TensorDict(
{
"1": torch.ones(1),
Expand All @@ -171,6 +173,76 @@ def fn(a, b=None, *, c=None):
td = TensorDict({"1": torch.ones(1)}, [])
assert (module(td)["a"] == 2).all()

def test_unused_out_to_in_map(self):
def fn(x, y):
return x + y

with pytest.warns(
match="out_to_in_map is not None but is only used when in_key is a dictionary."
):
_ = TensorDictModule(fn, in_keys=["x"], out_keys=["a"], out_to_in_map=False)

def test_input_keys_dict_reversed(self):
in_keys = {"x": "1", "y": "2"}

def fn(x, y):
return x + y

module = TensorDictModule(
fn, in_keys=in_keys, out_keys=["a"], out_to_in_map=True
)

td = TensorDict({"1": torch.ones(1), "2": torch.ones(1) * 3}, [])
assert (module(td)["a"] == 4).all()

def test_input_keys_match_reversed(self):
in_keys = {"1": "x", "2": "y"}
reversed_in_keys = {v: k for k, v in in_keys.items()}

def fn(x, y):
return y - x

module = TensorDictModule(
fn, in_keys=in_keys, out_keys=["a"], out_to_in_map=False
)
reversed_module = TensorDictModule(
fn, in_keys=reversed_in_keys, out_keys=["a"], out_to_in_map=True
)

td = TensorDict({"1": torch.ones(1), "2": torch.ones(1) * 3}, [])

assert module(td)["a"] == reversed_module(td)["a"] == torch.Tensor([2])

@pytest.mark.parametrize("out_to_in_map", [True, False])
def test_input_keys_wrong_mapping(self, out_to_in_map):
in_keys = {"1": "x", "2": "y"}
if not out_to_in_map:
in_keys = {v: k for k, v in in_keys.items()}

def fn(x, y):
return x + y

module = TensorDictModule(
fn, in_keys=in_keys, out_keys=["a"], out_to_in_map=out_to_in_map
)

td = TensorDict({"1": torch.ones(1), "2": torch.ones(1) * 3}, [])

with pytest.raises(TypeError, match="got an unexpected keyword argument '1'"):
module(td)

def test_input_keys_dict_deprecated_warning(self):
in_keys = {"1": "x", "2": "y"}

def fn(x, y):
return x + y

with pytest.warns(
DeprecationWarning,
match="Using a dictionary in_keys without specifying out_to_in_map is deprecated.",
):
_ = TensorDictModule(fn, in_keys=in_keys, out_keys=["a"])

def test_reset(self):
torch.manual_seed(0)
net = nn.ModuleList([nn.Sequential(nn.Linear(1, 1), nn.ReLU())])
Expand Down Expand Up @@ -478,7 +550,10 @@ def test_functional_functorch(self):

def test_vmap_kwargs(self):
module = TensorDictModule(
lambda x, *, y: x + y, in_keys={"1": "x", "2": "y"}, out_keys=["z"]
lambda x, *, y: x + y,
in_keys={"1": "x", "2": "y"},
out_keys=["z"],
out_to_in_map=False,
)
td = TensorDict(
{"1": torch.ones((10,)), "2": torch.ones((10,)) * 2}, batch_size=[10]
Expand Down Expand Up @@ -723,7 +798,9 @@ def fn(a, b=None, *, c=None):
return a + 1

if kwargs:
module1 = TensorDictModule(fn, in_keys=kwargs, out_keys=["a"])
module1 = TensorDictModule(
fn, in_keys=kwargs, out_keys=["a"], out_to_in_map=False
)
td = TensorDict(
{
"input": torch.ones(1),
Expand Down Expand Up @@ -1176,7 +1253,10 @@ def mycallable():
module, in_keys=[("i", "i2")], out_keys=[(("o", "o2"), ("o3",))]
)
TensorDictModule(
module, in_keys={"i": "i1", (("i2",),): "i3"}, out_keys=[("o", "o2")]
module,
in_keys={"i": "i1", (("i2",),): "i3"},
out_keys=[("o", "o2")],
out_to_in_map=False,
)

# corner cases that should work
Expand Down

1 comment on commit e871b7d

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'GPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: e871b7d Previous: b539beb Ratio
benchmarks/tensorclass/test_torch_functions.py::test_ones_like 114.35480220350493 iter/sec (stddev: 0.0027535001409373172) 228.90161665832144 iter/sec (stddev: 0.00013188668840879867) 2.00

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.