From e670caec84a6aca21f0a4f0f56de2a12d9f7ab97 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 30 Aug 2023 22:12:01 -0400 Subject: [PATCH 1/3] Add a Factors handler for inserting new factors --- chirho/observational/handlers/condition.py | 48 +++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/chirho/observational/handlers/condition.py b/chirho/observational/handlers/condition.py index 200988e6..1d63ae0b 100644 --- a/chirho/observational/handlers/condition.py +++ b/chirho/observational/handlers/condition.py @@ -1,6 +1,7 @@ -from typing import Generic, Hashable, Mapping, TypeVar +from typing import Callable, Generic, Hashable, Mapping, TypeVar import pyro +import torch from chirho.observational.internals import ObserveNameMessenger from chirho.observational.ops import AtomicObservation, observe @@ -8,6 +9,51 @@ T = TypeVar("T") +class Factors(Generic[T], pyro.poutine.messenger.Messenger): + """ + Effect handler that adds new log-factors to the unnormalized + joint log-density of a probabilistic program. + + After a :func:`pyro.sample` site whose name appears in ``factors``, + this handler inserts a new :func:`pyro.factor` site + whose name is prefixed with the string ``prefix`` + and whose log-weight is the result of applying the corresponding function + to the value of the sample site. :: + + >>> with Factors(factors={"x": lambda x: -(x - 1) ** 2}, prefix="__factor_"): + ... with pyro.poutine.trace() as tr: + ... x = pyro.sample("x", dist.Normal(0, 1)) + ... tr.trace.compute_log_prob() + >>> assert {"x", "__factor_x"} <= set(tr.trace.nodes.keys()) + >>> assert tr.trace.log_prob_sum() == tr.trace.nodes["x"]["log_prob"] + \ + ... tr.trace.nodes["x"]["log_prob"] - (tr.trace.nodes["x"]["value"] - 1) ** 2 + + :param factors: A mapping from sample site names to log-factor functions. + :param prefix: The prefix to use for the names of the factor sites. + """ + + factors: Mapping[str, Callable[[T], torch.Tensor]] + prefix: str + + def __init__( + self, + factors: Mapping[str, Callable[[T], torch.Tensor]], + *, + prefix: str = "__factor_", + ): + self.factors = factors + self.prefix = prefix + super().__init__() + + def _pyro_post_sample(self, msg: dict) -> None: + try: + factor = self.factors[msg["name"]] + except KeyError: + return + + pyro.factor(f"{self.prefix}{msg['name']}", factor(msg["value"])) + + class ConditionMessenger(Generic[T], ObserveNameMessenger): """ Condition on values in a probabilistic program. From 332d0aa260a5f1c37e81a4147b69e546c2184e8f Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 30 Aug 2023 22:33:42 -0400 Subject: [PATCH 2/3] docs and test --- chirho/observational/handlers/condition.py | 8 ++--- tests/observational/test_handlers.py | 34 ++++++++++++++++++++++ 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/chirho/observational/handlers/condition.py b/chirho/observational/handlers/condition.py index 1d63ae0b..21b76701 100644 --- a/chirho/observational/handlers/condition.py +++ b/chirho/observational/handlers/condition.py @@ -7,6 +7,7 @@ from chirho.observational.ops import AtomicObservation, observe T = TypeVar("T") +R = float | torch.Tensor class Factors(Generic[T], pyro.poutine.messenger.Messenger): @@ -25,19 +26,18 @@ class Factors(Generic[T], pyro.poutine.messenger.Messenger): ... x = pyro.sample("x", dist.Normal(0, 1)) ... tr.trace.compute_log_prob() >>> assert {"x", "__factor_x"} <= set(tr.trace.nodes.keys()) - >>> assert tr.trace.log_prob_sum() == tr.trace.nodes["x"]["log_prob"] + \ - ... tr.trace.nodes["x"]["log_prob"] - (tr.trace.nodes["x"]["value"] - 1) ** 2 + >>> assert torch.all(tr.trace.nodes["x"]["log_prob"] == -(x - 1) ** 2) :param factors: A mapping from sample site names to log-factor functions. :param prefix: The prefix to use for the names of the factor sites. """ - factors: Mapping[str, Callable[[T], torch.Tensor]] + factors: Mapping[str, Callable[[T], R]] prefix: str def __init__( self, - factors: Mapping[str, Callable[[T], torch.Tensor]], + factors: Mapping[str, Callable[[T], R]], *, prefix: str = "__factor_", ): diff --git a/tests/observational/test_handlers.py b/tests/observational/test_handlers.py index 105fa1bb..c8748453 100644 --- a/tests/observational/test_handlers.py +++ b/tests/observational/test_handlers.py @@ -11,6 +11,7 @@ ) from chirho.interventional.handlers import do from chirho.observational.handlers import condition +from chirho.observational.handlers.condition import Factors from chirho.observational.handlers.soft_conditioning import ( AutoSoftConditioning, KernelSoftConditionReparam, @@ -284,3 +285,36 @@ def model(): assert torch.allclose(node["value"], tr3.trace.nodes[name]["value"]) assert torch.allclose(node["log_prob"], tr2.trace.nodes[name]["log_prob"]) assert torch.allclose(node["log_prob"], tr3.trace.nodes[name]["log_prob"]) + + +def test_factors_handler(): + def model(): + z = pyro.sample("z", dist.Normal(0, 1), obs=torch.tensor(0.1)) + with pyro.plate("data", 2): + x = pyro.sample("x", dist.Normal(z, 1)) + y = pyro.sample("y", dist.Normal(x + z, 1)) + return z, x, y + + prefix = "__factor_" + factors = { + "z": lambda z: -((z - 1.5) ** 2), + "x": lambda x: -((x - 1) ** 2), + } + + with Factors[torch.Tensor](factors=factors, prefix=prefix): + with pyro.poutine.trace() as tr: + model() + + tr.trace.compute_log_prob() + + for name in factors: + assert name in tr.trace.nodes + assert f"{prefix}{name}" in tr.trace.nodes + assert ( + tr.trace.nodes[name]["fn"].batch_shape + == tr.trace.nodes[f"{prefix}{name}"]["fn"].batch_shape + ) + assert torch.allclose( + tr.trace.nodes[f"{prefix}{name}"]["log_prob"], + factors[name](tr.trace.nodes[name]["value"]), + ) From 4bb7098934d4f77eca7b2ade456ebb86c447cebf Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 30 Aug 2023 23:17:35 -0400 Subject: [PATCH 3/3] lint --- chirho/observational/handlers/condition.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chirho/observational/handlers/condition.py b/chirho/observational/handlers/condition.py index 21b76701..54c0d5ef 100644 --- a/chirho/observational/handlers/condition.py +++ b/chirho/observational/handlers/condition.py @@ -1,4 +1,4 @@ -from typing import Callable, Generic, Hashable, Mapping, TypeVar +from typing import Callable, Generic, Hashable, Mapping, TypeVar, Union import pyro import torch @@ -7,7 +7,7 @@ from chirho.observational.ops import AtomicObservation, observe T = TypeVar("T") -R = float | torch.Tensor +R = Union[float, torch.Tensor] class Factors(Generic[T], pyro.poutine.messenger.Messenger):