Skip to content

Commit

Permalink
Clean up random_intervention tests and docs (#306)
Browse files Browse the repository at this point in the history
* Clean up random_intervention tests

* docstring

* import

* imports

* fix docstring

* reorder files

* Update chirho/counterfactual/handlers/explanation.py

Co-authored-by: rfl-urbaniak <[email protected]>

---------

Co-authored-by: rfl-urbaniak <[email protected]>
  • Loading branch information
eb8680 and rfl-urbaniak authored Oct 11, 2023
1 parent ef0f5b8 commit 0e017cd
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 157 deletions.
120 changes: 42 additions & 78 deletions chirho/counterfactual/handlers/explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Callable, Iterable, Mapping, TypeVar

import pyro
import torch # noqa: F401
import torch

from chirho.counterfactual.handlers.counterfactual import Preemptions
from chirho.counterfactual.handlers.selection import get_factual_indices
Expand Down Expand Up @@ -88,35 +88,6 @@ def _consequent_differs(consequent: T) -> torch.Tensor:
return _consequent_differs


@contextlib.contextmanager
def SearchForCause(
actions: Mapping[str, Intervention[T]],
*,
bias: float = 0.0,
prefix: str = "__cause_split_",
):
"""
A context manager used for a stochastic search of minimal but-for causes among potential interventions.
On each run, nodes listed in `actions` are randomly seleted and intervened on with probability `.5 + bias`
(that is, preempted with probability `.5-bias`). The sampling is achieved by adding stochastic binary preemption
nodes associated with intervention candidates. If a given preemption node has value `0`, the corresponding
intervention is executed. See tests in `tests/counterfactual/test_handlers_explanation.py` for examples.
:param actions: A mapping of sites to interventions.
:param bias: The scalar bias towards not intervening. Must be between -0.5 and 0.5, defaults to 0.0.
:param prefix: A prefix used for naming additional preemption nodes. Defaults to "__cause_split_".
"""
# TODO support event_dim != 0 propagation in factual_preemption
preemptions = {
antecedent: undo_split(antecedents=[antecedent])
for antecedent in actions.keys()
}

with do(actions=actions):
with Preemptions(actions=preemptions, bias=bias, prefix=prefix):
yield


@functools.singledispatch
def uniform_proposal(
support: pyro.distributions.constraints.Constraint,
Expand All @@ -138,7 +109,7 @@ def uniform_proposal(
:return: A uniform probability distribution over the specified support.
"""
if support is pyro.distributions.constraints.real:
return pyro.distributions.Normal(0, 100).mask(False)
return pyro.distributions.Normal(0, 10).mask(False)
elif support is pyro.distributions.constraints.boolean:
return pyro.distributions.Bernoulli(logits=torch.zeros(()))
else:
Expand All @@ -154,26 +125,6 @@ def _uniform_proposal_indep(
event_shape: torch.Size = torch.Size([]),
**kwargs,
) -> pyro.distributions.Distribution:
"""
This constructs a probability distribution with independent dimensions
over a specified support. The choice of distribution depends on the type of support provided
(see the documentation for `uniform_proposal`).
:param support: The support used to create the probability distribution.
:param event_shape: The event shape specifying the dimensions of the distribution.
:param kwargs: Additional keyword arguments.
:return: A probability distribution with independent dimensions over the specified support.
Example:
```
indep_constraint = pyro.distributions.constraints.independent(
pyro.distributions.constraints.real, reinterpreted_batch_ndims=2)
dist = uniform_proposal(indep_constraint, event_shape=torch.Size([2, 3]))
with pyro.plate("data", 3):
samples_indep = pyro.sample("samples_indep", dist.expand([4, 2, 3]))
```
"""

d = uniform_proposal(support.base_constraint, event_shape=event_shape, **kwargs)
return d.expand(event_shape).to_event(support.reinterpreted_batch_ndims)

Expand All @@ -183,22 +134,6 @@ def _uniform_proposal_integer(
support: pyro.distributions.constraints.integer_interval,
**kwargs,
) -> pyro.distributions.Distribution:
"""
This constructs a uniform categorical distribution over an integer_interval support
where the lower bound is 0 and the upper bound is specified by the support.
:param support: The integer_interval support with a lower bound of 0 and a specified upper bound.
:param kwargs: Additional keyword arguments.
:return: A categorical probability distribution over the specified integer_interval support.
Example:
```
constraint = pyro.distributions.constraints.integer_interval(0, 2)
dist = _uniform_proposal_integer(constraint)
samples = dist.sample(torch.Size([100]))
print(dist.probs.tolist())
```
"""
if support.lower_bound != 0:
raise NotImplementedError(
"integer_interval with lower_bound > 0 not yet supported"
Expand All @@ -212,23 +147,23 @@ def random_intervention(
name: str,
) -> Callable[[torch.Tensor], torch.Tensor]:
"""
Creates a random `pyro`sample` function for a single sample site, determined by
Creates a random-valued intervention for a single sample site, determined by
by the distribution support, and site name.
:param support: The support constraint for the sample site..can take.
:param name: The name of the sample site.
:param support: The support constraint for the sample site.
:param name: The name of the auxiliary sample site.
:return: A `pyro.sample` function that takes a torch.Tensor as input
:return: A function that takes a ``torch.Tensor`` as input
and returns a random sample over the pre-specified support of the same
event shape as the input tensor.
Example:
```
support = pyro.distributions.constraints.real
name = "real_sample"
intervention_fn = random_intervention(support, name)
random_sample = intervention_fn(torch.tensor(2.0))
```
Example::
>>> support = pyro.distributions.constraints.real
>>> intervention_fn = random_intervention(support, name="random_value")
>>> with chirho.interventional.handlers.do(actions={"x": intervention_fn}):
... x = pyro.deterministic("x", torch.tensor(2.))
>>> assert x != 2
"""

def _random_intervention(value: torch.Tensor) -> torch.Tensor:
Expand All @@ -240,3 +175,32 @@ def _random_intervention(value: torch.Tensor) -> torch.Tensor:
return pyro.sample(name, proposal_dist)

return _random_intervention


@contextlib.contextmanager
def SearchForCause(
actions: Mapping[str, Intervention[T]],
*,
bias: float = 0.0,
prefix: str = "__cause_split_",
):
"""
A context manager used for a stochastic search of minimal but-for causes among potential interventions.
On each run, nodes listed in `actions` are randomly selected and intervened on with probability `.5 + bias`
(that is, preempted with probability `.5-bias`). The sampling is achieved by adding stochastic binary preemption
nodes associated with intervention candidates. If a given preemption node has value `0`, the corresponding
intervention is executed. See tests in `tests/counterfactual/test_handlers_explanation.py` for examples.
:param actions: A mapping of sites to interventions.
:param bias: The scalar bias towards not intervening. Must be between -0.5 and 0.5, defaults to 0.0.
:param prefix: A prefix used for naming additional preemption nodes. Defaults to "__cause_split_".
"""
# TODO support event_dim != 0 propagation in factual_preemption
preemptions = {
antecedent: undo_split(antecedents=[antecedent])
for antecedent in actions.keys()
}

with do(actions=actions):
with Preemptions(actions=preemptions, bias=bias, prefix=prefix):
yield
118 changes: 39 additions & 79 deletions tests/counterfactual/test_handlers_explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import pyro.infer
import pytest
import torch
from scipy.stats import spearmanr

from chirho.counterfactual.handlers.counterfactual import (
MultiWorldCounterfactual,
Expand All @@ -18,6 +17,7 @@
)
from chirho.counterfactual.ops import preempt, split
from chirho.indexed.ops import IndexSet, gather, indices_of
from chirho.interventional.ops import intervene
from chirho.observational.handlers.condition import Factors, condition


Expand Down Expand Up @@ -177,7 +177,7 @@ def model():


@pytest.mark.parametrize("plate_size", [4, 50, 200])
@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)])
@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)], ids=str)
def test_consequent_differs(plate_size, event_shape):
factors = {
"consequent": consequent_differs(
Expand Down Expand Up @@ -220,6 +220,43 @@ def model_cd():
assert nd["__factor_consequent"]["log_prob"].sum() < -1e2


SUPPORT_CASES = [
pyro.distributions.constraints.real,
pyro.distributions.constraints.boolean,
pyro.distributions.constraints.positive,
pyro.distributions.constraints.interval(0, 10),
pyro.distributions.constraints.interval(-5, 5),
pyro.distributions.constraints.integer_interval(0, 2),
pyro.distributions.constraints.integer_interval(0, 100),
]


@pytest.mark.parametrize("support", SUPPORT_CASES)
@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)], ids=str)
def test_uniform_proposal(support, event_shape):
if event_shape:
support = pyro.distributions.constraints.independent(support, len(event_shape))

uniform = uniform_proposal(support, event_shape=event_shape)
samples = uniform.sample((10,))
assert torch.all(support.check(samples))


@pytest.mark.parametrize("support", SUPPORT_CASES)
@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)], ids=str)
def test_random_intervention(support, event_shape):
if event_shape:
support = pyro.distributions.constraints.independent(support, len(event_shape))

obs_value = torch.randn(event_shape)
intervention = random_intervention(support, "samples")

with pyro.plate("draws", 10):
samples = intervene(obs_value, intervention)

assert torch.all(support.check(samples))


def stones_bayesian_model():
prob_sally_throws = pyro.sample("prob_sally_throws", dist.Beta(1, 1))
prob_bill_throws = pyro.sample("prob_bill_throws", dist.Beta(1, 1))
Expand Down Expand Up @@ -382,80 +419,3 @@ def test_SearchForCause_two_layers():
).item()

assert obs_bill_hits == 0.0 and int_bill_hits == 0.0 and int_bottle_shatters == 0.0


support_real = pyro.distributions.constraints.real
support_boolean = pyro.distributions.constraints.boolean
support_positive = pyro.distributions.constraints.positive
support_interval = pyro.distributions.constraints.interval(0, 10)
support_integer_interval = pyro.distributions.constraints.integer_interval(0, 2)
indep_constraint = pyro.distributions.constraints.independent(
pyro.distributions.constraints.real, reinterpreted_batch_ndims=1
)


@pytest.mark.parametrize(
"support",
[
support_real,
support_boolean,
support_positive,
support_interval,
support_integer_interval,
indep_constraint,
],
)
@pytest.mark.parametrize("edges", [(0, 2), (0, 100), (0, 250)])
def test_uniform_proposal(support, edges):
# plug the edges into interval constraints
if support is support_integer_interval:
support = pyro.distributions.constraints.integer_interval(*edges)
elif support is support_interval:
support = pyro.distributions.constraints.interval(*edges)

# test all but the indep_constraint
if support is not indep_constraint:
uniform = uniform_proposal(support)
with pyro.plate("samples", 50):
samples = pyro.sample("samples", uniform)

# with positive constraint, zeros are possible, but
# they don't pass `support.check`. Considered harmless.
if support is support_positive:
samples = samples[samples != 0]

assert torch.all(support.check(samples))

else: # testing the idependence constraint requires a bit more work
dist_indep = uniform_proposal(
indep_constraint, event_shape=torch.Size([2, 1000])
)
with pyro.plate("data", 2):
samples_indep = pyro.sample("samples_indep", dist_indep.expand([2]))

batch_1 = samples_indep[0].squeeze().tolist()
batch_2 = samples_indep[1].squeeze().tolist()
assert abs(spearmanr(batch_1, batch_2).correlation) < 0.2


@pytest.mark.parametrize(
"support",
[
support_real,
support_boolean,
support_positive,
support_interval,
support_integer_interval,
indep_constraint,
],
)
def test_random_intervention(support):
intervention = random_intervention(support, "samples")

with pyro.plate("draws", 1000):
samples = intervention(torch.ones(10))

if support is support_positive:
samples = samples[samples != 0]

assert torch.all(support.check(samples))

0 comments on commit 0e017cd

Please sign in to comment.