Skip to content

Commit

Permalink
Merge branch 'master' into az_hook_solver_tests
Browse files Browse the repository at this point in the history
  • Loading branch information
azane committed Aug 26, 2024
2 parents 87a3a89 + 0553360 commit 20de146
Show file tree
Hide file tree
Showing 5 changed files with 903 additions and 283 deletions.
64 changes: 28 additions & 36 deletions chirho/explainable/handlers/components.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
from typing import Callable, Iterable, MutableMapping, Optional, TypeVar

import pyro
Expand Down Expand Up @@ -113,26 +112,37 @@ def undo_split(
"""

def _undo_split(value: T) -> T:
antecedents_ = [
a
for a in antecedents
if a in indices_of(value, event_dim=support.event_dim)
]
antecedents_ = {
a: v
for a, v in indices_of(value, event_dim=support.event_dim).items()
if a in antecedents
}

factual_value = gather(
value,
IndexSet(**{antecedent: {0} for antecedent in antecedents_}),
IndexSet(**{antecedent: {0} for antecedent in antecedents_.keys()}),
event_dim=support.event_dim,
)

# TODO exponential in len(antecedents) - add an indexed.ops.expand to do this cheaply
index_keys: Iterable[MutableMapping[str, Iterable[int]]] = list()
for a, v in antecedents_.items():
if index_keys == []:
index_keys = [dict({a: {value}}.items()) for value in v]
else:
temp_index_keys = []
for i in index_keys:
temp_index_keys.extend(
[
dict(tuple(dict(i).items()) + tuple({a: {value}}.items()))
for value in v
]
)
index_keys = temp_index_keys
index_keys = index_keys if index_keys != [] else [{}]

return scatter_n(
{
IndexSet(
**{antecedent: {ind} for antecedent, ind in zip(antecedents_, inds)}
): factual_value
for inds in itertools.product(*[[0, 1]] * len(antecedents_))
},
{IndexSet(**ind_key): factual_value for ind_key in index_keys},
event_dim=support.event_dim,
)

Expand Down Expand Up @@ -230,15 +240,6 @@ def consequent_eq_neq(
"""

def _consequent_eq_neq(consequent: T) -> torch.Tensor:

factual_indices = IndexSet(
**{
name: ind
for name, ind in get_factual_indices().items()
if name in antecedents
}
)

necessity_world = kwargs.get("necessity_world", 1)
sufficiency_world = kwargs.get("sufficiency_world", 2)

Expand All @@ -249,7 +250,6 @@ def _consequent_eq_neq(consequent: T) -> torch.Tensor:
if name in antecedents
}
)

sufficiency_indices = IndexSet(
**{
name: {sufficiency_world}
Expand All @@ -265,9 +265,6 @@ def _consequent_eq_neq(consequent: T) -> torch.Tensor:
consequent, sufficiency_indices, event_dim=support.event_dim
)

# compare to proposed consequent if provided
# as then the sufficiency value can be different
# due to witness preemption
necessity_log_probs = (
soft_neq(
support,
Expand All @@ -283,7 +280,6 @@ def _consequent_eq_neq(consequent: T) -> torch.Tensor:
**kwargs,
)
)

sufficiency_log_probs = (
soft_eq(support, sufficiency_value, proposed_consequent, **kwargs)
if proposed_consequent is not None
Expand All @@ -292,16 +288,13 @@ def _consequent_eq_neq(consequent: T) -> torch.Tensor:

FACTUAL_NEC_SUFF = torch.zeros_like(sufficiency_log_probs)

index_keys = set(antecedents)
null_index = IndexSet(**{name: {0} for name in index_keys})

nec_suff_log_probs_partitioned = {
**{null_index: FACTUAL_NEC_SUFF},
**{
factual_indices: FACTUAL_NEC_SUFF,
},
**{
IndexSet(**{antecedent: {ind}}): log_prob
for antecedent in (
set(antecedents)
& set(indices_of(consequent, event_dim=support.event_dim))
)
IndexSet(**{antecedent: {ind} for antecedent in index_keys}): log_prob
for ind, log_prob in zip(
[necessity_world, sufficiency_world],
[necessity_log_probs, sufficiency_log_probs],
Expand All @@ -313,7 +306,6 @@ def _consequent_eq_neq(consequent: T) -> torch.Tensor:
nec_suff_log_probs_partitioned,
event_dim=0,
)

return new_value

return _consequent_eq_neq
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# -- Project information -----------------------------------------------------

project = 'chirho'
copyright = '2023, Basis'
copyright = '2024, Basis'
author = 'Basis'


Expand Down
Loading

0 comments on commit 20de146

Please sign in to comment.