diff --git a/chirho/dynamical/internals/_utils.py b/chirho/dynamical/internals/_utils.py index 7e4d6233..96ff4c2f 100644 --- a/chirho/dynamical/internals/_utils.py +++ b/chirho/dynamical/internals/_utils.py @@ -1,6 +1,8 @@ import functools -from typing import FrozenSet, Optional, Tuple, TypeVar +import typing +from typing import Any, Callable, Dict, FrozenSet, Optional, Tuple, TypeVar +import pyro import torch from chirho.dynamical.ops import State @@ -102,3 +104,45 @@ def _observe_state( return State( **{k: observe(rv[k], obs[k], name=f"{name}__{k}", **kwargs) for k in rv.keys()} ) + + +class ShallowMessenger(pyro.poutine.messenger.Messenger): + """ + Base class for so-called "shallow" effect handlers that uninstall themselves + after handling a single operation. + + .. warning:: + + Does not support post-processing or overriding generic ``_process_message`` + """ + + used: bool + + def __enter__(self): + self.used = False + return super().__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + if self in pyro.poutine.runtime._PYRO_STACK: + super().__exit__(exc_type, exc_value, traceback) + + @typing.final + def _process_message(self, msg: Dict[str, Any]) -> None: + if not self.used and hasattr(self, f"_pyro_{msg['type']}"): + self.used = True + super()._process_message(msg) + + prev_cont: Optional[Callable[[Dict[str, Any]], None]] = msg["continuation"] + + def cont(msg: Dict[str, Any]) -> None: + ix = pyro.poutine.runtime._PYRO_STACK.index(self) + pyro.poutine.runtime._PYRO_STACK.pop(ix) + if prev_cont is not None: + prev_cont(msg) + + msg["continuation"] = cont + + @typing.final + def _postprocess_message(self, msg: Dict[str, Any]) -> None: + if hasattr(self, f"_pyro_post_{msg['type']}"): + raise NotImplementedError("ShallowHandler does not support postprocessing") diff --git a/tests/dynamical/test_handler_composition.py b/tests/dynamical/test_handler_composition.py index 7447e510..db47b12a 100644 --- a/tests/dynamical/test_handler_composition.py +++ b/tests/dynamical/test_handler_composition.py @@ -1,6 +1,7 @@ import logging import pyro +import pytest import torch from pyro.distributions import Normal @@ -12,6 +13,7 @@ StaticIntervention, ) from chirho.dynamical.handlers.solver import TorchDiffEq +from chirho.dynamical.internals._utils import ShallowMessenger from chirho.dynamical.ops import State, simulate from chirho.observational.handlers import condition from chirho.observational.handlers.soft_conditioning import AutoSoftConditioning @@ -135,3 +137,56 @@ def test_smoke_inference_twincounterfactual_observation_intervention_commutes(): ) # Noise is shared between factual and counterfactual worlds. assert pred["u_ip"].squeeze().shape == (num_samples, len(flight_landing_times)) + + +class ShallowLogSample(ShallowMessenger): + log: set + + def __enter__(self): + self.log = set() + return super().__enter__() + + def _pyro_sample(self, msg): + self.log.add(msg["name"]) + pyro.sample(f"{msg['name']}__2", pyro.distributions.Bernoulli(0.5)) + + +def test_shallow_handler_stack(): + with pyro.poutine.trace() as tr1: + with ShallowLogSample() as h1, ShallowLogSample() as h2: + pyro.sample("a", pyro.distributions.Bernoulli(0.5)) + with pyro.poutine.trace() as tr2, ShallowLogSample() as h3: + pyro.sample("b", pyro.distributions.Bernoulli(0.5)) + + assert h1.log == {"a__2"} + assert h2.log == {"a"} + assert h3.log == {"b"} + + assert set(tr1.trace.nodes.keys()) == {"a", "a__2", "a__2__2", "b", "b__2"} + assert set(tr2.trace.nodes.keys()) == {"b", "b__2"} + + +def test_shallow_handler_block(): + with ShallowLogSample() as h1, pyro.poutine.block(hide_types=["sample"]): + pyro.sample("a", pyro.distributions.Bernoulli(0.5)) + with ShallowLogSample() as h2: + pyro.sample("b", pyro.distributions.Bernoulli(0.5)) + + assert h1.log == set() + assert h2.log == {"b"} + + +@pytest.mark.parametrize("error_before_sample", [True, False]) +def test_shallow_handler_error(error_before_sample): + class ShallowLogSampleError(ShallowLogSample): + def _pyro_sample(self, msg): + super()._pyro_sample(msg) + raise RuntimeError("error") + + with pytest.raises(RuntimeError, match="error"): + with ShallowLogSampleError(): + if error_before_sample: + raise RuntimeError("error") + pyro.sample("a", pyro.distributions.Bernoulli(0.5)) + + assert not pyro.poutine.runtime._PYRO_STACK