Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up random_intervention tests and docs #306

Merged
merged 8 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 42 additions & 78 deletions chirho/counterfactual/handlers/explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Callable, Iterable, Mapping, TypeVar

import pyro
import torch # noqa: F401
import torch

from chirho.counterfactual.handlers.counterfactual import Preemptions
from chirho.counterfactual.handlers.selection import get_factual_indices
Expand Down Expand Up @@ -88,35 +88,6 @@ def _consequent_differs(consequent: T) -> torch.Tensor:
return _consequent_differs


@contextlib.contextmanager
def SearchForCause(
actions: Mapping[str, Intervention[T]],
*,
bias: float = 0.0,
prefix: str = "__cause_split_",
):
"""
A context manager used for a stochastic search of minimal but-for causes among potential interventions.
On each run, nodes listed in `actions` are randomly seleted and intervened on with probability `.5 + bias`
(that is, preempted with probability `.5-bias`). The sampling is achieved by adding stochastic binary preemption
nodes associated with intervention candidates. If a given preemption node has value `0`, the corresponding
intervention is executed. See tests in `tests/counterfactual/test_handlers_explanation.py` for examples.

:param actions: A mapping of sites to interventions.
:param bias: The scalar bias towards not intervening. Must be between -0.5 and 0.5, defaults to 0.0.
:param prefix: A prefix used for naming additional preemption nodes. Defaults to "__cause_split_".
"""
# TODO support event_dim != 0 propagation in factual_preemption
preemptions = {
antecedent: undo_split(antecedents=[antecedent])
for antecedent in actions.keys()
}

with do(actions=actions):
with Preemptions(actions=preemptions, bias=bias, prefix=prefix):
yield


