Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…into sw-dynamic-multi-tutorial
  • Loading branch information
SamWitty committed Aug 21, 2024
2 parents ba4dee4 + 7f4fde3 commit 165bc56
Show file tree
Hide file tree
Showing 35 changed files with 5,115 additions and 1,275 deletions.
2 changes: 2 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ approximations that power much of the modern probabilistic machine learning land
- Causal inference with deep models and proxy variables
- `Example: Mediation analysis and (in)direct effects <https://basisresearch.github.io/chirho/mediation.html>`_
- Mediation analysis for path specific effects
- `Example: Estimating causal effects using instrumental variables <https://basisresearch.github.io/chirho/instrumental_var.html>`_
- Causal effect estimation with instrumental variables
- `Example: Deep structural causal model counterfactuals <https://basisresearch.github.io/chirho/deepscm.html>`_
- Counterfactuals with normalizing flows
- `Example: Structured Latent Confounders <https://basisresearch.github.io/chirho/slc.html>`_
Expand Down
8 changes: 6 additions & 2 deletions chirho/explainable/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from .components import random_intervention # noqa: F401
from .components import ExtractSupports, undo_split # noqa: F401
from .components import ( # noqa: F401
ExtractSupports,
random_intervention,
sufficiency_intervention,
undo_split,
)
from .explanation import SearchForExplanation, SplitSubsets # noqa: F401
from .preemptions import Preemptions # noqa: F401
189 changes: 181 additions & 8 deletions chirho/explainable/handlers/components.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import itertools
from typing import Callable, Iterable, MutableMapping, TypeVar
from typing import Callable, Iterable, MutableMapping, Optional, TypeVar

import pyro
import pyro.distributions.constraints as constraints
Expand All @@ -8,16 +8,61 @@
from chirho.counterfactual.handlers.selection import get_factual_indices
from chirho.explainable.internals import uniform_proposal
from chirho.indexed.ops import IndexSet, gather, indices_of, scatter_n
from chirho.observational.handlers import soft_neq
from chirho.observational.handlers import soft_eq, soft_neq
from chirho.observational.ops import Observation

S = TypeVar("S")
T = TypeVar("T")


def sufficiency_intervention(
support: constraints.Constraint,
antecedents: Iterable[str] = [],
sufficiency_world=2,
) -> Callable[[T], T]:
"""
Creates a sufficiency intervention for a single sample site, determined by
the site name, intervening to keep the value as in the factual world with
respect to the antecedents.
:param support: The support constraint for the site.
:param name: The sample site name.
:return: A function that takes a `torch.Tensor` as input
and returns the factual value at the named site as a tensor.
Example::
>>> with MultiWorldCounterfactual() as mwc:
>>> value = pyro.sample("value", proposal_dist)
>>> intervention = sufficiency_intervention(support)
>>> value = intervene(value, intervention)
"""

def _sufficiency_intervention(value: T) -> T:

indices = IndexSet(
**{
name: sufficiency_world
for name, ind in get_factual_indices().items()
if name in antecedents
}
)

sufficiency_value = gather(
value,
indices,
event_dim=support.event_dim,
)
return sufficiency_value

return _sufficiency_intervention


def random_intervention(
support: constraints.Constraint,
name: str,
) -> Callable[[torch.Tensor], torch.Tensor]:
) -> Callable[[T], T]:
"""
Creates a random-valued intervention for a single sample site, determined by
by the distribution support, and site name.
Expand All @@ -38,8 +83,10 @@ def random_intervention(
>>> assert x != 2
"""

def _random_intervention(value: torch.Tensor) -> torch.Tensor:
event_shape = value.shape[len(value.shape) - support.event_dim :]
def _random_intervention(value: T) -> T:

event_shape = value.shape[len(value.shape) - support.event_dim :] # type: ignore

