Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 2, 2024
2 parents 17dbbb1 + 5be188f commit 48bf06a
Show file tree
Hide file tree
Showing 8 changed files with 980 additions and 172 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ to build distributions from network outputs and get summary statistics or sample
TensorDictSequential
TensorDictModuleWrapper
CudaGraphModule
WrapModule

Ensembles
---------
Expand Down
1 change: 1 addition & 0 deletions tensordict/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
TensorDictModule,
TensorDictModuleBase,
TensorDictModuleWrapper,
WrapModule,
)
from tensordict.nn.distributions import (
AddStateIndependentNormalScale,
Expand Down
41 changes: 37 additions & 4 deletions tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1278,12 +1278,45 @@ def forward(self, *args: Any, **kwargs: Any) -> TensorDictBase:


class WrapModule(TensorDictModuleBase):
"""A wrapper around any callable that processes TensorDict instances.
This wrapper is useful when building :class:`~tensordict.nn.TensorDictSequential` stacks and when a transform
requires the entire TensorDict instance to be visible.
Args:
func (Callable[[TensorDictBase], TensorDictBase]): A callable function that takes in a TensorDictBase instance
and returns a transformed TensorDictBase instance.
Keyword Args:
inplace (bool, optional): If ``True``, the input TensorDict will be modified in-place. Otherwise, a new TensorDict
will be returned (if the function does not modify it in-place and returns it). Defaults to ``False``.
Examples:
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod, WrapModule
>>> seq = Seq(
... Mod(lambda x: x * 2, in_keys=["x"], out_keys=["y"]),
... WrapModule(lambda td: td.reshape(-1)),
... )
>>> td = TensorDict(x=torch.ones(3, 4, 5), batch_size=[3, 4])
>>> td = Seq(td)
>>> assert td.shape == (12,)
>>> assert (td["y"] == 2).all()
>>> assert td["y"].shape == (12, 5)
"""

in_keys = []
out_keys = []

def __init__(self, func):
self.func = func
def __init__(
self, func: Callable[[TensorDictBase], TensorDictBase], *, inplace: bool = False
) -> None:
super().__init__()
self.func = func
self.inplace = inplace

def forward(self, data):
return self.func(data)
def forward(self, data: TensorDictBase) -> TensorDictBase:
result = self.func(data)
if self.inplace and result is not data:
return data.update(result)
return result
26 changes: 7 additions & 19 deletions tensordict/nn/distributions/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,22 +221,6 @@ def from_distributions(
self.inplace = inplace
return self

@property
def aggregate_probabilities(self):
aggregate_probabilities = self._aggregate_probabilities
if aggregate_probabilities is None:
warnings.warn(
"The default value of `aggregate_probabilities` will change from `False` to `True` in v0.7. "
"Please pass this value explicitly to avoid this warning.",
FutureWarning,
)
aggregate_probabilities = self._aggregate_probabilities = False
return aggregate_probabilities

@aggregate_probabilities.setter
def aggregate_probabilities(self, value):
self._aggregate_probabilities = value

def sample(self, shape=None) -> TensorDictBase:
if shape is None:
shape = torch.Size([])
Expand Down Expand Up @@ -337,7 +321,7 @@ def log_prob(
aggregate_probabilities (bool, optional): if provided, overrides the default ``aggregate_probabilities``
from the class.
include_sum (bool, optional): Whether to include the summed log-probability in the output TensorDict.
Defaults to ``self.inplace`` which is set through the class constructor (``True`` by default).
Defaults to ``self.include_sum`` which is set through the class constructor (``True`` by default).
Has no effect if ``aggregate_probabilities`` is set to ``True``.
.. warning:: The default value of ``include_sum`` will switch to ``False`` in v0.9 in the constructor.
Expand All @@ -356,6 +340,8 @@ def log_prob(
"""
if aggregate_probabilities is None:
aggregate_probabilities = self.aggregate_probabilities
if aggregate_probabilities is None:
aggregate_probabilities = False
if not aggregate_probabilities:
return self.log_prob_composite(
sample, include_sum=include_sum, inplace=inplace
Expand All @@ -382,7 +368,7 @@ def log_prob_composite(
Keyword Args:
include_sum (bool, optional): Whether to include the summed log-probability in the output TensorDict.
Defaults to ``self.inplace`` which is set through the class constructor (``True`` by default).
Defaults to ``self.include_sum`` which is set through the class constructor (``True`` by default).
.. warning:: The default value of ``include_sum`` will switch to ``False`` in v0.9 in the constructor.
Expand Down Expand Up @@ -451,7 +437,7 @@ def entropy(
setting from the class. Determines whether to return a single summed entropy tensor or a TensorDict
with individual entropies. Defaults to ``False`` if not set in the class.
include_sum (bool, optional): Whether to include the summed entropy in the output TensorDict.
Defaults to `self.inplace`, which is set through the class constructor. Has no effect if
Defaults to `self.include_sum`, which is set through the class constructor. Has no effect if
`aggregate_probabilities` is set to `True`.
.. warning:: The default value of `include_sum` will switch to `False` in v0.9 in the constructor.
Expand All @@ -466,6 +452,8 @@ def entropy(
"""
if aggregate_probabilities is None:
aggregate_probabilities = self.aggregate_probabilities
if aggregate_probabilities is None:
aggregate_probabilities = False
if not aggregate_probabilities:
return self.entropy_composite(samples_mc, include_sum=include_sum)
se = 0.0
Expand Down
Loading

0 comments on commit 48bf06a

Please sign in to comment.