Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 18, 2024
1 parent d6c078d commit 4d6e477
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 5 deletions.
28 changes: 26 additions & 2 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import warnings

from textwrap import indent
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, overload, OrderedDict

import torch

Expand Down Expand Up @@ -791,6 +791,30 @@ class ProbabilisticTensorDictSequential(TensorDictSequential):
"""

@overload
def __init__(
self,
modules: OrderedDict,
partial_tolerant: bool = False,
return_composite: bool | None = None,
aggregate_probabilities: bool | None = None,
include_sum: bool | None = None,
inplace: bool | None = None,
) -> None:
...

@overload
def __init__(
self,
modules: List[TensorDictModuleBase | ProbabilisticTensorDictModule],
partial_tolerant: bool = False,
return_composite: bool | None = None,
aggregate_probabilities: bool | None = None,
include_sum: bool | None = None,
inplace: bool | None = None,
) -> None:
...

def __init__(
self,
*modules: TensorDictModuleBase | ProbabilisticTensorDictModule,
Expand All @@ -815,7 +839,7 @@ def __init__(
"ProbabilisticTensorDictSequential (unless return_composite is set to ``True``)."
)
# if the modules not including the final probabilistic module return the sampled
# key we wont be sampling it again, in that case
# key we won't be sampling it again, in that case
# ProbabilisticTensorDictSequential is presumably used to return the
# distribution using `get_dist` or to sample log_probabilities
_, out_keys = self._compute_in_and_out_keys(modules[:-1])
Expand Down
10 changes: 7 additions & 3 deletions tensordict/nn/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,9 @@ def select_subsequence(
return type(self)(*modules)
else:
keys = [key for key in self.module if self.module[key] in modules]
modules_dict = OrderedDict(**{key: val for key, val in zip(keys, modules)})
modules_dict = collections.OrderedDict(
**{key: val for key, val in zip(keys, modules)}
)
return type(self)(modules_dict)

def _run_module(
Expand Down Expand Up @@ -565,8 +567,10 @@ def __getitem__(self, index: int | slice | str) -> TensorDictModuleBase:
else:
return type(self)(*self.module.__getitem__(index))

def __setitem__(self, index: int, tensordict_module: TensorDictModuleBase) -> None:
def __setitem__(
self, index: int | slice | str, tensordict_module: TensorDictModuleBase
) -> None:
return self.module.__setitem__(idx=index, module=tensordict_module)

def __delitem__(self, index: int | slice) -> None:
def __delitem__(self, index: int | slice | str) -> None:
self.module.__delitem__(idx=index)

0 comments on commit 4d6e477

Please sign in to comment.