From 0553360c0df9f8ed305d672fa8e4488f7709604e Mon Sep 17 00:00:00 2001 From: PoorvaGarg <40417664+PoorvaGarg@users.noreply.github.com> Date: Fri, 23 Aug 2024 15:15:23 -0400 Subject: [PATCH] Explainable undo split (#562) * tests for indepepdent and correct * added print * extra case * debugged reverse * debug consequen_eq_neq * fixed test_consequent_eq_neq * fixed the test with dimensions * consequent_eq_neq * three variable model * testing three dependent * debugging * minimal example for three independent variables * more three variable models * diverge * debugged * notebook tested three variable models * three variable test cases aded * clean up * test for factual log probs * more clean up * fixed a lint error * lint clean * reverted metadata * ground truth for conditioning on deterministic node * responsibility debug * documentation commit * responsibility example * documentation completed * small typos * small changes * improved readability * fixed links in intro, made ac optional in description * fixed link failure in TOC * style changes in Moivation * style changes in Setup * small fixes in Causal query 1 description * causal query 2 small fixes, one oustanding comment * style changes in contex-sensitive... * computation of P(f_{m, l}, f'_{m', l'} | m, l) * double checked things * clean_notebooks.sh run * rough * reveresed accidental commit * modifed descriptions and arguments a bit * trying out things * clean up * restored tutorial_i * fix for undo_split * tests for undo_split added * cleanup * lint * lint and clean up * lint typing error --------- Co-authored-by: rfl-urbaniak <rfl.urbaniak@gmail.com> --- chirho/explainable/handlers/components.py | 36 ++++++++++------ tests/explainable/test_handlers_components.py | 41 +++++++++++++++---- 2 files changed, 57 insertions(+), 20 deletions(-) diff --git a/chirho/explainable/handlers/components.py b/chirho/explainable/handlers/components.py index 4db8fb42..22042dbc 100644 --- a/chirho/explainable/handlers/components.py +++ b/chirho/explainable/handlers/components.py @@ -1,4 +1,3 @@ -import itertools from typing import Callable, Iterable, MutableMapping, Optional, TypeVar import pyro @@ -113,26 +112,37 @@ def undo_split( """ def _undo_split(value: T) -> T: - antecedents_ = [ - a - for a in antecedents - if a in indices_of(value, event_dim=support.event_dim) - ] + antecedents_ = { + a: v + for a, v in indices_of(value, event_dim=support.event_dim).items() + if a in antecedents + } factual_value = gather( value, - IndexSet(**{antecedent: {0} for antecedent in antecedents_}), + IndexSet(**{antecedent: {0} for antecedent in antecedents_.keys()}), event_dim=support.event_dim, ) # TODO exponential in len(antecedents) - add an indexed.ops.expand to do this cheaply + index_keys: Iterable[MutableMapping[str, Iterable[int]]] = list() + for a, v in antecedents_.items(): + if index_keys == []: + index_keys = [dict({a: {value}}.items()) for value in v] + else: + temp_index_keys = [] + for i in index_keys: + temp_index_keys.extend( + [ + dict(tuple(dict(i).items()) + tuple({a: {value}}.items())) + for value in v + ] + ) + index_keys = temp_index_keys + index_keys = index_keys if index_keys != [] else [{}] + return scatter_n( - { - IndexSet( - **{antecedent: {ind} for antecedent, ind in zip(antecedents_, inds)} - ): factual_value - for inds in itertools.product(*[[0, 1]] * len(antecedents_)) - }, + {IndexSet(**ind_key): factual_value for ind_key in index_keys}, event_dim=support.event_dim, ) diff --git a/tests/explainable/test_handlers_components.py b/tests/explainable/test_handlers_components.py index cc3d4ead..d3839d9a 100644 --- a/tests/explainable/test_handlers_components.py +++ b/tests/explainable/test_handlers_components.py @@ -83,13 +83,16 @@ def test_random_intervention(support, event_shape): assert torch.all(support.check(samples)) -def test_undo_split(): +@pytest.mark.parametrize("num_splits", [1, 2, 5]) +def test_undo_split(num_splits): with MultiWorldCounterfactual(): x_obs = torch.zeros(10) x_cf_1 = torch.ones(10) x_cf_2 = 2 * x_cf_1 - x_split = split(x_obs, (x_cf_1,), name="split1", event_dim=1) - x_split = split(x_split, (x_cf_2,), name="split2", event_dim=1) + x_split = split(x_obs, (x_cf_1,) * num_splits, name="split1", event_dim=1) + x_split = split( + x_split, (x_cf_2,) * (num_splits + 1), name="split2", event_dim=1 + ) undo_split2 = undo_split( support=constraints.independent(constraints.real, 1), antecedents=["split2"] @@ -100,9 +103,31 @@ def test_undo_split(): assert torch.all(gather(x_split, IndexSet(split2={0}), event_dim=1) == x_undone) +def test_undo_split_multi_dim(): + with MultiWorldCounterfactual(): + x_obs = torch.ones(10) + x_cf_1 = 2 * x_obs + x_cf_2 = 3 * x_cf_1 + x_split = split(x_obs, (x_cf_1,), name="split1", event_dim=1) + x_split = split(x_split, (x_cf_2, x_cf_1, x_cf_2), name="split2", event_dim=1) + x_split = split(x_split, (x_cf_2, x_cf_1), name="split3", event_dim=1) + + undo_split23 = undo_split( + support=constraints.independent(constraints.real, 1), + antecedents=["split2", "split3"], + ) + x_undone = undo_split23(x_split) + + assert indices_of(x_split, event_dim=1) == indices_of(x_undone, event_dim=1) + assert torch.all( + gather(x_split, IndexSet(split2={0}, split3={0}), event_dim=1) == x_undone + ) + + @pytest.mark.parametrize("plate_size", [4, 50, 200]) @pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)]) -def test_undo_split_parametrized(event_shape, plate_size): +@pytest.mark.parametrize("num_splits", [1, 2, 5]) +def test_undo_split_parametrized(event_shape, plate_size, num_splits): joint_dims = torch.Size([plate_size, *event_shape]) replace1 = torch.ones(joint_dims) @@ -114,7 +139,9 @@ def model(): w = pyro.sample( "w", dist.Normal(0, 1).expand(event_shape).to_event(len(event_shape)) ) - w = split(w, (replace1,), name="split1", event_dim=len(event_shape)) + w = split( + w, (replace1,) * num_splits, name="split1", event_dim=len(event_shape) + ) w = pyro.deterministic( "w_preempted", @@ -146,11 +173,11 @@ def model(): with mwc: assert indices_of( nd["w_undone"]["value"], event_dim=len(event_shape) - ) == IndexSet(split1={0, 1}) + ) == IndexSet(split1=set(range(num_splits + 1))) w_undone_shape = list(nd["w_undone"]["value"].shape) desired_shape = list( - (2,) + (num_splits + 1,) + (1,) * (len(w_undone_shape) - len(event_shape) - 2) + (plate_size,) + event_shape