@functools.singledispatch
def uniform_proposal(
support: pyro.distributions.constraints.Constraint,
Expand All @@ -138,7 +109,7 @@ def uniform_proposal(
:return: A uniform probability distribution over the specified support.
"""
if support is pyro.distributions.constraints.real:
return pyro.distributions.Normal(0, 100).mask(False)
return pyro.distributions.Normal(0, 10).mask(False)
elif support is pyro.distributions.constraints.boolean:
return pyro.distributions.Bernoulli(logits=torch.zeros(()))
else:
Expand All @@ -154,26 +125,6 @@ def _uniform_proposal_indep(
event_shape: torch.Size = torch.Size([]),
**kwargs,
) -> pyro.distributions.Distribution:
"""
This constructs a probability distribution with independent dimensions
over a specified support. The choice of distribution depends on the type of support provided
(see the documentation for `uniform_proposal`).

:param support: The support used to create the probability distribution.
:param event_shape: The event shape specifying the dimensions of the distribution.
:param kwargs: Additional keyword arguments.
:return: A probability distribution with independent dimensions over the specified support.

Example:
```
indep_constraint = pyro.distributions.constraints.independent(
pyro.distributions.constraints.real, reinterpreted_batch_ndims=2)
dist = uniform_proposal(indep_constraint, event_shape=torch.Size([2, 3]))
with pyro.plate("data", 3):
samples_indep = pyro.sample("samples_indep", dist.expand([4, 2, 3]))
```
"""

d = uniform_proposal(support.base_constraint, event_shape=event_shape, **kwargs)
return d.expand(event_shape).to_event(support.reinterpreted_batch_ndims)

Expand All @@ -183,22 +134,6 @@ def _uniform_proposal_integer(
support: pyro.distributions.constraints.integer_interval,
**kwargs,
) -> pyro.distributions.Distribution:
"""
This constructs a uniform categorical distribution over an integer_interval support
where the lower bound is 0 and the upper bound is specified by the support.

:param support: The integer_interval support with a lower bound of 0 and a specified upper bound.
:param kwargs: Additional keyword arguments.
:return: A categorical probability distribution over the specified integer_interval support.

Example:
```
constraint = pyro.distributions.constraints.integer_interval(0, 2)
dist = _uniform_proposal_integer(constraint)
samples = dist.sample(torch.Size([100]))
print(dist.probs.tolist())
```
"""
if support.lower_bound != 0:
raise NotImplementedError(
"integer_interval with lower_bound > 0 not yet supported"
Expand All @@ -212,23 +147,23 @@ def random_intervention(
name: str,
) -> Callable[[torch.Tensor], torch.Tensor]:
"""
Creates a random `pyro`sample` function for a single sample site, determined by
Creates a random-valued intervention for a single sample site, determined by
by the distribution support, and site name.

:param support: The support constraint for the sample site..can take.
:param name: The name of the sample site.
:param support: The support constraint for the sample site.
:param name: The name of the auxiliary sample site.

:return: A `pyro.sample` function that takes a torch.Tensor as input
:return: A function that takes a ``torch.Tensor`` as input
and returns a random sample over the pre-specified support of the same
event shape as the input tensor.

Example:
```
support = pyro.distributions.constraints.real
name = "real_sample"
intervention_fn = random_intervention(support, name)
random_sample = intervention_fn(torch.tensor(2.0))
```
Example::

>>> support = pyro.distributions.constraints.real
>>> intervention_fn = random_intervention(support, name="random_value")
>>> with chirho.interventional.handlers.do(actions={"x": intervention_fn}):
... x = pyro.deterministic("x", torch.tensor(2.))
>>> assert x != 2
"""

def _random_intervention(value: torch.Tensor) -> torch.Tensor:
Expand All @@ -240,3 +175,32 @@ def _random_intervention(value: torch.Tensor) -> torch.Tensor:
return pyro.sample(name, proposal_dist)

return _random_intervention


@contextlib.contextmanager
def SearchForCause(
actions: Mapping[str, Intervention[T]],
*,
bias: float = 0.0,
prefix: str = "__cause_split_",
):
"""
A context manager used for a stochastic search of minimal but-for causes among potential interventions.
On each run, nodes listed in `actions` are randomly seleted and intervened on with probability `.5 + bias`
eb8680 marked this conversation as resolved.
Show resolved Hide resolved
(that is, preempted with probability `.5-bias`). The sampling is achieved by adding stochastic binary preemption
nodes associated with intervention candidates. If a given preemption node has value `0`, the corresponding
intervention is executed. See tests in `tests/counterfactual/test_handlers_explanation.py` for examples.

:param actions: A mapping of sites to interventions.
:param bias: The scalar bias towards not intervening. Must be between -0.5 and 0.5, defaults to 0.0.
:param prefix: A prefix used for naming additional preemption nodes. Defaults to "__cause_split_".
"""
# TODO support event_dim != 0 propagation in factual_preemption
preemptions = {
antecedent: undo_split(antecedents=[antecedent])
for antecedent in actions.keys()
}

with do(actions=actions):
with Preemptions(actions=preemptions, bias=bias, prefix=prefix):
yield
118 changes: 39 additions & 79 deletions tests/counterfactual/test_handlers_explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import pyro.infer
import pytest
import torch
from scipy.stats import spearmanr

from chirho.counterfactual.handlers.counterfactual import (
MultiWorldCounterfactual,
Expand All @@ -18,6 +17,7 @@
)
from chirho.counterfactual.ops import preempt, split
from chirho.indexed.ops import IndexSet, gather, indices_of
from chirho.interventional.ops import intervene
from chirho.observational.handlers.condition import Factors, condition


Expand Down Expand Up @@ -177,7 +177,7 @@ def model():


@pytest.mark.parametrize("plate_size", [4, 50, 200])
@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)])
@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)], ids=str)
def test_consequent_differs(plate_size, event_shape):
factors = {
"consequent": consequent_differs(
Expand Down Expand Up @@ -220,6 +220,43 @@ def model_cd():
assert nd["__factor_consequent"]["log_prob"].sum() < -1e2


SUPPORT_CASES = [
pyro.distributions.constraints.real,
pyro.distributions.constraints.boolean,
pyro.distributions.constraints.positive,
pyro.distributions.constraints.interval(0, 10),
pyro.distributions.constraints.interval(-5, 5),
pyro.distributions.constraints.integer_interval(0, 2),
pyro.distributions.constraints.integer_interval(0, 100),
]


@pytest.mark.parametrize("support", SUPPORT_CASES)
@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)], ids=str)
def test_uniform_proposal(support, event_shape):
if event_shape:
support = pyro.distributions.constraints.independent(support, len(event_shape))

uniform = uniform_proposal(support, event_shape=event_shape)
samples = uniform.sample((10,))
assert torch.all(support.check(samples))


@pytest.mark.parametrize("support", SUPPORT_CASES)
@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)], ids=str)
def test_random_intervention(support, event_shape):
if event_shape:
support = pyro.distributions.constraints.independent(support, len(event_shape))

obs_value = torch.randn(event_shape)
intervention = random_intervention(support, "samples")

with pyro.plate("draws", 10):
samples = intervene(obs_value, intervention)

assert torch.all(support.check(samples))


def stones_bayesian_model():
prob_sally_throws = pyro.sample("prob_sally_throws", dist.Beta(1, 1))
prob_bill_throws = pyro.sample("prob_bill_throws", dist.Beta(1, 1))
Expand Down Expand Up @@ -382,80 +419,3 @@ def test_SearchForCause_two_layers():
).item()

assert obs_bill_hits == 0.0 and int_bill_hits == 0.0 and int_bottle_shatters == 0.0


support_real = pyro.distributions.constraints.real
support_boolean = pyro.distributions.constraints.boolean
support_positive = pyro.distributions.constraints.positive
support_interval = pyro.distributions.constraints.interval(0, 10)
support_integer_interval = pyro.distributions.constraints.integer_interval(0, 2)
indep_constraint = pyro.distributions.constraints.independent(
pyro.distributions.constraints.real, reinterpreted_batch_ndims=1
)


@pytest.mark.parametrize(
"support",
[
support_real,
support_boolean,
support_positive,
support_interval,
support_integer_interval,
indep_constraint,
],
)
@pytest.mark.parametrize("edges", [(0, 2), (0, 100), (0, 250)])
def test_uniform_proposal(support, edges):
# plug the edges into interval constraints
if support is support_integer_interval:
support = pyro.distributions.constraints.integer_interval(*edges)
elif support is support_interval:
support = pyro.distributions.constraints.interval(*edges)

# test all but the indep_constraint
if support is not indep_constraint:
uniform = uniform_proposal(support)
with pyro.plate("samples", 50):
samples = pyro.sample("samples", uniform)

# with positive constraint, zeros are possible, but
# they don't pass `support.check`. Considered harmless.
if support is support_positive:
samples = samples[samples != 0]

assert torch.all(support.check(samples))

else: # testing the idependence constraint requires a bit more work
dist_indep = uniform_proposal(
indep_constraint, event_shape=torch.Size([2, 1000])
)
with pyro.plate("data", 2):
samples_indep = pyro.sample("samples_indep", dist_indep.expand([2]))

batch_1 = samples_indep[0].squeeze().tolist()
batch_2 = samples_indep[1].squeeze().tolist()
assert abs(spearmanr(batch_1, batch_2).correlation) < 0.2


@pytest.mark.parametrize(
"support",
[
support_real,
support_boolean,
support_positive,
support_interval,
support_integer_interval,
indep_constraint,
],
)
def test_random_intervention(support):
intervention = random_intervention(support, "samples")

with pyro.plate("draws", 1000):
samples = intervention(torch.ones(10))

if support is support_positive:
samples = samples[samples != 0]

assert torch.all(support.check(samples))