diff --git a/chirho/counterfactual/handlers/explanation.py b/chirho/counterfactual/handlers/explanation.py index ae90f6ef..4fe40de9 100644 --- a/chirho/counterfactual/handlers/explanation.py +++ b/chirho/counterfactual/handlers/explanation.py @@ -4,7 +4,7 @@ from typing import Callable, Iterable, Mapping, TypeVar import pyro -import torch # noqa: F401 +import torch from chirho.counterfactual.handlers.counterfactual import Preemptions from chirho.counterfactual.handlers.selection import get_factual_indices @@ -88,35 +88,6 @@ def _consequent_differs(consequent: T) -> torch.Tensor: return _consequent_differs -@contextlib.contextmanager -def SearchForCause( - actions: Mapping[str, Intervention[T]], - *, - bias: float = 0.0, - prefix: str = "__cause_split_", -): - """ - A context manager used for a stochastic search of minimal but-for causes among potential interventions. - On each run, nodes listed in `actions` are randomly seleted and intervened on with probability `.5 + bias` - (that is, preempted with probability `.5-bias`). The sampling is achieved by adding stochastic binary preemption - nodes associated with intervention candidates. If a given preemption node has value `0`, the corresponding - intervention is executed. See tests in `tests/counterfactual/test_handlers_explanation.py` for examples. - - :param actions: A mapping of sites to interventions. - :param bias: The scalar bias towards not intervening. Must be between -0.5 and 0.5, defaults to 0.0. - :param prefix: A prefix used for naming additional preemption nodes. Defaults to "__cause_split_". - """ - # TODO support event_dim != 0 propagation in factual_preemption - preemptions = { - antecedent: undo_split(antecedents=[antecedent]) - for antecedent in actions.keys() - } - - with do(actions=actions): - with Preemptions(actions=preemptions, bias=bias, prefix=prefix): - yield - - @functools.singledispatch def uniform_proposal( support: pyro.distributions.constraints.Constraint, @@ -138,7 +109,7 @@ def uniform_proposal( :return: A uniform probability distribution over the specified support. """ if support is pyro.distributions.constraints.real: - return pyro.distributions.Normal(0, 100).mask(False) + return pyro.distributions.Normal(0, 10).mask(False) elif support is pyro.distributions.constraints.boolean: return pyro.distributions.Bernoulli(logits=torch.zeros(())) else: @@ -154,26 +125,6 @@ def _uniform_proposal_indep( event_shape: torch.Size = torch.Size([]), **kwargs, ) -> pyro.distributions.Distribution: - """ - This constructs a probability distribution with independent dimensions - over a specified support. The choice of distribution depends on the type of support provided - (see the documentation for `uniform_proposal`). - - :param support: The support used to create the probability distribution. - :param event_shape: The event shape specifying the dimensions of the distribution. - :param kwargs: Additional keyword arguments. - :return: A probability distribution with independent dimensions over the specified support. - - Example: - ``` - indep_constraint = pyro.distributions.constraints.independent( - pyro.distributions.constraints.real, reinterpreted_batch_ndims=2) - dist = uniform_proposal(indep_constraint, event_shape=torch.Size([2, 3])) - with pyro.plate("data", 3): - samples_indep = pyro.sample("samples_indep", dist.expand([4, 2, 3])) - ``` - """ - d = uniform_proposal(support.base_constraint, event_shape=event_shape, **kwargs) return d.expand(event_shape).to_event(support.reinterpreted_batch_ndims) @@ -183,22 +134,6 @@ def _uniform_proposal_integer( support: pyro.distributions.constraints.integer_interval, **kwargs, ) -> pyro.distributions.Distribution: - """ - This constructs a uniform categorical distribution over an integer_interval support - where the lower bound is 0 and the upper bound is specified by the support. - - :param support: The integer_interval support with a lower bound of 0 and a specified upper bound. - :param kwargs: Additional keyword arguments. - :return: A categorical probability distribution over the specified integer_interval support. - - Example: - ``` - constraint = pyro.distributions.constraints.integer_interval(0, 2) - dist = _uniform_proposal_integer(constraint) - samples = dist.sample(torch.Size([100])) - print(dist.probs.tolist()) - ``` - """ if support.lower_bound != 0: raise NotImplementedError( "integer_interval with lower_bound > 0 not yet supported" @@ -212,23 +147,23 @@ def random_intervention( name: str, ) -> Callable[[torch.Tensor], torch.Tensor]: """ - Creates a random `pyro`sample` function for a single sample site, determined by + Creates a random-valued intervention for a single sample site, determined by by the distribution support, and site name. - :param support: The support constraint for the sample site..can take. - :param name: The name of the sample site. + :param support: The support constraint for the sample site. + :param name: The name of the auxiliary sample site. - :return: A `pyro.sample` function that takes a torch.Tensor as input + :return: A function that takes a ``torch.Tensor`` as input and returns a random sample over the pre-specified support of the same event shape as the input tensor. - Example: - ``` - support = pyro.distributions.constraints.real - name = "real_sample" - intervention_fn = random_intervention(support, name) - random_sample = intervention_fn(torch.tensor(2.0)) - ``` + Example:: + + >>> support = pyro.distributions.constraints.real + >>> intervention_fn = random_intervention(support, name="random_value") + >>> with chirho.interventional.handlers.do(actions={"x": intervention_fn}): + ... x = pyro.deterministic("x", torch.tensor(2.)) + >>> assert x != 2 """ def _random_intervention(value: torch.Tensor) -> torch.Tensor: @@ -240,3 +175,32 @@ def _random_intervention(value: torch.Tensor) -> torch.Tensor: return pyro.sample(name, proposal_dist) return _random_intervention + + +@contextlib.contextmanager +def SearchForCause( + actions: Mapping[str, Intervention[T]], + *, + bias: float = 0.0, + prefix: str = "__cause_split_", +): + """ + A context manager used for a stochastic search of minimal but-for causes among potential interventions. + On each run, nodes listed in `actions` are randomly selected and intervened on with probability `.5 + bias` + (that is, preempted with probability `.5-bias`). The sampling is achieved by adding stochastic binary preemption + nodes associated with intervention candidates. If a given preemption node has value `0`, the corresponding + intervention is executed. See tests in `tests/counterfactual/test_handlers_explanation.py` for examples. + + :param actions: A mapping of sites to interventions. + :param bias: The scalar bias towards not intervening. Must be between -0.5 and 0.5, defaults to 0.0. + :param prefix: A prefix used for naming additional preemption nodes. Defaults to "__cause_split_". + """ + # TODO support event_dim != 0 propagation in factual_preemption + preemptions = { + antecedent: undo_split(antecedents=[antecedent]) + for antecedent in actions.keys() + } + + with do(actions=actions): + with Preemptions(actions=preemptions, bias=bias, prefix=prefix): + yield diff --git a/tests/counterfactual/test_handlers_explanation.py b/tests/counterfactual/test_handlers_explanation.py index 1eb71b9c..2348fa5c 100644 --- a/tests/counterfactual/test_handlers_explanation.py +++ b/tests/counterfactual/test_handlers_explanation.py @@ -3,7 +3,6 @@ import pyro.infer import pytest import torch -from scipy.stats import spearmanr from chirho.counterfactual.handlers.counterfactual import ( MultiWorldCounterfactual, @@ -18,6 +17,7 @@ ) from chirho.counterfactual.ops import preempt, split from chirho.indexed.ops import IndexSet, gather, indices_of +from chirho.interventional.ops import intervene from chirho.observational.handlers.condition import Factors, condition @@ -177,7 +177,7 @@ def model(): @pytest.mark.parametrize("plate_size", [4, 50, 200]) -@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)]) +@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)], ids=str) def test_consequent_differs(plate_size, event_shape): factors = { "consequent": consequent_differs( @@ -220,6 +220,43 @@ def model_cd(): assert nd["__factor_consequent"]["log_prob"].sum() < -1e2 +SUPPORT_CASES = [ + pyro.distributions.constraints.real, + pyro.distributions.constraints.boolean, + pyro.distributions.constraints.positive, + pyro.distributions.constraints.interval(0, 10), + pyro.distributions.constraints.interval(-5, 5), + pyro.distributions.constraints.integer_interval(0, 2), + pyro.distributions.constraints.integer_interval(0, 100), +] + + +@pytest.mark.parametrize("support", SUPPORT_CASES) +@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)], ids=str) +def test_uniform_proposal(support, event_shape): + if event_shape: + support = pyro.distributions.constraints.independent(support, len(event_shape)) + + uniform = uniform_proposal(support, event_shape=event_shape) + samples = uniform.sample((10,)) + assert torch.all(support.check(samples)) + + +@pytest.mark.parametrize("support", SUPPORT_CASES) +@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)], ids=str) +def test_random_intervention(support, event_shape): + if event_shape: + support = pyro.distributions.constraints.independent(support, len(event_shape)) + + obs_value = torch.randn(event_shape) + intervention = random_intervention(support, "samples") + + with pyro.plate("draws", 10): + samples = intervene(obs_value, intervention) + + assert torch.all(support.check(samples)) + + def stones_bayesian_model(): prob_sally_throws = pyro.sample("prob_sally_throws", dist.Beta(1, 1)) prob_bill_throws = pyro.sample("prob_bill_throws", dist.Beta(1, 1)) @@ -382,80 +419,3 @@ def test_SearchForCause_two_layers(): ).item() assert obs_bill_hits == 0.0 and int_bill_hits == 0.0 and int_bottle_shatters == 0.0 - - -support_real = pyro.distributions.constraints.real -support_boolean = pyro.distributions.constraints.boolean -support_positive = pyro.distributions.constraints.positive -support_interval = pyro.distributions.constraints.interval(0, 10) -support_integer_interval = pyro.distributions.constraints.integer_interval(0, 2) -indep_constraint = pyro.distributions.constraints.independent( - pyro.distributions.constraints.real, reinterpreted_batch_ndims=1 -) - - -@pytest.mark.parametrize( - "support", - [ - support_real, - support_boolean, - support_positive, - support_interval, - support_integer_interval, - indep_constraint, - ], -) -@pytest.mark.parametrize("edges", [(0, 2), (0, 100), (0, 250)]) -def test_uniform_proposal(support, edges): - # plug the edges into interval constraints - if support is support_integer_interval: - support = pyro.distributions.constraints.integer_interval(*edges) - elif support is support_interval: - support = pyro.distributions.constraints.interval(*edges) - - # test all but the indep_constraint - if support is not indep_constraint: - uniform = uniform_proposal(support) - with pyro.plate("samples", 50): - samples = pyro.sample("samples", uniform) - - # with positive constraint, zeros are possible, but - # they don't pass `support.check`. Considered harmless. - if support is support_positive: - samples = samples[samples != 0] - - assert torch.all(support.check(samples)) - - else: # testing the idependence constraint requires a bit more work - dist_indep = uniform_proposal( - indep_constraint, event_shape=torch.Size([2, 1000]) - ) - with pyro.plate("data", 2): - samples_indep = pyro.sample("samples_indep", dist_indep.expand([2])) - - batch_1 = samples_indep[0].squeeze().tolist() - batch_2 = samples_indep[1].squeeze().tolist() - assert abs(spearmanr(batch_1, batch_2).correlation) < 0.2 - - -@pytest.mark.parametrize( - "support", - [ - support_real, - support_boolean, - support_positive, - support_interval, - support_integer_interval, - indep_constraint, - ], -) -def test_random_intervention(support): - intervention = random_intervention(support, "samples") - - with pyro.plate("draws", 1000): - samples = intervention(torch.ones(10)) - - if support is support_positive: - samples = samples[samples != 0] - - assert torch.all(support.check(samples))