From ef0f5b819e65dca46fa5acf473b220f436b5e01c Mon Sep 17 00:00:00 2001 From: rfl-urbaniak Date: Mon, 9 Oct 2023 15:50:09 +0200 Subject: [PATCH] SearchForCause effect handler with tests thereof (#297) * started part of cause * part of cause in progress * second test WIP * fixed small typos in undo_split documentation * first emulated second test success * added SearchOfCause * dealing with the second test WIP * dealing with the layered test WIP * two-layer test succeeds * small lint * lint after black update * renamed the handler to SearchForCause * renamed the handler in the test * tweak tests * simplify tests * revert * revert --------- Co-authored-by: eb8680 Co-authored-by: Eli --- chirho/counterfactual/handlers/explanation.py | 39 +++- .../test_handlers_explanation.py | 177 +++++++++++++++++- 2 files changed, 205 insertions(+), 11 deletions(-) diff --git a/chirho/counterfactual/handlers/explanation.py b/chirho/counterfactual/handlers/explanation.py index ac41cf8e..ae90f6ef 100644 --- a/chirho/counterfactual/handlers/explanation.py +++ b/chirho/counterfactual/handlers/explanation.py @@ -1,12 +1,16 @@ +import contextlib import functools import itertools -from typing import Callable, Iterable, TypeVar +from typing import Callable, Iterable, Mapping, TypeVar import pyro import torch # noqa: F401 +from chirho.counterfactual.handlers.counterfactual import Preemptions from chirho.counterfactual.handlers.selection import get_factual_indices from chirho.indexed.ops import IndexSet, cond, gather, indices_of, scatter +from chirho.interventional.handlers import do +from chirho.interventional.ops import Intervention S = TypeVar("S") T = TypeVar("T") @@ -15,8 +19,8 @@ def undo_split(antecedents: Iterable[str] = [], event_dim: int = 0) -> Callable[[T], T]: """ A helper function that undoes an upstream :func:`~chirho.counterfactual.ops.split` operation, - meant to meant to be used to create arguments to pass to :func:`~chirho.interventional.ops.intervene` , - :func:`~chirho.counterfactual.ops.split` or :func:`~chirho.counterfactual.ops.preempt` . + meant to be used to create arguments to pass to :func:`~chirho.interventional.ops.intervene` , + :func:`~chirho.counterfactual.ops.split` or :func:`~chirho.counterfactual.ops.preempt`. Works by gathering the factual value and scattering it back into two alternative cases. :param antecedents: A list of upstream intervened sites which induced the :func:`split` to be reversed. @@ -84,6 +88,35 @@ def _consequent_differs(consequent: T) -> torch.Tensor: return _consequent_differs +@contextlib.contextmanager +def SearchForCause( + actions: Mapping[str, Intervention[T]], + *, + bias: float = 0.0, + prefix: str = "__cause_split_", +): + """ + A context manager used for a stochastic search of minimal but-for causes among potential interventions. + On each run, nodes listed in `actions` are randomly seleted and intervened on with probability `.5 + bias` + (that is, preempted with probability `.5-bias`). The sampling is achieved by adding stochastic binary preemption + nodes associated with intervention candidates. If a given preemption node has value `0`, the corresponding + intervention is executed. See tests in `tests/counterfactual/test_handlers_explanation.py` for examples. + + :param actions: A mapping of sites to interventions. + :param bias: The scalar bias towards not intervening. Must be between -0.5 and 0.5, defaults to 0.0. + :param prefix: A prefix used for naming additional preemption nodes. Defaults to "__cause_split_". + """ + # TODO support event_dim != 0 propagation in factual_preemption + preemptions = { + antecedent: undo_split(antecedents=[antecedent]) + for antecedent in actions.keys() + } + + with do(actions=actions): + with Preemptions(actions=preemptions, bias=bias, prefix=prefix): + yield + + @functools.singledispatch def uniform_proposal( support: pyro.distributions.constraints.Constraint, diff --git a/tests/counterfactual/test_handlers_explanation.py b/tests/counterfactual/test_handlers_explanation.py index 602c4dd8..1eb71b9c 100644 --- a/tests/counterfactual/test_handlers_explanation.py +++ b/tests/counterfactual/test_handlers_explanation.py @@ -5,8 +5,12 @@ import torch from scipy.stats import spearmanr -from chirho.counterfactual.handlers import MultiWorldCounterfactual +from chirho.counterfactual.handlers.counterfactual import ( + MultiWorldCounterfactual, + Preemptions, +) from chirho.counterfactual.handlers.explanation import ( + SearchForCause, consequent_differs, random_intervention, undo_split, @@ -14,7 +18,7 @@ ) from chirho.counterfactual.ops import preempt, split from chirho.indexed.ops import IndexSet, gather, indices_of -from chirho.observational.handlers.condition import Factors +from chirho.observational.handlers.condition import Factors, condition def test_undo_split(): @@ -216,8 +220,169 @@ def model_cd(): assert nd["__factor_consequent"]["log_prob"].sum() < -1e2 -# ________________________________________________ -# testing uniform proposal and random intervention +def stones_bayesian_model(): + prob_sally_throws = pyro.sample("prob_sally_throws", dist.Beta(1, 1)) + prob_bill_throws = pyro.sample("prob_bill_throws", dist.Beta(1, 1)) + prob_sally_hits = pyro.sample("prob_sally_hits", dist.Beta(1, 1)) + prob_bill_hits = pyro.sample("prob_bill_hits", dist.Beta(1, 1)) + prob_bottle_shatters_if_sally = pyro.sample( + "prob_bottle_shatters_if_sally", dist.Beta(1, 1) + ) + prob_bottle_shatters_if_bill = pyro.sample( + "prob_bottle_shatters_if_bill", dist.Beta(1, 1) + ) + + sally_throws = pyro.sample("sally_throws", dist.Bernoulli(prob_sally_throws)) + bill_throws = pyro.sample("bill_throws", dist.Bernoulli(prob_bill_throws)) + + new_shp = torch.where(sally_throws == 1, prob_sally_hits, 0.0) + + sally_hits = pyro.sample("sally_hits", dist.Bernoulli(new_shp)) + + new_bhp = torch.where( + (bill_throws.bool() & (~sally_hits.bool())) == 1, + prob_bill_hits, + torch.tensor(0.0), + ) + + bill_hits = pyro.sample("bill_hits", dist.Bernoulli(new_bhp)) + + new_bsp = torch.where( + bill_hits.bool() == 1, + prob_bottle_shatters_if_bill, + torch.where( + sally_hits.bool() == 1, + prob_bottle_shatters_if_sally, + torch.tensor(0.0), + ), + ) + + bottle_shatters = pyro.sample("bottle_shatters", dist.Bernoulli(new_bsp)) + + return { + "sally_throws": sally_throws, + "bill_throws": bill_throws, + "sally_hits": sally_hits, + "bill_hits": bill_hits, + "bottle_shatters": bottle_shatters, + } + + +def test_SearchForCause_single_layer(): + observations = { + "prob_sally_throws": 1.0, + "prob_bill_throws": 1.0, + "prob_sally_hits": 1.0, + "prob_bill_hits": 1.0, + "prob_bottle_shatters_if_sally": 1.0, + "prob_bottle_shatters_if_bill": 1.0, + } + + observations_conditioning = condition( + data={k: torch.as_tensor(v) for k, v in observations.items()} + ) + + with MultiWorldCounterfactual() as mwc: + with SearchForCause({"sally_throws": 0.0}, bias=0.0): + with observations_conditioning: + with pyro.poutine.trace() as tr: + stones_bayesian_model() + + tr = tr.trace.nodes + + with mwc: + preempt_sally_throws = gather( + tr["__cause_split_sally_throws"]["value"], + IndexSet(**{"sally_throws": {0}}), + event_dim=0, + ) + + int_sally_hits = gather( + tr["sally_hits"]["value"], IndexSet(**{"sally_throws": {1}}), event_dim=0 + ) + + obs_bill_hits = gather( + tr["bill_hits"]["value"], IndexSet(**{"sally_throws": {0}}), event_dim=0 + ) + + int_bill_hits = gather( + tr["bill_hits"]["value"], IndexSet(**{"sally_throws": {1}}), event_dim=0 + ) + + int_bottle_shatters = gather( + tr["bottle_shatters"]["value"], + IndexSet(**{"sally_throws": {1}}), + event_dim=0, + ) + + outcome = { + "preempt_sally_throws": preempt_sally_throws.item(), + "int_sally_hits": int_sally_hits.item(), + "obs_bill_hits": obs_bill_hits.item(), + "int_bill_hits": int_bill_hits.item(), + "intervened_bottle_shatters": int_bottle_shatters.item(), + } + + assert list(outcome.values()) == [0, 0.0, 0.0, 1.0, 1.0] or list( + outcome.values() + ) == [1, 1.0, 0.0, 0.0, 1.0] + + +def test_SearchForCause_two_layers(): + observations = { + "prob_sally_throws": 1.0, + "prob_bill_throws": 1.0, + "prob_sally_hits": 1.0, + "prob_bill_hits": 1.0, + "prob_bottle_shatters_if_sally": 1.0, + "prob_bottle_shatters_if_bill": 1.0, + } + + observations_conditioning = condition( + data={k: torch.as_tensor(v) for k, v in observations.items()} + ) + + actions = {"sally_throws": 0.0} + + pinned_preemption_variables = { + "preempt_sally_throws": torch.tensor(0), + "witness_preempt_bill_hits": torch.tensor(1), + } + preemption_conditioning = condition(data=pinned_preemption_variables) + + witness_preemptions = {"bill_hits": undo_split(antecedents=actions.keys())} + witness_preemptions_handler: Preemptions = Preemptions( + actions=witness_preemptions, prefix="witness_preempt_" + ) + + with MultiWorldCounterfactual() as mwc: + with SearchForCause(actions=actions, bias=0.1, prefix="preempt_"): + with preemption_conditioning, witness_preemptions_handler: + with observations_conditioning: + with pyro.poutine.trace() as tr: + stones_bayesian_model() + + tr = tr.trace.nodes + + with mwc: + obs_bill_hits = gather( + tr["bill_hits"]["value"], + IndexSet(**{"sally_throws": {0}}), + event_dim=0, + ).item() + int_bill_hits = gather( + tr["bill_hits"]["value"], + IndexSet(**{"sally_throws": {1}}), + event_dim=0, + ).item() + int_bottle_shatters = gather( + tr["bottle_shatters"]["value"], + IndexSet(**{"sally_throws": {1}}), + event_dim=0, + ).item() + + assert obs_bill_hits == 0.0 and int_bill_hits == 0.0 and int_bottle_shatters == 0.0 + support_real = pyro.distributions.constraints.real support_boolean = pyro.distributions.constraints.boolean @@ -294,7 +459,3 @@ def test_random_intervention(support): samples = samples[samples != 0] assert torch.all(support.check(samples)) - - -# tests of uniform proposal and random intervention end here -# ___________________________________________________________