Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 29, 2024
2 parents 6c98571 + a0124fb commit 2a6262e
Show file tree
Hide file tree
Showing 4 changed files with 351 additions and 26 deletions.
51 changes: 47 additions & 4 deletions tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,11 +773,20 @@ class TensorDictModule(TensorDictModuleBase):
order given by the in_keys iterable.
If ``in_keys`` is a dictionary, its keys must correspond to the key
to be read in the tensordict and its values must match the name of
the keyword argument in the function signature.
the keyword argument in the function signature. If `out_to_in_map` is ``True``,
the mapping gets inverted so that the keys correspond to the keyword
arguments in the function signature.
out_keys (iterable of str): keys to be written to the input tensordict. The length of out_keys must match the
number of tensors returned by the embedded module. Using "_" as a key avoid writing tensor to output.
Keyword Args:
out_to_in_map (bool, optional): if ``True``, `in_keys` is read as if the keys are the arguments keys of
the :meth:`~.forward` method and the values are the keys in the input :class:`~tensordict.TensorDict`. If
``False`` or ``None`` (default), keys are considered to be the input keys and values the method's arguments keys.
.. warning::
The default value of `out_to_in_map` will change from ``False`` to ``True`` in the v0.9 release.
inplace (bool or string, optional): if ``True`` (default), the output of the module are written in the tensordict
provided to the :meth:`~.forward` method. If ``False``, a new :class:`~tensordict.TensorDict` with and empty
batch-size and no device is created. if ``"empty"``, :meth:`~tensordict.TensorDict.empty` will be used to
Expand Down Expand Up @@ -865,12 +874,24 @@ class TensorDictModule(TensorDictModuleBase):
Examples:
>>> module = TensorDictModule(lambda x, *, y: x+y,
... in_keys={'1': 'x', '2': 'y'}, out_keys=['z'],
... in_keys={'1': 'x', '2': 'y'}, out_keys=['z'], out_to_in_map=False
... )
>>> td = module(TensorDict({'1': torch.ones(()), '2': torch.ones(())*2}, []))
>>> td['z']
tensor(3.)
If `out_to_in_map` is set to ``True``, then the `in_keys` mapping is reversed. This way,
one can use the same input key for different keyword arguments.
Examples:
>>> module = TensorDictModule(lambda x, *, y, z: x+y+z,
... in_keys={'x': '1', 'y': '2', z: '2'}, out_keys=['t'], out_to_in_map=True
... )
>>> td = module(TensorDict({'1': torch.ones(()), '2': torch.ones(())*2}, []))
>>> td['t']
tensor(5.)
Functional calls to a tensordict module is easy:
Examples:
Expand Down Expand Up @@ -923,17 +944,39 @@ def __init__(
in_keys: NestedKey | List[NestedKey] | Dict[NestedKey:str],
out_keys: NestedKey | List[NestedKey],
*,
out_to_in_map: bool | None = None,
inplace: bool | str = True,
) -> None:
super().__init__()

if out_to_in_map is not None and not isinstance(in_keys, dict):
warnings.warn(
"out_to_in_map is not None but is only used when in_key is a dictionary."
)

if isinstance(in_keys, dict):
if out_to_in_map is None:
out_to_in_map = False
warnings.warn(
"Using a dictionary in_keys without specifying out_to_in_map is deprecated. "
"By default, out_to_in_map is `False` (`in_keys` keys as tensordict pointers, "
"values as kwarg name), but from version>=0.9, default will be `True` "
"(`in_keys` keys as func kwarg name, values as tensordict pointers). "
"Please use explicit out_to_in_map to indicate the ordering of the input keys. ",
DeprecationWarning,
stacklevel=2,
)

