Skip to content

Commit

Permalink
[Feature] CompositeDistribution.from_distributions
Browse files Browse the repository at this point in the history
ghstack-source-id: 04a62439b0fe60422fbc901172df46306e161cc5
Pull Request resolved: #1113
  • Loading branch information
vmoens committed Nov 27, 2024
1 parent b4b8b31 commit a45c7e3
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 0 deletions.
100 changes: 100 additions & 0 deletions tensordict/nn/distributions/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import warnings
from typing import Dict

import torch
from tensordict import TensorDict, TensorDictBase
Expand Down Expand Up @@ -121,6 +122,105 @@ def __init__(
self.include_sum = include_sum
self.inplace = inplace

@classmethod
def from_distributions(
cls,
params,
distributions: Dict[NestedKey, d.Distribution],
*,
name_map: dict | None = None,
aggregate_probabilities: bool | None = None,
log_prob_key: NestedKey = "sample_log_prob",
entropy_key: NestedKey = "entropy",
inplace: bool | None = None,
include_sum: bool | None = None,
) -> CompositeDistribution:
"""Create a `CompositeDistribution` instance from existing distribution objects.
This class method allows for the creation of a `CompositeDistribution` by directly providing
a dictionary of distribution instances, rather than specifying distribution types and parameters separately.
Args:
params (TensorDictBase): A TensorDict that defines the batch shape for the composite distribution.
The params will not be used by this method, but the tensordict will be used to gather the key names of
the distributions.
distributions (Dict[NestedKey, d.Distribution]): A dictionary mapping nested keys to distribution instances.
These distributions will be used directly in the composite distribution.
Keyword Args:
name_map (Dict[NestedKey, NestedKey], optional): A mapping of where each sample should be written. If not provided,
the key names from `distribution_map` will be used.
aggregate_probabilities (bool, optional): If `True`, the `log_prob` and `entropy` methods will sum the probabilities and entropies
of the individual distributions and return a single tensor. If `False`, individual log-probabilities will be stored in the input
TensorDict (for `log_prob`) or returned as leaves of the output TensorDict (for `entropy`). This can be overridden at runtime
by passing the `aggregate_probabilities` argument to `log_prob` and `entropy`. Defaults to `False`.
log_prob_key (NestedKey, optional): The key where the log probability will be stored. Defaults to `'sample_log_prob'`.
entropy_key (NestedKey, optional): The key where the entropy will be stored. Defaults to `'entropy'`.
inplace (bool, optional): Whether to modify the input TensorDict in-place. Defaults to `True`.
.. warning:: The default value of ``inplace`` will switch to ``False`` in v0.9 in the constructor.
include_sum (bool, optional): Whether to include the summed log-probability in the output TensorDict. Defaults to `True`.
.. warning:: The default value of ``include_sum`` will switch to ``False`` in v0.9 in the constructor.
Returns:
CompositeDistribution: An instance of `CompositeDistribution` initialized with the provided distributions.
Raises:
KeyError: If a key in `name_map` cannot be found in the provided distributions.
.. note:: The batch size of the `params` TensorDict determines the batch shape of the composite distribution.
Example:
>>> from tensordict.nn import CompositeDistribution, ProbabilisticTensorDictSequential, ProbabilisticTensorDictModule, TensorDictModule
>>> import torch
>>> from tensordict import TensorDict
>>>
>>> # Values are not used to build the dists
>>> params = TensorDict({("0", "loc"): None, ("1", "loc"): None, ("0", "scale"): None, ("1", "scale"): None})
>>> d0 = torch.distributions.Normal(0, 1)
>>> d1 = torch.distributions.Normal(torch.zeros(1, 2), torch.ones(1, 2))
>>>
>>> d = CompositeDistribution.from_distributions(params, {"0": d0, "1": d1})
>>> print(d.sample())
TensorDict(
fields={
0: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
1: Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
"""
self = cls.__new__(cls)
self._batch_shape = params.shape
dists = {}
if name_map is not None:
name_map = {
unravel_key(key): unravel_key(other_key)
for key, other_key in name_map.items()
}
for name, dist in distributions.items():
name_unravel = unravel_key(name)
if name_map:
try:
write_name = unravel_key(name_map.get(name, name_unravel))
except KeyError:
raise KeyError(
f"Failed to retrieve the key {name} from the name_map with keys {name_map.keys()}."
)
else:
write_name = name_unravel
dists[write_name] = dist
self.dists = dists
self.log_prob_key = log_prob_key
self.entropy_key = entropy_key

self.aggregate_probabilities = aggregate_probabilities
self.include_sum = include_sum
self.inplace = inplace
return self

@property
def aggregate_probabilities(self):
aggregate_probabilities = self._aggregate_probabilities
Expand Down
23 changes: 23 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2166,6 +2166,29 @@ def test_sample(self):
sample = dist.sample((4,))
assert sample.shape == torch.Size((4,) + params.shape)

def test_from_distributions(self):

# Values are not used to build the dists
params = TensorDict(
{
("0", "loc"): None,
("1", "nested", "loc"): None,
("0", "scale"): None,
("1", "nested", "scale"): None,
}
)
d0 = torch.distributions.Normal(0, 1)
d1 = torch.distributions.Normal(torch.zeros(1, 2), torch.ones(1, 2))

d = CompositeDistribution.from_distributions(
params, {"0": d0, ("1", "nested"): d1}
)
s = d.sample()
assert s["0"].shape == ()
assert s["1", "nested"].shape == (1, 2)
assert isinstance(s["0"], torch.Tensor)
assert isinstance(s["1", "nested"], torch.Tensor)

def test_sample_named(self):
params = TensorDict(
{
Expand Down

0 comments on commit a45c7e3

Please sign in to comment.