Skip to content

Commit

Permalink
feat(nn.common.TensorDictModule): support tuple values in in_keys for…
Browse files Browse the repository at this point in the history
… flexible input key dispatching
  • Loading branch information
bachdj-px committed Nov 26, 2024
1 parent c842730 commit 42a61e0
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 12 deletions.
52 changes: 44 additions & 8 deletions tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,11 +773,17 @@ class TensorDictModule(TensorDictModuleBase):
order given by the in_keys iterable.
If ``in_keys`` is a dictionary, its keys must correspond to the key
to be read in the tensordict and its values must match the name of
the keyword argument in the function signature.
the keyword argument in the function signature. if `out_to_in_map`,
the mapping gets inverted so that the keys correspond to the keyword
arguments in the function signature.
out_keys (iterable of str): keys to be written to the input tensordict. The length of out_keys must match the
number of tensors returned by the embedded module. Using "_" as a key avoid writing tensor to output.
Keyword Args:
out_to_in_map (bool or None, optional): if ``True``, `in_keys` is read as if the keys are the arguments keys of
the :meth:`~.forward` method and the values are the keys in the input :class:`~tensordict.TensorDict`. If
`False`, keys are considered go be the input keys and values the method's arguments. If `None` (default), the
behaviour is the same as for `False` but a deprecation warning is raised.
inplace (bool or string, optional): if ``True`` (default), the output of the module are written in the tensordict
provided to the :meth:`~.forward` method. If ``False``, a new :class:`~tensordict.TensorDict` with and empty
batch-size and no device is created. if ``"empty"``, :meth:`~tensordict.TensorDict.empty` will be used to
Expand Down Expand Up @@ -806,7 +812,7 @@ class TensorDictModule(TensorDictModuleBase):
Examples:
>>> from tensordict import TensorDict
>>> # one can wrap regular nn.Module
>>> module = TensorDictModule(nn.Transformer(128), in_keys=["input", "tgt"], out_keys=["out"])
>>> module = TensorDictModule(nn.Transformer(128), in_keys=["input", "tgt"], out_keys=["out"], out_to_in_map=False)
>>> input = torch.ones(2, 3, 128)
>>> tgt = torch.zeros(2, 3, 128)
>>> data = TensorDict({"input": input, "tgt": tgt}, batch_size=[2, 3])
Expand All @@ -827,7 +833,7 @@ class TensorDictModule(TensorDictModuleBase):
>>> out = module(input, tgt)
>>> assert out.shape == input.shape
>>> # we can also wrap regular functions
>>> module = TensorDictModule(lambda x: (x-1, x+1), in_keys=[("input", "x")], out_keys=[("output", "x-1"), ("output", "x+1")])
>>> module = TensorDictModule(lambda x: (x-1, x+1), in_keys=[("input", "x")], out_keys=[("output", "x-1"), ("output", "x+1")], out_to_in_map=False)
>>> module(TensorDict({("input", "x"): torch.zeros(())}, batch_size=[]))
TensorDict(
fields={
Expand All @@ -851,7 +857,7 @@ class TensorDictModule(TensorDictModuleBase):
We can use TensorDictModule to populate a tensordict:
Examples:
>>> module = TensorDictModule(lambda: torch.randn(3), in_keys=[], out_keys=["x"])
>>> module = TensorDictModule(lambda: torch.randn(3), in_keys=[], out_keys=["x"], out_to_in_map=False)
>>> print(module(TensorDict({}, batch_size=[])))
TensorDict(
fields={
Expand All @@ -865,12 +871,24 @@ class TensorDictModule(TensorDictModuleBase):
Examples:
>>> module = TensorDictModule(lambda x, *, y: x+y,
... in_keys={'1': 'x', '2': 'y'}, out_keys=['z'],
... in_keys={'1': 'x', '2': 'y'}, out_keys=['z'], out_to_in_map=False
... )
>>> td = module(TensorDict({'1': torch.ones(()), '2': torch.ones(())*2}, []))
>>> td['z']
tensor(3.)
If `out_to_in_map` is set to `True`, then the `in_keys` mapping is reversed. This way,
one can use the same input key for different keyword arguments.
Examples:
>>> module = TensorDictModule(lambda x, *, y, z: x+y+z,
... in_keys={'x': '1', 'y': '2', z: '2'}, out_keys=['t'], out_to_in_map=True
... )
>>> td = module(TensorDict({'1': torch.ones(()), '2': torch.ones(())*2}, []))
>>> td['t']
tensor(5.)
Functional calls to a tensordict module is easy:
Examples:
Expand Down Expand Up @@ -914,7 +932,7 @@ class TensorDictModule(TensorDictModuleBase):
"""

_IN_KEY_ERR = "in_keys must be of type list, str or tuples of str, or dict."
_IN_KEY_ERR = "in_keys must be of type list, str or tuples of str, or dict, or dict of str and list."
_OUT_KEY_ERR = "out_keys must be of type list, str or tuples of str."

def __init__(
Expand All @@ -923,17 +941,35 @@ def __init__(
in_keys: NestedKey | List[NestedKey] | Dict[NestedKey:str],
out_keys: NestedKey | List[NestedKey],
*,
out_to_in_map: bool | None = None,
inplace: bool | str = True,
) -> None:
super().__init__()

if out_to_in_map is not None and not isinstance(in_keys, dict):
warnings.warn(
"out_to_in_map is not None but is only used when in_key` is a dictionary."
)

if isinstance(in_keys, dict):
if out_to_in_map is None:
warnings.warn(
"Using a dictionary in_keys without specifying out_to_in_map is deprecated."
"Use out_to_in_map to indicate the ordering of the input keys.",
DeprecationWarning,
stacklevel=2,
)

# write the kwargs and create a list instead
_in_keys = []
self._kwargs = []
for key, value in in_keys.items():
self._kwargs.append(value)
_in_keys.append(key)
if out_to_in_map: # arg: td_key
self._kwargs.append(key)
_in_keys.append(value)
else: # td_key: arg
self._kwargs.append(value)
_in_keys.append(key)
in_keys = _in_keys
else:
if isinstance(in_keys, (str, tuple)):
Expand Down
88 changes: 84 additions & 4 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,9 @@ def fn(a, b=None, *, c=None):
return a + 1

if kwargs:
module = TensorDictModule(fn, in_keys=kwargs, out_keys=["a"])
module = TensorDictModule(
fn, in_keys=kwargs, out_keys=["a"], out_to_in_map=False
)
td = TensorDict(
{
"1": torch.ones(1),
Expand All @@ -164,6 +166,76 @@ def fn(a, b=None, *, c=None):
td = TensorDict({"1": torch.ones(1)}, [])
assert (module(td)["a"] == 2).all()

def test_unused_out_to_in_map(self):
def fn(x, y):
return x + y

with pytest.warns(
match="out_to_in_map is not None but is only used when in_key` is a dictionary."
):
_ = TensorDictModule(fn, in_keys=["x"], out_keys=["a"], out_to_in_map=False)

def test_input_keys_dict_reversed(self):
in_keys = {"x": "1", "y": "2"}

def fn(x, y):
return x + y

module = TensorDictModule(
fn, in_keys=in_keys, out_keys=["a"], out_to_in_map=True
)

td = TensorDict({"1": torch.ones(1), "2": torch.ones(1) * 3}, [])
assert (module(td)["a"] == 4).all()

def test_input_keys_match_reversed(self):
in_keys = {"1": "x", "2": "y"}
reversed_in_keys = {v: k for k, v in in_keys.items()}

def fn(x, y):
return y - x

module = TensorDictModule(
fn, in_keys=in_keys, out_keys=["a"], out_to_in_map=False
)
reversed_module = TensorDictModule(
fn, in_keys=reversed_in_keys, out_keys=["a"], out_to_in_map=True
)

td = TensorDict({"1": torch.ones(1), "2": torch.ones(1) * 3}, [])

assert module(td)["a"] == reversed_module(td)["a"] == torch.Tensor([2])

@pytest.mark.parametrize("out_to_in_map", [True, False])
def test_input_keys_wrong_mapping(self, out_to_in_map):
in_keys = {"1": "x", "2": "y"}
if not out_to_in_map:
in_keys = {v: k for k, v in in_keys.items()}

def fn(x, y):
return x + y

module = TensorDictModule(
fn, in_keys=in_keys, out_keys=["a"], out_to_in_map=out_to_in_map
)

td = TensorDict({"1": torch.ones(1), "2": torch.ones(1) * 3}, [])

with pytest.raises(TypeError, match="got an unexpected keyword argument '1'"):
module(td)

def test_input_keys_dict_deprecated_warning(self):
in_keys = {"1": "x", "2": "y"}

def fn(x, y):
return x + y

with pytest.warns(
DeprecationWarning,
match="Using a dictionary in_keys without specifying out_to_in_map is deprecated.",
):
_ = TensorDictModule(fn, in_keys=in_keys, out_keys=["a"])

def test_reset(self):
torch.manual_seed(0)
net = nn.ModuleList([nn.Sequential(nn.Linear(1, 1), nn.ReLU())])
Expand Down Expand Up @@ -471,7 +543,10 @@ def test_functional_functorch(self):

def test_vmap_kwargs(self):
module = TensorDictModule(
lambda x, *, y: x + y, in_keys={"1": "x", "2": "y"}, out_keys=["z"]
lambda x, *, y: x + y,
in_keys={"1": "x", "2": "y"},
out_keys=["z"],
out_to_in_map=False,
)
td = TensorDict(
{"1": torch.ones((10,)), "2": torch.ones((10,)) * 2}, batch_size=[10]
Expand Down Expand Up @@ -716,7 +791,9 @@ def fn(a, b=None, *, c=None):
return a + 1

if kwargs:
module1 = TensorDictModule(fn, in_keys=kwargs, out_keys=["a"])
module1 = TensorDictModule(
fn, in_keys=kwargs, out_keys=["a"], out_to_in_map=False
)
td = TensorDict(
{
"input": torch.ones(1),
Expand Down Expand Up @@ -1169,7 +1246,10 @@ def mycallable():
module, in_keys=[("i", "i2")], out_keys=[(("o", "o2"), ("o3",))]
)
TensorDictModule(
module, in_keys={"i": "i1", (("i2",),): "i3"}, out_keys=[("o", "o2")]
module,
in_keys={"i": "i1", (("i2",),): "i3"},
out_keys=[("o", "o2")],
out_to_in_map=False,
)

# corner cases that should work
Expand Down

0 comments on commit 42a61e0

Please sign in to comment.