diff --git a/chirho/interventional/handlers.py b/chirho/interventional/handlers.py index ac2fb03b..5be449cf 100644 --- a/chirho/interventional/handlers.py +++ b/chirho/interventional/handlers.py @@ -92,13 +92,14 @@ def __init__(self, actions: Mapping[Hashable, AtomicIntervention[T]]): super().__init__() def _pyro_post_sample(self, msg): - try: - action = self.actions[msg["name"]] - except KeyError: + if msg["name"] not in self.actions or msg["infer"].get( + "_do_not_intervene", None + ): return + msg["value"] = intervene( msg["value"], - action, + self.actions[msg["name"]], event_dim=len(msg["fn"].event_shape), name=msg["name"], ) diff --git a/chirho/observational/__init__.py b/chirho/observational/__init__.py index e69de29b..f5970f59 100644 --- a/chirho/observational/__init__.py +++ b/chirho/observational/__init__.py @@ -0,0 +1 @@ +import chirho.observational.internals # noqa: F401 diff --git a/chirho/observational/handlers/__init__.py b/chirho/observational/handlers/__init__.py new file mode 100644 index 00000000..892fc055 --- /dev/null +++ b/chirho/observational/handlers/__init__.py @@ -0,0 +1 @@ +from .condition import condition # noqa: F401 diff --git a/chirho/observational/handlers/condition.py b/chirho/observational/handlers/condition.py new file mode 100644 index 00000000..200988e6 --- /dev/null +++ b/chirho/observational/handlers/condition.py @@ -0,0 +1,63 @@ +from typing import Generic, Hashable, Mapping, TypeVar + +import pyro + +from chirho.observational.internals import ObserveNameMessenger +from chirho.observational.ops import AtomicObservation, observe + +T = TypeVar("T") + + +class ConditionMessenger(Generic[T], ObserveNameMessenger): + """ + Condition on values in a probabilistic program. + + Can be used as a drop-in replacement for :func:`pyro.condition` that supports + a richer set of observational data types and enables counterfactual inference. + """ + + def __init__(self, data: Mapping[Hashable, AtomicObservation[T]]): + self.data = data + super().__init__() + + def _pyro_sample(self, msg): + if pyro.poutine.util.site_is_subsample(msg) or pyro.poutine.util.site_is_factor( + msg + ): + return + + if msg["name"] not in self.data or msg["infer"].get("_do_not_observe", None): + if ( + "_markov_scope" in msg["infer"] + and getattr(self, "_current_site", None) is not None + ): + msg["infer"]["_markov_scope"].pop(self._current_site, None) + return + + msg["stop"] = True + msg["done"] = True + + # flags to guarantee commutativity of condition, intervene, trace + msg["mask"] = False + msg["is_observed"] = False + msg["infer"]["is_auxiliary"] = True + msg["infer"]["_do_not_trace"] = True + msg["infer"]["_do_not_intervene"] = True + msg["infer"]["_do_not_observe"] = True + + with pyro.poutine.infer_config( + config_fn=lambda msg_: { + "_do_not_observe": msg["name"] == msg_["name"] + or msg_["infer"].get("_do_not_observe", False) + } + ): + try: + self._current_site = msg["name"] + msg["value"] = observe( + msg["fn"], self.data[msg["name"]], name=msg["name"], **msg["kwargs"] + ) + finally: + self._current_site = None + + +condition = pyro.poutine.handlers._make_handler(ConditionMessenger)[1] diff --git a/chirho/observational/handlers.py b/chirho/observational/handlers/soft_conditioning.py similarity index 99% rename from chirho/observational/handlers.py rename to chirho/observational/handlers/soft_conditioning.py index c19bdbe6..b79f3978 100644 --- a/chirho/observational/handlers.py +++ b/chirho/observational/handlers/soft_conditioning.py @@ -8,6 +8,7 @@ T = TypeVar("T") + Kernel = Callable[[T, T], torch.Tensor] diff --git a/chirho/observational/internals.py b/chirho/observational/internals.py new file mode 100644 index 00000000..36a7ce1f --- /dev/null +++ b/chirho/observational/internals.py @@ -0,0 +1,47 @@ +from typing import Optional, TypeVar + +import pyro +import pyro.distributions +import torch + +from chirho.observational.ops import AtomicObservation, observe + +T = TypeVar("T") + + +@observe.register(int) +@observe.register(float) +@observe.register(bool) +@observe.register(torch.Tensor) +def _observe_deterministic(rv: T, obs: Optional[AtomicObservation[T]] = None, **kwargs): + """ + Observe a tensor in a probabilistic program. + """ + rv_dist = pyro.distributions.Delta( + torch.as_tensor(rv), event_dim=kwargs.pop("event_dim", 0) + ) + return observe(rv_dist, obs, **kwargs) + + +@observe.register(pyro.distributions.Distribution) +@pyro.poutine.runtime.effectful(type="observe") +def _observe_distribution( + rv: pyro.distributions.Distribution, + obs: Optional[AtomicObservation[T]] = None, + *, + name: Optional[str] = None, + **kwargs, +) -> T: + if name is None: + raise ValueError("name must be specified when observing a distribution") + + if callable(obs): + raise NotImplementedError("Dependent observations are not yet supported") + + return pyro.sample(name, rv, obs=obs, **kwargs) + + +class ObserveNameMessenger(pyro.poutine.messenger.Messenger): + def _pyro_observe(self, msg): + if "name" not in msg["kwargs"]: + msg["kwargs"]["name"] = msg["name"] diff --git a/chirho/observational/ops.py b/chirho/observational/ops.py new file mode 100644 index 00000000..a32b357f --- /dev/null +++ b/chirho/observational/ops.py @@ -0,0 +1,18 @@ +import functools +from typing import Callable, Hashable, Mapping, Optional, TypeVar, Union + +T = TypeVar("T") + +AtomicObservation = Union[T, Callable[..., T]] # TODO add support for more atomic types +CompoundObservation = Union[ + Mapping[Hashable, AtomicObservation[T]], Callable[..., AtomicObservation[T]] +] +Observation = Union[AtomicObservation[T], CompoundObservation[T]] + + +@functools.singledispatch +def observe(rv, obs: Optional[Observation[T]] = None, **kwargs) -> T: + """ + Observe a random value in a probabilistic program. + """ + raise NotImplementedError(f"observe not implemented for type {type(rv)}") diff --git a/tests/observational/test_handlers_soft_conditioning.py b/tests/observational/test_handlers.py similarity index 60% rename from tests/observational/test_handlers_soft_conditioning.py rename to tests/observational/test_handlers.py index 6b393870..591ceaa0 100644 --- a/tests/observational/test_handlers_soft_conditioning.py +++ b/tests/observational/test_handlers.py @@ -10,7 +10,8 @@ TwinWorldCounterfactual, ) from chirho.interventional.handlers import do -from chirho.observational.handlers import ( +from chirho.observational.handlers import condition +from chirho.observational.handlers.soft_conditioning import ( AutoSoftConditioning, KernelSoftConditionReparam, RBFKernel, @@ -69,7 +70,7 @@ def test_soft_conditioning_smoke_continuous_1( } with pyro.poutine.trace() as tr, pyro.poutine.reparam( config=reparam_config - ), pyro.condition(data=data): + ), condition(data=data): continuous_scm_1() tr.trace.compute_log_prob() @@ -110,7 +111,7 @@ def test_soft_conditioning_smoke_discrete_1( } with pyro.poutine.trace() as tr, pyro.poutine.reparam( config=reparam_config - ), pyro.condition(data=data): + ), condition(data=data): discrete_scm_1() tr.trace.compute_log_prob() @@ -154,7 +155,7 @@ def test_soft_conditioning_counterfactual_continuous_1( with pyro.poutine.trace() as tr, pyro.poutine.reparam( config=reparam_config - ), cf_class(cf_dim), do(actions=actions), pyro.condition(data=data): + ), cf_class(cf_dim), do(actions=actions), condition(data=data): continuous_scm_1() tr.trace.compute_log_prob() @@ -174,3 +175,109 @@ def test_soft_conditioning_counterfactual_continuous_1( else: assert AutoSoftConditioning.site_is_deterministic(tr.trace.nodes[name]) assert f"{name}_approx_log_prob" not in tr.trace.nodes + + +def hmm_model(data): + transition_probs = pyro.param( + "transition_probs", + torch.tensor([[0.75, 0.25], [0.25, 0.75]]), + constraint=dist.constraints.simplex, + ) + emission_probs = pyro.sample( + "emission_probs", + dist.Dirichlet(torch.tensor([0.5, 0.5])).expand([2]).to_event(1), + ) + x = pyro.sample("x", dist.Categorical(torch.tensor([0.5, 0.5]))) + logger.debug(f"-1\t{tuple(x.shape)}") + for t, y in pyro.markov(enumerate(data)): + x = pyro.sample( + f"x_{t}", + dist.Categorical(pyro.ops.indexing.Vindex(transition_probs)[..., x, :]), + ) + + pyro.sample( + f"y_{t}", + dist.Categorical(pyro.ops.indexing.Vindex(emission_probs)[..., x, :]), + ) + logger.debug(f"{t}\t{tuple(x.shape)}") + + +@pytest.mark.parametrize("num_particles", [1, 10]) +@pytest.mark.parametrize("max_plate_nesting", [3, float("inf")]) +@pytest.mark.parametrize("use_guide", [False, True]) +@pytest.mark.parametrize("num_steps", [2, 3, 4, 5, 6]) +@pytest.mark.parametrize("Elbo", [pyro.infer.TraceEnum_ELBO, pyro.infer.TraceTMC_ELBO]) +def test_smoke_condition_enumerate_hmm_elbo( + num_steps, Elbo, use_guide, max_plate_nesting, num_particles +): + data = dist.Categorical(torch.tensor([0.5, 0.5])).sample((num_steps,)) + + assert issubclass(Elbo, pyro.infer.elbo.ELBO) + elbo = Elbo( + max_plate_nesting=max_plate_nesting, + num_particles=num_particles, + vectorize_particles=(num_particles > 1), + ) + + model = condition(data={f"y_{t}": y for t, y in enumerate(data)})(hmm_model) + + if use_guide: + guide = pyro.infer.config_enumerate(default="parallel")( + pyro.infer.autoguide.AutoDiscreteParallel( + pyro.poutine.block(expose=["x"])(condition(data={})(model)) + ) + ) + model = pyro.infer.config_enumerate(default="parallel")(model) + else: + model = pyro.infer.config_enumerate(default="parallel")(model) + model = condition(model, data={"x": torch.as_tensor(0)}) + + def guide(data): + pass + + # smoke test + elbo.differentiable_loss(model, guide, data) + + +def test_condition_commutes(): + def model(): + z = pyro.sample("z", dist.Normal(0, 1), obs=torch.tensor(0.1)) + with pyro.plate("data", 2): + x = pyro.sample("x", dist.Normal(z, 1)) + y = pyro.sample("y", dist.Normal(x + z, 1)) + return z, x, y + + h_cond = condition( + data={"x": torch.tensor([0.0, 1.0]), "y": torch.tensor([1.0, 2.0])} + ) + h_do = do(actions={"z": torch.tensor(0.0), "x": torch.tensor([0.3, 0.4])}) + + # case 1 + with pyro.poutine.trace() as tr1: + with h_cond, h_do: + model() + + # case 2 + with pyro.poutine.trace() as tr2: + with h_do, h_cond: + model() + + # case 3 + with h_cond, pyro.poutine.trace() as tr3: + with h_do: + model() + + tr1.trace.compute_log_prob() + tr2.trace.compute_log_prob() + tr3.trace.compute_log_prob() + + assert set(tr1.trace.nodes) == set(tr2.trace.nodes) == set(tr3.trace.nodes) + assert ( + tr1.trace.log_prob_sum() == tr2.trace.log_prob_sum() == tr3.trace.log_prob_sum() + ) + for name, node in tr1.trace.nodes.items(): + if node["type"] == "sample" and not pyro.poutine.util.site_is_subsample(node): + assert torch.allclose(node["value"], tr2.trace.nodes[name]["value"]) + assert torch.allclose(node["value"], tr3.trace.nodes[name]["value"]) + assert torch.allclose(node["log_prob"], tr2.trace.nodes[name]["log_prob"]) + assert torch.allclose(node["log_prob"], tr3.trace.nodes[name]["log_prob"])