# write the kwargs and create a list instead
_in_keys = []
self._kwargs = []
for key, value in in_keys.items():
self._kwargs.append(value)
_in_keys.append(key)
if out_to_in_map: # arg: td_key
self._kwargs.append(key)
_in_keys.append(value)
else: # td_key: arg
self._kwargs.append(value)
_in_keys.append(key)
in_keys = _in_keys
else:
if isinstance(in_keys, (str, tuple)):
Expand Down
175 changes: 158 additions & 17 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,17 +580,23 @@ def _dist_sample(


class ProbabilisticTensorDictSequential(TensorDictSequential):
"""A sequence of :class:`~tensordict.nn.TensorDictModules` ending in a :class:`~tensordict.nn.ProbabilisticTensorDictModule`.
"""A sequence of :class:`~tensordict.nn.TensorDictModules` containing at least one :class:`~tensordict.nn.ProbabilisticTensorDictModule`.
This class extends :class:`~tensordict.nn.TensorDictSequential` by enforcing that the final module
in the sequence is an instance of :class:`~tensordict.nn.ProbabilisticTensorDictModule`. It also
exposes the :meth:`~.get_dist` method to recover the distribution object from the
:class:`~tensordict.nn.ProbabilisticTensorDictModule`.
This class extends :class:`~tensordict.nn.TensorDictSequential` and is typically configured with a sequence of
modules where the final module is an instance of :class:`~tensordict.nn.ProbabilisticTensorDictModule`.
However, it also supports configurations where one or more intermediate modules are instances of
:class:`~tensordict.nn.ProbabilisticTensorDictModule`, while the last module may or may not be probabilistic.
In all cases, it exposes the :meth:`~.get_dist` method to recover the distribution object from the
:class:`~tensordict.nn.ProbabilisticTensorDictModule` instances in the sequence.
Multiple probabilistic modules can co-exist in a single ``ProbabilisticTensorDictSequential``.
If `return_composite` if ``False`` (default), only the last one will produce a distribution and the others
will be executed as regular :class:`~tensordict.nn.TensorDictModule` instances. If ``True``,
intermediate distributions will be grouped in a single :class:`~tensordict.nn.CompositeDistribution`.
If `return_composite` is ``False`` (default), only the last one will produce a distribution and the others
will be executed as regular :class:`~tensordict.nn.TensorDictModule` instances.
However, if a `ProbabilisticTensorDictModule` is not the last module in the sequence and `return_composite=False`,
a `ValueError` will be raised when trying to query the module. If `return_composite=True`,
all intermediate `ProbabilisticTensorDictModule` instances will contribute to a single
:class:`~tensordict.nn.CompositeDistribution` instance.
Resulting log-probabilities will be conditional probabilities if samples are interdependent:
whenever
Expand Down Expand Up @@ -652,6 +658,125 @@ class ProbabilisticTensorDictSequential(TensorDictSequential):
:obj:`ProbabilisticTensorDictModule` or
:obj:`ProbabilisticTensorDictSequential`.
Examples:
>>> from tensordict.nn import ProbabilisticTensorDictModule as Prob, ProbabilisticTensorDictSequential as Seq
>>> import torch
>>> # Typical usage: a single distribution is computed last in the sequence
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import ProbabilisticTensorDictModule as Prob, ProbabilisticTensorDictSequential as Seq, \
... TensorDictModule as Mod
>>> torch.manual_seed(0)
>>>
>>> module = Seq(
... Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]),
... Prob(in_keys=["loc"], out_keys=["sample"], distribution_class=torch.distributions.Normal,
... distribution_kwargs={"scale": 1}),
... )
>>> input = TensorDict(x=torch.ones(3))
>>> td = module(input.copy())
>>> print(td)
TensorDict(
fields={
loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
sample: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
>>> print(module.get_dist(input))
Normal(loc: torch.Size([3]), scale: torch.Size([3]))
>>> print(module.log_prob(td))
tensor([-0.9189, -0.9189, -0.9189])
>>> # Intermediate distributions are ignored when return_composite=False
>>> module = Seq(
... Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]),
... Prob(in_keys=["loc"], out_keys=["sample0"], distribution_class=torch.distributions.Normal,
... distribution_kwargs={"scale": 1}),
... Mod(lambda x: x + 1, in_keys=["sample0"], out_keys=["loc2"]),
... Prob(in_keys={"loc": "loc2"}, out_keys=["sample1"], distribution_class=torch.distributions.Normal,
... distribution_kwargs={"scale": 1}),
... return_composite=False,
... )
>>> td = module(TensorDict(x=torch.ones(3)))
>>> print(td)
TensorDict(
fields={
loc2: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
sample0: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
sample1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
>>> print(module.get_dist(input))
Normal(loc: torch.Size([3]), scale: torch.Size([3]))
>>> print(module.log_prob(td))
tensor([-0.9189, -0.9189, -0.9189])
>>> # Intermediate distributions produce a CompositeDistribution when return_composite=True
>>> module = Seq(
... Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]),
... Prob(in_keys=["loc"], out_keys=["sample0"], distribution_class=torch.distributions.Normal,
... distribution_kwargs={"scale": 1}),
... Mod(lambda x: x + 1, in_keys=["sample0"], out_keys=["loc2"]),
... Prob(in_keys={"loc": "loc2"}, out_keys=["sample1"], distribution_class=torch.distributions.Normal,
... distribution_kwargs={"scale": 1}),
... return_composite=True,
... )
>>> input = TensorDict(x=torch.ones(3))
>>> td = module(input.copy())
>>> print(td)
TensorDict(
fields={
loc2: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
sample0: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
sample1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
>>> print(module.get_dist(input))
CompositeDistribution({'sample0': Normal(loc: torch.Size([3]), scale: torch.Size([3])), 'sample1': Normal(loc: torch.Size([3]), scale: torch.Size([3]))})
>>> print(module.log_prob(td, aggregate_probabilities=False))
TensorDict(
fields={
sample0_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
sample1_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
>>> # Even a single intermediate distribution is wrapped in a CompositeDistribution when
>>> # return_composite=True
>>> module = Seq(
... Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]),
... Prob(in_keys=["loc"], out_keys=["sample0"], distribution_class=torch.distributions.Normal,
... distribution_kwargs={"scale": 1}),
... Mod(lambda x: x + 1, in_keys=["sample0"], out_keys=["y"]),
... return_composite=True,
... )
>>> td = module(TensorDict(x=torch.ones(3)))
>>> print(td)
TensorDict(
fields={
loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
sample0: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
y: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
>>> print(module.get_dist(input))
CompositeDistribution({'sample0': Normal(loc: torch.Size([3]), scale: torch.Size([3]))})
>>> print(module.log_prob(td, aggregate_probabilities=False, inplace=False, include_sum=False))
TensorDict(
fields={
sample0_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
"""

def __init__(
Expand All @@ -668,14 +793,14 @@ def __init__(
"ProbabilisticTensorDictSequential must consist of zero or more "
"TensorDictModules followed by a ProbabilisticTensorDictModule"
)
if not isinstance(
if not return_composite and not isinstance(
modules[-1],
(ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential),
):
raise TypeError(
"The final module passed to ProbabilisticTensorDictSequential must be "
"an instance of ProbabilisticTensorDictModule or another "
"ProbabilisticTensorDictSequential"
"ProbabilisticTensorDictSequential (unless return_composite is set to ``True``)."
)
# if the modules not including the final probabilistic module return the sampled
# key we wont be sampling it again, in that case
Expand Down Expand Up @@ -724,7 +849,12 @@ def get_dist_params(
tds = self.det_part
type = interaction_type()
if type is None:
type = self.module[-1].default_interaction_type
for m in reversed(self.module):
if hasattr(m, "default_interaction_type"):
type = m.default_interaction_type
break
else:
raise ValueError("Could not find a default interaction in the modules.")
with set_interaction_type(type):
return tds(tensordict, tensordict_out, **kwargs)

Expand Down Expand Up @@ -784,8 +914,8 @@ def get_dist(
td_copy = tdm(td_copy)
if len(dists) == 0:
raise RuntimeError(f"No distribution module found in {self}.")
elif len(dists) == 1:
return dist
# elif len(dists) == 1:
# return dist
return CompositeDistribution.from_distributions(
td_copy,
dists,
Expand Down Expand Up @@ -968,10 +1098,21 @@ def forward(
tensordict_exec = tensordict.copy()
else:
tensordict_exec = tensordict
tensordict_exec = self.get_dist_params(tensordict_exec, **kwargs)
tensordict_exec = self.module[-1](
tensordict_exec, _requires_sample=self._requires_sample
)
if self.return_composite:
for m in self.module:
if isinstance(
m, (ProbabilisticTensorDictModule, ProbabilisticTensorDictModule)
):
tensordict_exec = m(
tensordict_exec, _requires_sample=self._requires_sample
)
else:
tensordict_exec = m(tensordict_exec, **kwargs)
else:
tensordict_exec = self.get_dist_params(tensordict_exec, **kwargs)
tensordict_exec = self.module[-1](
tensordict_exec, _requires_sample=self._requires_sample
)
if tensordict_out is not None:
result = tensordict_out
result.update(tensordict_exec, keys_to_update=self.out_keys)
Expand Down
5 changes: 4 additions & 1 deletion tensordict/nn/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,10 @@ def forward(
else:
result = tensordict_exec
if self._select_before_return:
return tensordict.update(result.select(*self.out_keys))
# We must also update any value that has been updated during the course of execution
# from the input data.
keys = list(set(self.out_keys + list(tensordict.keys(True, True))))
return tensordict.update(result, keys_to_update=keys)
return result

def __len__(self) -> int:
Expand Down
Loading

0 comments on commit 2a6262e

Please sign in to comment.