diff --git a/tensordict/nn/sequence.py b/tensordict/nn/sequence.py index faa2f60a0..e668f37fe 100644 --- a/tensordict/nn/sequence.py +++ b/tensordict/nn/sequence.py @@ -5,9 +5,10 @@ from __future__ import annotations +import collections import logging from copy import deepcopy -from typing import Any, Iterable, List +from typing import Any, Callable, Iterable, List, OrderedDict, overload from tensordict._nestedkey import NestedKey @@ -170,19 +171,57 @@ class TensorDictSequential(TensorDictModule): module: nn.ModuleList _select_before_return = False + @overload def __init__( self, - *modules: TensorDictModuleBase, + modules: OrderedDict[str, Callable[[TensorDictBase], TensorDictBase]], + *, + partial_tolerant: bool = False, + selected_out_keys: List[NestedKey] | None = None, + ) -> None: ... + + @overload + def __init__( + self, + modules: List[Callable[[TensorDictBase], TensorDictBase]], + *, + partial_tolerant: bool = False, + selected_out_keys: List[NestedKey] | None = None, + ) -> None: ... + + def __init__( + self, + *modules: Callable[[TensorDictBase], TensorDictBase], partial_tolerant: bool = False, selected_out_keys: List[NestedKey] | None = None, ) -> None: - modules = self._convert_modules(modules) - in_keys, out_keys = self._compute_in_and_out_keys(modules) - self._complete_out_keys = list(out_keys) - super().__init__( - module=nn.ModuleList(list(modules)), in_keys=in_keys, out_keys=out_keys - ) + if len(modules) == 1 and isinstance(modules[0], collections.OrderedDict): + modules_vals = self._convert_modules(modules[0].values()) + in_keys, out_keys = self._compute_in_and_out_keys(modules_vals) + self._complete_out_keys = list(out_keys) + modules = collections.OrderedDict( + **{key: val for key, val in zip(modules[0], modules_vals)} + ) + super().__init__( + module=nn.ModuleDict(modules), in_keys=in_keys, out_keys=out_keys + ) + elif len(modules) == 1 and isinstance( + modules[0], collections.abc.MutableSequence + ): + modules = self._convert_modules(modules[0]) + in_keys, out_keys = self._compute_in_and_out_keys(modules) + self._complete_out_keys = list(out_keys) + super().__init__( + module=nn.ModuleList(modules), in_keys=in_keys, out_keys=out_keys + ) + else: + modules = self._convert_modules(modules) + in_keys, out_keys = self._compute_in_and_out_keys(modules) + self._complete_out_keys = list(out_keys) + super().__init__( + module=nn.ModuleList(list(modules)), in_keys=in_keys, out_keys=out_keys + ) self.partial_tolerant = partial_tolerant if selected_out_keys: @@ -408,7 +447,7 @@ def select_subsequence( out_keys = deepcopy(self.out_keys) out_keys = unravel_key_list(out_keys) - module_list = list(self.module) + module_list = list(self._module_iter()) id_to_keep = set(range(len(module_list))) for i, module in enumerate(module_list): if ( @@ -445,8 +484,12 @@ def select_subsequence( raise ValueError( "No modules left after selection. Make sure that in_keys and out_keys are coherent." ) - - return type(self)(*modules) + if isinstance(self.module, nn.ModuleList): + return type(self)(*modules) + else: + keys = [key for key in self.module if self.module[key] in modules] + modules_dict = OrderedDict(**{key: val for key, val in zip(keys, modules)}) + return type(self)(modules_dict) def _run_module( self, @@ -466,6 +509,12 @@ def _run_module( module(sub_td, **kwargs) return tensordict + def _module_iter(self): + if isinstance(self.module, nn.ModuleDict): + yield from self.module.children() + else: + yield from self.module + @dispatch(auto_batch_size=False) @_set_skip_existing_None() def forward( @@ -481,7 +530,7 @@ def forward( else: tensordict_exec = tensordict if not len(kwargs): - for module in self.module: + for module in self._module_iter(): tensordict_exec = self._run_module(module, tensordict_exec, **kwargs) else: raise RuntimeError( @@ -510,8 +559,8 @@ def forward( def __len__(self) -> int: return len(self.module) - def __getitem__(self, index: int | slice) -> TensorDictModuleBase: - if isinstance(index, int): + def __getitem__(self, index: int | slice | str) -> TensorDictModuleBase: + if isinstance(index, (int, str)): return self.module.__getitem__(index) else: return type(self)(*self.module.__getitem__(index)) diff --git a/test/test_nn.py b/test/test_nn.py index 630b8d3d2..5b6f741e5 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -10,6 +10,7 @@ import pickle import unittest import weakref +from collections import OrderedDict import pytest import torch @@ -797,6 +798,58 @@ def test_tdmodule_inplace(self): class TestTDSequence: + def test_ordered_dict(self): + linear = nn.Linear(3, 4) + linear.weight.data.fill_(0) + linear.bias.data.fill_(1) + layer0 = TensorDictModule(linear, in_keys=["x"], out_keys=["y"]) + ordered_dict = OrderedDict( + layer0=layer0, + layer1=lambda x: x + 1, + ) + seq = TensorDictSequential(ordered_dict) + td = seq(TensorDict(x=torch.ones(3))) + assert (td["x"] == 2).all() + assert (td["y"] == 2).all() + assert seq["layer0"] is layer0 + + def test_ordered_dict_select_subsequence(self): + ordered_dict = OrderedDict( + layer0=TensorDictModule(lambda x: x + 1, in_keys=["x"], out_keys=["y"]), + layer1=TensorDictModule(lambda x: x - 1, in_keys=["y"], out_keys=["z"]), + layer2=TensorDictModule( + lambda x, y: x + y, in_keys=["x", "y"], out_keys=["a"] + ), + ) + seq = TensorDictSequential(ordered_dict) + assert len(seq) == 3 + assert isinstance(seq.module, nn.ModuleDict) + seq_select = seq.select_subsequence(out_keys=["a"]) + assert len(seq_select) == 2 + assert isinstance(seq_select.module, nn.ModuleDict) + assert list(seq_select.module) == ["layer0", "layer2"] + + def test_ordered_dict_select_outkeys(self): + ordered_dict = OrderedDict( + layer0=TensorDictModule( + lambda x: x + 1, in_keys=["x"], out_keys=["intermediate"] + ), + layer1=TensorDictModule( + lambda x: x - 1, in_keys=["intermediate"], out_keys=["z"] + ), + layer2=TensorDictModule( + lambda x, y: x + y, in_keys=["x", "z"], out_keys=["a"] + ), + ) + seq = TensorDictSequential(ordered_dict) + assert len(seq) == 3 + assert isinstance(seq.module, nn.ModuleDict) + seq.select_out_keys("z", "a") + td = seq(TensorDict(x=0)) + assert "intermediate" not in td + assert "z" in td + assert "a" in td + @pytest.mark.parametrize("args", [True, False]) def test_input_keys(self, args): module0 = TensorDictModule(lambda x: x + 0, in_keys=["input"], out_keys=["1"])