Skip to content

Commit

Permalink
Merge master into staging-dynamic (#348)
Browse files Browse the repository at this point in the history
* fix typo the the (#300)

* Adding  helper function for generating stochastic interventions with approximately uniform distributions.  (#294)

* added uniform_proposal to explanation.py

* added a test for uniform_proposal

* added _uniform_proposal_indep

* added _uniform_proposal_indep

* added test for uniform_proposal_indep, lint

* added _uniform_proposal_integer and a test, lint

* added random_intervention()

* added test for random_intervention()

* removed redundant logs

* revised uniform prop, random intervention, tests

* SearchForCause effect handler with tests thereof (#297)

* started part of cause

* part of cause in progress

* second test WIP

* fixed small typos in undo_split documentation

* first emulated second test success

* added SearchOfCause

* dealing with the second test WIP

* dealing with the layered test WIP

* two-layer test succeeds

* small lint

* lint after black update

* renamed the handler to SearchForCause

* renamed the handler in the test

* tweak tests

* simplify tests

* revert

* revert

---------

Co-authored-by: eb8680 <[email protected]>
Co-authored-by: Eli <[email protected]>

* Clean up random_intervention tests and docs (#306)

* Clean up random_intervention tests

* docstring

* import

* imports

* fix docstring

* reorder files

* Update chirho/counterfactual/handlers/explanation.py

Co-authored-by: rfl-urbaniak <[email protected]>

---------

Co-authored-by: rfl-urbaniak <[email protected]>

* make ci tests parallel (#345)

---------

Co-authored-by: Zenna Tavares <[email protected]>
Co-authored-by: rfl-urbaniak <[email protected]>
Co-authored-by: eb8680 <[email protected]>
Co-authored-by: Eli <[email protected]>
  • Loading branch information
5 people authored Oct 18, 2023
1 parent a40a2a1 commit ba0c13e
Show file tree
Hide file tree
Showing 4 changed files with 345 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,5 @@ jobs:
- name: Test
shell: bash
run: |
pytest tests/ -s --cov=chirho/ --cov-report=term-missing ${@-}
pytest tests/ -s -n auto --cov=chirho/ --cov-report=term-missing ${@-}
cd docs && make html
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ answer the following kinds of causal questions that appear frequently in
practice.

- **Interventional**: *How many COVID-19 hospitalizations will occur if
the the USA imposes a national mask mandate?*
the USA imposes a national mask mandate?*

- **Counterfactual**: *Given that 100,000 people were infected with
COVID-19 in the past month, how many would have been infected if a
Expand Down
132 changes: 128 additions & 4 deletions chirho/counterfactual/handlers/explanation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import contextlib
import functools
import itertools
from typing import Callable, Iterable, TypeVar
from typing import Callable, Iterable, Mapping, TypeVar

import torch # noqa: F401
import pyro
import torch

from chirho.counterfactual.handlers.counterfactual import Preemptions
from chirho.counterfactual.handlers.selection import get_factual_indices
from chirho.indexed.ops import IndexSet, cond, gather, indices_of, scatter
from chirho.interventional.handlers import do
from chirho.interventional.ops import Intervention

S = TypeVar("S")
T = TypeVar("T")
Expand All @@ -13,8 +19,8 @@
def undo_split(antecedents: Iterable[str] = [], event_dim: int = 0) -> Callable[[T], T]:
"""
A helper function that undoes an upstream :func:`~chirho.counterfactual.ops.split` operation,
meant to meant to be used to create arguments to pass to :func:`~chirho.interventional.ops.intervene` ,
:func:`~chirho.counterfactual.ops.split` or :func:`~chirho.counterfactual.ops.preempt` .
meant to be used to create arguments to pass to :func:`~chirho.interventional.ops.intervene` ,
:func:`~chirho.counterfactual.ops.split` or :func:`~chirho.counterfactual.ops.preempt`.
Works by gathering the factual value and scattering it back into two alternative cases.
:param antecedents: A list of upstream intervened sites which induced the :func:`split` to be reversed.
Expand Down Expand Up @@ -80,3 +86,121 @@ def _consequent_differs(consequent: T) -> torch.Tensor:
return cond(eps, 0.0, not_eq, event_dim=event_dim)

return _consequent_differs


@functools.singledispatch
def uniform_proposal(
support: pyro.distributions.constraints.Constraint,
**kwargs,
) -> pyro.distributions.Distribution:
"""
This function heuristically constructs a probability distribution over a specified
support. The choice of distribution depends on the type of support provided.
- If the support is `real`, it creates a wide Normal distribution
and standard deviation, defaulting to (0,100).
- If the support is `boolean`, it creates a Bernoulli distribution with a fixed logit of 0,
corresponding to success probability .5.
- If the support is an `interval`, the transformed distribution is centered around the
midpoint of the interval.
:param support: The support used to create the probability distribution.
:param kwargs: Additional keyword arguments.
:return: A uniform probability distribution over the specified support.
"""
if support is pyro.distributions.constraints.real:
return pyro.distributions.Normal(0, 10).mask(False)
elif support is pyro.distributions.constraints.boolean:
return pyro.distributions.Bernoulli(logits=torch.zeros(()))
else:
tfm = pyro.distributions.transforms.biject_to(support)
base = uniform_proposal(pyro.distributions.constraints.real, **kwargs)
return pyro.distributions.TransformedDistribution(base, tfm)


@uniform_proposal.register
def _uniform_proposal_indep(
support: pyro.distributions.constraints.independent,
*,
event_shape: torch.Size = torch.Size([]),
**kwargs,
) -> pyro.distributions.Distribution:
d = uniform_proposal(support.base_constraint, event_shape=event_shape, **kwargs)
return d.expand(event_shape).to_event(support.reinterpreted_batch_ndims)


@uniform_proposal.register
def _uniform_proposal_integer(
support: pyro.distributions.constraints.integer_interval,
**kwargs,
) -> pyro.distributions.Distribution:
if support.lower_bound != 0:
raise NotImplementedError(
"integer_interval with lower_bound > 0 not yet supported"
)
n = support.upper_bound - support.lower_bound + 1
return pyro.distributions.Categorical(probs=torch.ones((n,)))


def random_intervention(
support: pyro.distributions.constraints.Constraint,
name: str,
) -> Callable[[torch.Tensor], torch.Tensor]:
"""
Creates a random-valued intervention for a single sample site, determined by
by the distribution support, and site name.
:param support: The support constraint for the sample site.
:param name: The name of the auxiliary sample site.
:return: A function that takes a ``torch.Tensor`` as input
and returns a random sample over the pre-specified support of the same
event shape as the input tensor.
Example::
>>> support = pyro.distributions.constraints.real
>>> intervention_fn = random_intervention(support, name="random_value")
>>> with chirho.interventional.handlers.do(actions={"x": intervention_fn}):
... x = pyro.deterministic("x", torch.tensor(2.))
>>> assert x != 2
"""

def _random_intervention(value: torch.Tensor) -> torch.Tensor:
event_shape = value.shape[len(value.shape) - support.event_dim :]
proposal_dist = uniform_proposal(
support,
event_shape=event_shape,
)
return pyro.sample(name, proposal_dist)

return _random_intervention


@contextlib.contextmanager
def SearchForCause(
actions: Mapping[str, Intervention[T]],
*,
bias: float = 0.0,
prefix: str = "__cause_split_",
):
"""
A context manager used for a stochastic search of minimal but-for causes among potential interventions.
On each run, nodes listed in `actions` are randomly selected and intervened on with probability `.5 + bias`
(that is, preempted with probability `.5-bias`). The sampling is achieved by adding stochastic binary preemption
nodes associated with intervention candidates. If a given preemption node has value `0`, the corresponding
intervention is executed. See tests in `tests/counterfactual/test_handlers_explanation.py` for examples.
:param actions: A mapping of sites to interventions.
:param bias: The scalar bias towards not intervening. Must be between -0.5 and 0.5, defaults to 0.0.
:param prefix: A prefix used for naming additional preemption nodes. Defaults to "__cause_split_".
"""
# TODO support event_dim != 0 propagation in factual_preemption
preemptions = {
antecedent: undo_split(antecedents=[antecedent])
for antecedent in actions.keys()
}

with do(actions=actions):
with Preemptions(actions=preemptions, bias=bias, prefix=prefix):
yield
Loading

0 comments on commit ba0c13e

Please sign in to comment.