proposal_dist = uniform_proposal(
support,
event_shape=event_shape,
Expand Down Expand Up @@ -92,7 +139,43 @@ def _undo_split(value: T) -> T:
return _undo_split


def consequent_differs(
def consequent_eq(
support: constraints.Constraint,
antecedents: Iterable[str] = [],
**kwargs,
) -> Callable[[T], torch.Tensor]:
"""
A helper function for assessing whether values at a site are close to their observed values, assigning
a small negative value close to zero if a value is close to its observed state and a large negative value otherwise.
:param support: The support constraint for the consequent site.
:param antecedents: A list of names of upstream intervened sites to consider when assessing similarity.
:return: A callable which applied to a site value object (``consequent``), returns a tensor where each
element indicates the extent to which the corresponding element of ``consequent``
is close to its factual value.
"""

def _consequent_eq(consequent: T) -> torch.Tensor:
indices = IndexSet(
**{
name: ind
for name, ind in get_factual_indices().items()
if name in antecedents
}
)
eq = soft_eq(
support,
consequent,
gather(consequent, indices, event_dim=support.event_dim),
**kwargs,
)
return eq

return _consequent_eq


def consequent_neq(
support: constraints.Constraint,
antecedents: Iterable[str] = [],
**kwargs,
Expand All @@ -109,7 +192,7 @@ def consequent_differs(
element indicates whether the corresponding element of ``consequent`` differs from its factual value.
"""

def _consequent_differs(consequent: T) -> torch.Tensor:
def _consequent_neq(consequent: T) -> torch.Tensor:
indices = IndexSet(
**{
name: ind
Expand All @@ -125,7 +208,97 @@ def _consequent_differs(consequent: T) -> torch.Tensor:
)
return diff

return _consequent_differs
return _consequent_neq


def consequent_eq_neq(
support: constraints.Constraint,
proposed_consequent: Optional[Observation[T]],
antecedents: Iterable[str] = [],
**kwargs,
) -> Callable[[T], torch.Tensor]:
"""
A helper function for obtaining joint log prob of necessity and sufficiency. Assumes that
the necessity intervention has been applied in counterfactual world 1 and sufficiency intervention in
counterfactual world 2 (these can be passed as kwargs).
:param support: The support constraint for the consequent site.
:param antecedents: A list of names of upstream intervened sites to consider when composing the joint log prob.
:return: A callable which applied to a site value object (``consequent``), returns a tensor with log prob sums
of values resulting from necessity and sufficiency interventions, in appropriate counterfactual worlds.
"""

def _consequent_eq_neq(consequent: T) -> torch.Tensor:
necessity_world = kwargs.get("necessity_world", 1)
sufficiency_world = kwargs.get("sufficiency_world", 2)

necessity_indices = IndexSet(
**{
name: {necessity_world}
for name in indices_of(consequent, event_dim=support.event_dim).keys()
if name in antecedents
}
)
sufficiency_indices = IndexSet(
**{
name: {sufficiency_world}
for name in indices_of(consequent, event_dim=support.event_dim).keys()
if name in antecedents
}
)

necessity_value = gather(
consequent, necessity_indices, event_dim=support.event_dim
)
sufficiency_value = gather(
consequent, sufficiency_indices, event_dim=support.event_dim
)

necessity_log_probs = (
soft_neq(
support,
necessity_value,
proposed_consequent,
**kwargs,
)
if proposed_consequent is not None
else soft_neq(
support,
necessity_value,
sufficiency_value,
**kwargs,
)
)
sufficiency_log_probs = (
soft_eq(support, sufficiency_value, proposed_consequent, **kwargs)
if proposed_consequent is not None
else torch.zeros_like(necessity_log_probs)
)

FACTUAL_NEC_SUFF = torch.zeros_like(sufficiency_log_probs)

index_keys = set(antecedents)
null_index = IndexSet(**{name: {0} for name in index_keys})

nec_suff_log_probs_partitioned = {
**{null_index: FACTUAL_NEC_SUFF},
**{
IndexSet(**{antecedent: {ind} for antecedent in index_keys}): log_prob
for ind, log_prob in zip(
[necessity_world, sufficiency_world],
[necessity_log_probs, sufficiency_log_probs],
)
},
}

new_value = scatter_n(
nec_suff_log_probs_partitioned,
event_dim=0,
)
return new_value

return _consequent_eq_neq


class ExtractSupports(pyro.poutine.messenger.Messenger):
Expand Down
Loading

0 comments on commit 165bc56

Please sign in to comment.