Skip to content

Commit

Permalink
SearchForCause effect handler with tests thereof (#297)
Browse files Browse the repository at this point in the history
* started part of cause

* part of cause in progress

* second test WIP

* fixed small typos in undo_split documentation

* first emulated second test success

* added SearchOfCause

* dealing with the second test WIP

* dealing with the layered test WIP

* two-layer test succeeds

* small lint

* lint after black update

* renamed the handler to SearchForCause

* renamed the handler in the test

* tweak tests

* simplify tests

* revert

* revert

---------

Co-authored-by: eb8680 <[email protected]>
Co-authored-by: Eli <[email protected]>
  • Loading branch information
3 people authored Oct 9, 2023
1 parent 72ada27 commit ef0f5b8
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 11 deletions.
39 changes: 36 additions & 3 deletions chirho/counterfactual/handlers/explanation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import contextlib
import functools
import itertools
from typing import Callable, Iterable, TypeVar
from typing import Callable, Iterable, Mapping, TypeVar

import pyro
import torch # noqa: F401

from chirho.counterfactual.handlers.counterfactual import Preemptions
from chirho.counterfactual.handlers.selection import get_factual_indices
from chirho.indexed.ops import IndexSet, cond, gather, indices_of, scatter
from chirho.interventional.handlers import do
from chirho.interventional.ops import Intervention

S = TypeVar("S")
T = TypeVar("T")
Expand All @@ -15,8 +19,8 @@
def undo_split(antecedents: Iterable[str] = [], event_dim: int = 0) -> Callable[[T], T]:
"""
A helper function that undoes an upstream :func:`~chirho.counterfactual.ops.split` operation,
meant to meant to be used to create arguments to pass to :func:`~chirho.interventional.ops.intervene` ,
:func:`~chirho.counterfactual.ops.split` or :func:`~chirho.counterfactual.ops.preempt` .
meant to be used to create arguments to pass to :func:`~chirho.interventional.ops.intervene` ,
:func:`~chirho.counterfactual.ops.split` or :func:`~chirho.counterfactual.ops.preempt`.
Works by gathering the factual value and scattering it back into two alternative cases.
:param antecedents: A list of upstream intervened sites which induced the :func:`split` to be reversed.
Expand Down Expand Up @@ -84,6 +88,35 @@ def _consequent_differs(consequent: T) -> torch.Tensor:
return _consequent_differs


@contextlib.contextmanager
def SearchForCause(
actions: Mapping[str, Intervention[T]],
*,
bias: float = 0.0,
prefix: str = "__cause_split_",
):
"""
A context manager used for a stochastic search of minimal but-for causes among potential interventions.
On each run, nodes listed in `actions` are randomly seleted and intervened on with probability `.5 + bias`
(that is, preempted with probability `.5-bias`). The sampling is achieved by adding stochastic binary preemption
nodes associated with intervention candidates. If a given preemption node has value `0`, the corresponding
intervention is executed. See tests in `tests/counterfactual/test_handlers_explanation.py` for examples.
:param actions: A mapping of sites to interventions.
:param bias: The scalar bias towards not intervening. Must be between -0.5 and 0.5, defaults to 0.0.
:param prefix: A prefix used for naming additional preemption nodes. Defaults to "__cause_split_".
"""
# TODO support event_dim != 0 propagation in factual_preemption
preemptions = {
antecedent: undo_split(antecedents=[antecedent])
for antecedent in actions.keys()
}

with do(actions=actions):
with Preemptions(actions=preemptions, bias=bias, prefix=prefix):
yield


@functools.singledispatch
def uniform_proposal(
support: pyro.distributions.constraints.Constraint,
Expand Down
177 changes: 169 additions & 8 deletions tests/counterfactual/test_handlers_explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,20 @@
import torch
from scipy.stats import spearmanr

from chirho.counterfactual.handlers import MultiWorldCounterfactual
from chirho.counterfactual.handlers.counterfactual import (
MultiWorldCounterfactual,
Preemptions,
)
from chirho.counterfactual.handlers.explanation import (
SearchForCause,
consequent_differs,
random_intervention,
undo_split,
uniform_proposal,
)
from chirho.counterfactual.ops import preempt, split
from chirho.indexed.ops import IndexSet, gather, indices_of
from chirho.observational.handlers.condition import Factors
from chirho.observational.handlers.condition import Factors, condition


def test_undo_split():
Expand Down Expand Up @@ -216,8 +220,169 @@ def model_cd():
assert nd["__factor_consequent"]["log_prob"].sum() < -1e2


# ________________________________________________
# testing uniform proposal and random intervention
def stones_bayesian_model():
prob_sally_throws = pyro.sample("prob_sally_throws", dist.Beta(1, 1))
prob_bill_throws = pyro.sample("prob_bill_throws", dist.Beta(1, 1))
prob_sally_hits = pyro.sample("prob_sally_hits", dist.Beta(1, 1))
prob_bill_hits = pyro.sample("prob_bill_hits", dist.Beta(1, 1))
prob_bottle_shatters_if_sally = pyro.sample(
"prob_bottle_shatters_if_sally", dist.Beta(1, 1)
)
prob_bottle_shatters_if_bill = pyro.sample(
"prob_bottle_shatters_if_bill", dist.Beta(1, 1)
)

sally_throws = pyro.sample("sally_throws", dist.Bernoulli(prob_sally_throws))
bill_throws = pyro.sample("bill_throws", dist.Bernoulli(prob_bill_throws))

new_shp = torch.where(sally_throws == 1, prob_sally_hits, 0.0)

sally_hits = pyro.sample("sally_hits", dist.Bernoulli(new_shp))

new_bhp = torch.where(
(bill_throws.bool() & (~sally_hits.bool())) == 1,
prob_bill_hits,
torch.tensor(0.0),
)

bill_hits = pyro.sample("bill_hits", dist.Bernoulli(new_bhp))

new_bsp = torch.where(
bill_hits.bool() == 1,
prob_bottle_shatters_if_bill,
torch.where(
sally_hits.bool() == 1,
prob_bottle_shatters_if_sally,
torch.tensor(0.0),
),
)

bottle_shatters = pyro.sample("bottle_shatters", dist.Bernoulli(new_bsp))

return {
"sally_throws": sally_throws,
"bill_throws": bill_throws,
"sally_hits": sally_hits,
"bill_hits": bill_hits,
"bottle_shatters": bottle_shatters,
}


def test_SearchForCause_single_layer():
observations = {
"prob_sally_throws": 1.0,
"prob_bill_throws": 1.0,
"prob_sally_hits": 1.0,
"prob_bill_hits": 1.0,
"prob_bottle_shatters_if_sally": 1.0,
"prob_bottle_shatters_if_bill": 1.0,
}

observations_conditioning = condition(
data={k: torch.as_tensor(v) for k, v in observations.items()}
)

with MultiWorldCounterfactual() as mwc:
with SearchForCause({"sally_throws": 0.0}, bias=0.0):
with observations_conditioning:
with pyro.poutine.trace() as tr:
stones_bayesian_model()

tr = tr.trace.nodes

with mwc:
preempt_sally_throws = gather(
tr["__cause_split_sally_throws"]["value"],
IndexSet(**{"sally_throws": {0}}),
event_dim=0,
)

int_sally_hits = gather(
tr["sally_hits"]["value"], IndexSet(**{"sally_throws": {1}}), event_dim=0
)

obs_bill_hits = gather(
tr["bill_hits"]["value"], IndexSet(**{"sally_throws": {0}}), event_dim=0
)

int_bill_hits = gather(
tr["bill_hits"]["value"], IndexSet(**{"sally_throws": {1}}), event_dim=0
)

int_bottle_shatters = gather(
tr["bottle_shatters"]["value"],
IndexSet(**{"sally_throws": {1}}),
event_dim=0,
)

outcome = {
"preempt_sally_throws": preempt_sally_throws.item(),
"int_sally_hits": int_sally_hits.item(),
"obs_bill_hits": obs_bill_hits.item(),
"int_bill_hits": int_bill_hits.item(),
"intervened_bottle_shatters": int_bottle_shatters.item(),
}

assert list(outcome.values()) == [0, 0.0, 0.0, 1.0, 1.0] or list(
outcome.values()
) == [1, 1.0, 0.0, 0.0, 1.0]


def test_SearchForCause_two_layers():
observations = {
"prob_sally_throws": 1.0,
"prob_bill_throws": 1.0,
"prob_sally_hits": 1.0,
"prob_bill_hits": 1.0,
"prob_bottle_shatters_if_sally": 1.0,
"prob_bottle_shatters_if_bill": 1.0,
}

observations_conditioning = condition(
data={k: torch.as_tensor(v) for k, v in observations.items()}
)

actions = {"sally_throws": 0.0}

pinned_preemption_variables = {
"preempt_sally_throws": torch.tensor(0),
"witness_preempt_bill_hits": torch.tensor(1),
}
preemption_conditioning = condition(data=pinned_preemption_variables)

witness_preemptions = {"bill_hits": undo_split(antecedents=actions.keys())}
witness_preemptions_handler: Preemptions = Preemptions(
actions=witness_preemptions, prefix="witness_preempt_"
)

with MultiWorldCounterfactual() as mwc:
with SearchForCause(actions=actions, bias=0.1, prefix="preempt_"):
with preemption_conditioning, witness_preemptions_handler:
with observations_conditioning:
with pyro.poutine.trace() as tr:
stones_bayesian_model()

tr = tr.trace.nodes

with mwc:
obs_bill_hits = gather(
tr["bill_hits"]["value"],
IndexSet(**{"sally_throws": {0}}),
event_dim=0,
).item()
int_bill_hits = gather(
tr["bill_hits"]["value"],
IndexSet(**{"sally_throws": {1}}),
event_dim=0,
).item()
int_bottle_shatters = gather(
tr["bottle_shatters"]["value"],
IndexSet(**{"sally_throws": {1}}),
event_dim=0,
).item()

assert obs_bill_hits == 0.0 and int_bill_hits == 0.0 and int_bottle_shatters == 0.0


support_real = pyro.distributions.constraints.real
support_boolean = pyro.distributions.constraints.boolean
Expand Down Expand Up @@ -294,7 +459,3 @@ def test_random_intervention(support):
samples = samples[samples != 0]

assert torch.all(support.check(samples))


# tests of uniform proposal and random intervention end here
# ___________________________________________________________

0 comments on commit ef0f5b8

Please sign in to comment.