Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explainable undo split #562

Merged
merged 57 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
080e39d
tests for indepepdent and correct
PoorvaGarg Aug 1, 2024
7dd59ad
added print
rfl-urbaniak Aug 1, 2024
03516b7
extra case
rfl-urbaniak Aug 1, 2024
a7b3a8a
debugged reverse
rfl-urbaniak Aug 2, 2024
6ee8651
debug consequen_eq_neq
PoorvaGarg Aug 2, 2024
513ec6e
fixed test_consequent_eq_neq
rfl-urbaniak Aug 2, 2024
ce96b9f
fixed the test with dimensions
rfl-urbaniak Aug 2, 2024
1f8e72a
consequent_eq_neq
PoorvaGarg Aug 2, 2024
bc5ffb6
three variable model
PoorvaGarg Aug 2, 2024
5aabfc6
testing three dependent
rfl-urbaniak Aug 5, 2024
776208f
debugging
rfl-urbaniak Aug 5, 2024
c0a22c0
minimal example for three independent variables
PoorvaGarg Aug 5, 2024
1911de2
more three variable models
PoorvaGarg Aug 5, 2024
84381e2
diverge
PoorvaGarg Aug 5, 2024
bbc121c
debugged
rfl-urbaniak Aug 5, 2024
498f070
notebook tested three variable models
PoorvaGarg Aug 6, 2024
1e65c7d
three variable test cases aded
PoorvaGarg Aug 6, 2024
4058660
clean up
PoorvaGarg Aug 6, 2024
34d0faf
test for factual log probs
PoorvaGarg Aug 6, 2024
9326a52
more clean up
PoorvaGarg Aug 6, 2024
45e75d6
fixed a lint error
PoorvaGarg Aug 6, 2024
075c33a
lint clean
PoorvaGarg Aug 6, 2024
dde4d36
reverted metadata
PoorvaGarg Aug 7, 2024
75e9f05
ground truth for conditioning on deterministic node
PoorvaGarg Aug 8, 2024
7e2501a
responsibility debug
PoorvaGarg Aug 8, 2024
2a38798
documentation commit
PoorvaGarg Aug 9, 2024
9254529
responsibility example
PoorvaGarg Aug 12, 2024
61aa26b
documentation completed
PoorvaGarg Aug 12, 2024
58842f6
small typos
rfl-urbaniak Aug 12, 2024
9ee3068
small changes
PoorvaGarg Aug 13, 2024
b5148e2
Merge branch 'documentation_explainability' of https://github.com/Bas…
PoorvaGarg Aug 13, 2024
55f1782
improved readability
PoorvaGarg Aug 13, 2024
8676cd7
fixed links in intro, made ac optional in description
rfl-urbaniak Aug 14, 2024
773d085
fixed link failure in TOC
rfl-urbaniak Aug 14, 2024
a0158f4
style changes in Moivation
rfl-urbaniak Aug 14, 2024
c4d8ed4
style changes in Setup
rfl-urbaniak Aug 14, 2024
ff37278
small fixes in Causal query 1 description
rfl-urbaniak Aug 14, 2024
1269aa1
causal query 2 small fixes, one oustanding comment
rfl-urbaniak Aug 14, 2024
411a501
style changes in contex-sensitive...
rfl-urbaniak Aug 14, 2024
61bb476
computation of P(f_{m, l}, f'_{m', l'} | m, l)
PoorvaGarg Aug 14, 2024
228eebc
double checked things
PoorvaGarg Aug 14, 2024
fc199cc
clean_notebooks.sh run
PoorvaGarg Aug 15, 2024
3034bd6
rough
PoorvaGarg Aug 15, 2024
75ddfaf
reveresed accidental commit
PoorvaGarg Aug 15, 2024
2a4b3a9
modifed descriptions and arguments a bit
rfl-urbaniak Aug 15, 2024
78c9aae
trying out things
PoorvaGarg Aug 16, 2024
225077c
Merge branch 'master' into documentation_explainability
PoorvaGarg Aug 16, 2024
de604b0
clean up
PoorvaGarg Aug 16, 2024
e40e6c1
restored tutorial_i
PoorvaGarg Aug 16, 2024
1c36d31
fix for undo_split
PoorvaGarg Aug 19, 2024
a5f20aa
tests for undo_split added
PoorvaGarg Aug 19, 2024
3e36b49
Merge branch 'documentation_explainability' into explainable-continuo…
PoorvaGarg Aug 19, 2024
97b5ffa
cleanup
PoorvaGarg Aug 19, 2024
7e1344c
lint
PoorvaGarg Aug 19, 2024
05dcf16
lint and clean up
PoorvaGarg Aug 19, 2024
c0f05e0
lint typing error
PoorvaGarg Aug 19, 2024
b5e6926
Merge branch 'master' into explainable-undo_split
PoorvaGarg Aug 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions chirho/explainable/handlers/components.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
from typing import Callable, Iterable, MutableMapping, Optional, TypeVar

import pyro
Expand Down Expand Up @@ -113,26 +112,37 @@ def undo_split(
"""

def _undo_split(value: T) -> T:
antecedents_ = [
a
for a in antecedents
if a in indices_of(value, event_dim=support.event_dim)
]
antecedents_ = {
a: v
for a, v in indices_of(value, event_dim=support.event_dim).items()
if a in antecedents
}

factual_value = gather(
value,
IndexSet(**{antecedent: {0} for antecedent in antecedents_}),
IndexSet(**{antecedent: {0} for antecedent in antecedents_.keys()}),
event_dim=support.event_dim,
)

# TODO exponential in len(antecedents) - add an indexed.ops.expand to do this cheaply
index_keys: Iterable[MutableMapping[str, Iterable[int]]] = list()
for a, v in antecedents_.items():
if index_keys == []:
index_keys = [dict({a: {value}}.items()) for value in v]
else:
temp_index_keys = []
for i in index_keys:
temp_index_keys.extend(
[
dict(tuple(dict(i).items()) + tuple({a: {value}}.items()))
for value in v
]
)
index_keys = temp_index_keys
index_keys = index_keys if index_keys != [] else [{}]

return scatter_n(
{
IndexSet(
**{antecedent: {ind} for antecedent, ind in zip(antecedents_, inds)}
): factual_value
for inds in itertools.product(*[[0, 1]] * len(antecedents_))
},
{IndexSet(**ind_key): factual_value for ind_key in index_keys},
event_dim=support.event_dim,
)

Expand Down
41 changes: 34 additions & 7 deletions tests/explainable/test_handlers_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,16 @@ def test_random_intervention(support, event_shape):
assert torch.all(support.check(samples))


def test_undo_split():
@pytest.mark.parametrize("num_splits", [1, 2, 5])
def test_undo_split(num_splits):
with MultiWorldCounterfactual():
x_obs = torch.zeros(10)
x_cf_1 = torch.ones(10)
x_cf_2 = 2 * x_cf_1
x_split = split(x_obs, (x_cf_1,), name="split1", event_dim=1)
x_split = split(x_split, (x_cf_2,), name="split2", event_dim=1)
x_split = split(x_obs, (x_cf_1,) * num_splits, name="split1", event_dim=1)
x_split = split(
x_split, (x_cf_2,) * (num_splits + 1), name="split2", event_dim=1
)

undo_split2 = undo_split(
support=constraints.independent(constraints.real, 1), antecedents=["split2"]
Expand All @@ -100,9 +103,31 @@ def test_undo_split():
assert torch.all(gather(x_split, IndexSet(split2={0}), event_dim=1) == x_undone)


def test_undo_split_multi_dim():
with MultiWorldCounterfactual():
x_obs = torch.ones(10)
x_cf_1 = 2 * x_obs
x_cf_2 = 3 * x_cf_1
x_split = split(x_obs, (x_cf_1,), name="split1", event_dim=1)
x_split = split(x_split, (x_cf_2, x_cf_1, x_cf_2), name="split2", event_dim=1)
x_split = split(x_split, (x_cf_2, x_cf_1), name="split3", event_dim=1)

undo_split23 = undo_split(
support=constraints.independent(constraints.real, 1),
antecedents=["split2", "split3"],
)
x_undone = undo_split23(x_split)

assert indices_of(x_split, event_dim=1) == indices_of(x_undone, event_dim=1)
assert torch.all(
gather(x_split, IndexSet(split2={0}, split3={0}), event_dim=1) == x_undone
)


@pytest.mark.parametrize("plate_size", [4, 50, 200])
@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)])
def test_undo_split_parametrized(event_shape, plate_size):
@pytest.mark.parametrize("num_splits", [1, 2, 5])
def test_undo_split_parametrized(event_shape, plate_size, num_splits):
joint_dims = torch.Size([plate_size, *event_shape])

replace1 = torch.ones(joint_dims)
Expand All @@ -114,7 +139,9 @@ def model():
w = pyro.sample(
"w", dist.Normal(0, 1).expand(event_shape).to_event(len(event_shape))
)
w = split(w, (replace1,), name="split1", event_dim=len(event_shape))
w = split(
w, (replace1,) * num_splits, name="split1", event_dim=len(event_shape)
)

w = pyro.deterministic(
"w_preempted",
Expand Down Expand Up @@ -146,11 +173,11 @@ def model():
with mwc:
assert indices_of(
nd["w_undone"]["value"], event_dim=len(event_shape)
) == IndexSet(split1={0, 1})
) == IndexSet(split1=set(range(num_splits + 1)))

w_undone_shape = list(nd["w_undone"]["value"].shape)
desired_shape = list(
(2,)
(num_splits + 1,)
+ (1,) * (len(w_undone_shape) - len(event_shape) - 2)
+ (plate_size,)
+ event_shape
Expand Down
Loading