From 0553360c0df9f8ed305d672fa8e4488f7709604e Mon Sep 17 00:00:00 2001
From: PoorvaGarg <40417664+PoorvaGarg@users.noreply.github.com>
Date: Fri, 23 Aug 2024 15:15:23 -0400
Subject: [PATCH] Explainable undo split (#562)

* tests for indepepdent and correct

* added print

* extra case

* debugged reverse

* debug consequen_eq_neq

* fixed test_consequent_eq_neq

* fixed the test with dimensions

* consequent_eq_neq

* three variable model

* testing three dependent

* debugging

* minimal example for three independent variables

* more three variable models

* diverge

* debugged

* notebook tested three variable models

* three variable test cases aded

* clean up

* test for factual log probs

* more clean up

* fixed a lint error

* lint clean

* reverted metadata

* ground truth for conditioning on deterministic node

* responsibility debug

* documentation commit

* responsibility example

* documentation completed

* small typos

* small changes

* improved readability

* fixed links in intro, made ac optional in description

* fixed link failure in TOC

* style changes in Moivation

* style changes in Setup

* small fixes in Causal query 1 description

* causal query 2 small fixes, one oustanding comment

* style changes in contex-sensitive...

* computation of P(f_{m, l}, f'_{m', l'} | m, l)

* double checked things

* clean_notebooks.sh run

* rough

* reveresed accidental commit

* modifed descriptions and arguments a bit

* trying out things

* clean up

* restored tutorial_i

* fix for undo_split

* tests for undo_split added

* cleanup

* lint

* lint and clean up

* lint typing error

---------

Co-authored-by: rfl-urbaniak <rfl.urbaniak@gmail.com>
---
 chirho/explainable/handlers/components.py     | 36 ++++++++++------
 tests/explainable/test_handlers_components.py | 41 +++++++++++++++----
 2 files changed, 57 insertions(+), 20 deletions(-)

diff --git a/chirho/explainable/handlers/components.py b/chirho/explainable/handlers/components.py
index 4db8fb42..22042dbc 100644
--- a/chirho/explainable/handlers/components.py
+++ b/chirho/explainable/handlers/components.py
@@ -1,4 +1,3 @@
-import itertools
 from typing import Callable, Iterable, MutableMapping, Optional, TypeVar
 
 import pyro
@@ -113,26 +112,37 @@ def undo_split(
     """
 
     def _undo_split(value: T) -> T:
-        antecedents_ = [
-            a
-            for a in antecedents
-            if a in indices_of(value, event_dim=support.event_dim)
-        ]
+        antecedents_ = {
+            a: v
+            for a, v in indices_of(value, event_dim=support.event_dim).items()
+            if a in antecedents
+        }
 
         factual_value = gather(
             value,
-            IndexSet(**{antecedent: {0} for antecedent in antecedents_}),
+            IndexSet(**{antecedent: {0} for antecedent in antecedents_.keys()}),
             event_dim=support.event_dim,
         )
 
         # TODO exponential in len(antecedents) - add an indexed.ops.expand to do this cheaply
+        index_keys: Iterable[MutableMapping[str, Iterable[int]]] = list()
+        for a, v in antecedents_.items():
+            if index_keys == []:
+                index_keys = [dict({a: {value}}.items()) for value in v]
+            else:
+                temp_index_keys = []
+                for i in index_keys:
+                    temp_index_keys.extend(
+                        [
+                            dict(tuple(dict(i).items()) + tuple({a: {value}}.items()))
+                            for value in v
+                        ]
+                    )
+                index_keys = temp_index_keys
+        index_keys = index_keys if index_keys != [] else [{}]
+
         return scatter_n(
-            {
-                IndexSet(
-                    **{antecedent: {ind} for antecedent, ind in zip(antecedents_, inds)}
-                ): factual_value
-                for inds in itertools.product(*[[0, 1]] * len(antecedents_))
-            },
+            {IndexSet(**ind_key): factual_value for ind_key in index_keys},
             event_dim=support.event_dim,
         )
 
diff --git a/tests/explainable/test_handlers_components.py b/tests/explainable/test_handlers_components.py
index cc3d4ead..d3839d9a 100644
--- a/tests/explainable/test_handlers_components.py
+++ b/tests/explainable/test_handlers_components.py
@@ -83,13 +83,16 @@ def test_random_intervention(support, event_shape):
     assert torch.all(support.check(samples))
 
 
-def test_undo_split():
+@pytest.mark.parametrize("num_splits", [1, 2, 5])
+def test_undo_split(num_splits):
     with MultiWorldCounterfactual():
         x_obs = torch.zeros(10)
         x_cf_1 = torch.ones(10)
         x_cf_2 = 2 * x_cf_1
-        x_split = split(x_obs, (x_cf_1,), name="split1", event_dim=1)
-        x_split = split(x_split, (x_cf_2,), name="split2", event_dim=1)
+        x_split = split(x_obs, (x_cf_1,) * num_splits, name="split1", event_dim=1)
+        x_split = split(
+            x_split, (x_cf_2,) * (num_splits + 1), name="split2", event_dim=1
+        )
 
         undo_split2 = undo_split(
             support=constraints.independent(constraints.real, 1), antecedents=["split2"]
@@ -100,9 +103,31 @@ def test_undo_split():
         assert torch.all(gather(x_split, IndexSet(split2={0}), event_dim=1) == x_undone)
 
 
+def test_undo_split_multi_dim():
+    with MultiWorldCounterfactual():
+        x_obs = torch.ones(10)
+        x_cf_1 = 2 * x_obs
+        x_cf_2 = 3 * x_cf_1
+        x_split = split(x_obs, (x_cf_1,), name="split1", event_dim=1)
+        x_split = split(x_split, (x_cf_2, x_cf_1, x_cf_2), name="split2", event_dim=1)
+        x_split = split(x_split, (x_cf_2, x_cf_1), name="split3", event_dim=1)
+
+        undo_split23 = undo_split(
+            support=constraints.independent(constraints.real, 1),
+            antecedents=["split2", "split3"],
+        )
+        x_undone = undo_split23(x_split)
+
+        assert indices_of(x_split, event_dim=1) == indices_of(x_undone, event_dim=1)
+        assert torch.all(
+            gather(x_split, IndexSet(split2={0}, split3={0}), event_dim=1) == x_undone
+        )
+
+
 @pytest.mark.parametrize("plate_size", [4, 50, 200])
 @pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)])
-def test_undo_split_parametrized(event_shape, plate_size):
+@pytest.mark.parametrize("num_splits", [1, 2, 5])
+def test_undo_split_parametrized(event_shape, plate_size, num_splits):
     joint_dims = torch.Size([plate_size, *event_shape])
 
     replace1 = torch.ones(joint_dims)
@@ -114,7 +139,9 @@ def model():
         w = pyro.sample(
             "w", dist.Normal(0, 1).expand(event_shape).to_event(len(event_shape))
         )
-        w = split(w, (replace1,), name="split1", event_dim=len(event_shape))
+        w = split(
+            w, (replace1,) * num_splits, name="split1", event_dim=len(event_shape)
+        )
 
         w = pyro.deterministic(
             "w_preempted",
@@ -146,11 +173,11 @@ def model():
     with mwc:
         assert indices_of(
             nd["w_undone"]["value"], event_dim=len(event_shape)
-        ) == IndexSet(split1={0, 1})
+        ) == IndexSet(split1=set(range(num_splits + 1)))
 
         w_undone_shape = list(nd["w_undone"]["value"].shape)
         desired_shape = list(
-            (2,)
+            (num_splits + 1,)
             + (1,) * (len(w_undone_shape) - len(event_shape) - 2)
             + (plate_size,)
             + event_shape