Skip to content

Commit

Permalink
Explainable undo split (#562)
Browse files Browse the repository at this point in the history
* tests for indepepdent and correct

* added print

* extra case

* debugged reverse

* debug consequen_eq_neq

* fixed test_consequent_eq_neq

* fixed the test with dimensions

* consequent_eq_neq

* three variable model

* testing three dependent

* debugging

* minimal example for three independent variables

* more three variable models

* diverge

* debugged

* notebook tested three variable models

* three variable test cases aded

* clean up

* test for factual log probs

* more clean up

* fixed a lint error

* lint clean

* reverted metadata

* ground truth for conditioning on deterministic node

* responsibility debug

* documentation commit

* responsibility example

* documentation completed

* small typos

* small changes

* improved readability

* fixed links in intro, made ac optional in description

* fixed link failure in TOC

* style changes in Moivation

* style changes in Setup

* small fixes in Causal query 1 description

* causal query 2 small fixes, one oustanding comment

* style changes in contex-sensitive...

* computation of P(f_{m, l}, f'_{m', l'} | m, l)

* double checked things

* clean_notebooks.sh run

* rough

* reveresed accidental commit

* modifed descriptions and arguments a bit

* trying out things

* clean up

* restored tutorial_i

* fix for undo_split

* tests for undo_split added

* cleanup

* lint

* lint and clean up

* lint typing error

---------

Co-authored-by: rfl-urbaniak <[email protected]>
  • Loading branch information
PoorvaGarg and rfl-urbaniak authored Aug 23, 2024
1 parent 66a8f39 commit 0553360
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 20 deletions.
36 changes: 23 additions & 13 deletions chirho/explainable/handlers/components.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
from typing import Callable, Iterable, MutableMapping, Optional, TypeVar

import pyro
Expand Down Expand Up @@ -113,26 +112,37 @@ def undo_split(
"""

def _undo_split(value: T) -> T:
antecedents_ = [
a
for a in antecedents
if a in indices_of(value, event_dim=support.event_dim)
]
antecedents_ = {
a: v
for a, v in indices_of(value, event_dim=support.event_dim).items()
if a in antecedents
}

factual_value = gather(
value,
IndexSet(**{antecedent: {0} for antecedent in antecedents_}),
IndexSet(**{antecedent: {0} for antecedent in antecedents_.keys()}),
event_dim=support.event_dim,
)

# TODO exponential in len(antecedents) - add an indexed.ops.expand to do this cheaply
index_keys: Iterable[MutableMapping[str, Iterable[int]]] = list()
for a, v in antecedents_.items():
if index_keys == []:
index_keys = [dict({a: {value}}.items()) for value in v]
else:
temp_index_keys = []
for i in index_keys:
temp_index_keys.extend(
[
dict(tuple(dict(i).items()) + tuple({a: {value}}.items()))
for value in v
]
)
index_keys = temp_index_keys
index_keys = index_keys if index_keys != [] else [{}]

return scatter_n(
{
IndexSet(
**{antecedent: {ind} for antecedent, ind in zip(antecedents_, inds)}
): factual_value
for inds in itertools.product(*[[0, 1]] * len(antecedents_))
},
{IndexSet(**ind_key): factual_value for ind_key in index_keys},
event_dim=support.event_dim,
)

Expand Down
41 changes: 34 additions & 7 deletions tests/explainable/test_handlers_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,16 @@ def test_random_intervention(support, event_shape):
assert torch.all(support.check(samples))


def test_undo_split():
@pytest.mark.parametrize("num_splits", [1, 2, 5])
def test_undo_split(num_splits):
with MultiWorldCounterfactual():
x_obs = torch.zeros(10)
x_cf_1 = torch.ones(10)
x_cf_2 = 2 * x_cf_1
x_split = split(x_obs, (x_cf_1,), name="split1", event_dim=1)
x_split = split(x_split, (x_cf_2,), name="split2", event_dim=1)
x_split = split(x_obs, (x_cf_1,) * num_splits, name="split1", event_dim=1)
x_split = split(
x_split, (x_cf_2,) * (num_splits + 1), name="split2", event_dim=1
)

undo_split2 = undo_split(
support=constraints.independent(constraints.real, 1), antecedents=["split2"]
Expand All @@ -100,9 +103,31 @@ def test_undo_split():
assert torch.all(gather(x_split, IndexSet(split2={0}), event_dim=1) == x_undone)


def test_undo_split_multi_dim():
with MultiWorldCounterfactual():
x_obs = torch.ones(10)
x_cf_1 = 2 * x_obs
x_cf_2 = 3 * x_cf_1
x_split = split(x_obs, (x_cf_1,), name="split1", event_dim=1)
x_split = split(x_split, (x_cf_2, x_cf_1, x_cf_2), name="split2", event_dim=1)
x_split = split(x_split, (x_cf_2, x_cf_1), name="split3", event_dim=1)

undo_split23 = undo_split(
support=constraints.independent(constraints.real, 1),
antecedents=["split2", "split3"],
)
x_undone = undo_split23(x_split)

assert indices_of(x_split, event_dim=1) == indices_of(x_undone, event_dim=1)
assert torch.all(
gather(x_split, IndexSet(split2={0}, split3={0}), event_dim=1) == x_undone
)


@pytest.mark.parametrize("plate_size", [4, 50, 200])
@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)])
def test_undo_split_parametrized(event_shape, plate_size):
@pytest.mark.parametrize("num_splits", [1, 2, 5])
def test_undo_split_parametrized(event_shape, plate_size, num_splits):
joint_dims = torch.Size([plate_size, *event_shape])

replace1 = torch.ones(joint_dims)
Expand All @@ -114,7 +139,9 @@ def model():
w = pyro.sample(
"w", dist.Normal(0, 1).expand(event_shape).to_event(len(event_shape))
)
w = split(w, (replace1,), name="split1", event_dim=len(event_shape))
w = split(
w, (replace1,) * num_splits, name="split1", event_dim=len(event_shape)
)

w = pyro.deterministic(
"w_preempted",
Expand Down Expand Up @@ -146,11 +173,11 @@ def model():
with mwc:
assert indices_of(
nd["w_undone"]["value"], event_dim=len(event_shape)
) == IndexSet(split1={0, 1})
) == IndexSet(split1=set(range(num_splits + 1)))

w_undone_shape = list(nd["w_undone"]["value"].shape)
desired_shape = list(
(2,)
(num_splits + 1,)
+ (1,) * (len(w_undone_shape) - len(event_shape) - 2)
+ (plate_size,)
+ event_shape
Expand Down

0 comments on commit 0553360

Please sign in to comment.