Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding undo_split and a test thereof (+small lint) #264

Merged
merged 23 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions chirho/counterfactual/handlers/counterfactual.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import pyro
import torch

from chirho.counterfactual.handlers.ambiguity import FactualConditioningMessenger
from chirho.counterfactual.handlers.ambiguity import (
FactualConditioningMessenger,
)
from chirho.counterfactual.ops import preempt, split
from chirho.indexed.handlers import IndexPlatesMessenger
from chirho.indexed.ops import get_index_plates
Expand Down Expand Up @@ -67,7 +69,9 @@ def _pyro_split(msg: Dict[str, Any]) -> None:
msg["stop"] = True


class MultiWorldCounterfactual(IndexPlatesMessenger, BaseCounterfactualMessenger):
class MultiWorldCounterfactual(
IndexPlatesMessenger, BaseCounterfactualMessenger
):
default_name: str = "intervened"

@classmethod
Expand All @@ -79,7 +83,9 @@ def _pyro_split(cls, msg: Dict[str, Any]) -> None:
msg["kwargs"]["name"] = msg["name"] = name


class TwinWorldCounterfactual(IndexPlatesMessenger, BaseCounterfactualMessenger):
class TwinWorldCounterfactual(
IndexPlatesMessenger, BaseCounterfactualMessenger
):
default_name: str = "intervened"

@classmethod
Expand Down Expand Up @@ -112,7 +118,10 @@ class Preemptions(Generic[T], pyro.poutine.messenger.Messenger):
prefix: str

def __init__(
self, actions: Mapping[str, Intervention[T]], *, prefix: str = "__split_"
self,
actions: Mapping[str, Intervention[T]],
*,
prefix: str = "__split_",
):
self.actions = actions
self.prefix = prefix
Expand Down Expand Up @@ -191,7 +200,8 @@ def _pyro_post_sample(self, msg):
action = (action,) if not isinstance(action, tuple) else action
num_actions = len(action) if isinstance(action, tuple) else 1
weights = torch.tensor(
[0.5 - self.bias] + ([(0.5 + self.bias) / num_actions] * num_actions),
[0.5 - self.bias]
+ ([(0.5 + self.bias) / num_actions] * num_actions),
device=msg["value"].device,
)
case_dist = pyro.distributions.Categorical(probs=weights)
Expand Down
59 changes: 56 additions & 3 deletions chirho/counterfactual/ops.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Optional, Tuple, TypeVar
from typing import Optional, Tuple, TypeVar, Iterable, Callable


import pyro

from chirho.indexed.ops import IndexSet, cond, scatter
from chirho.indexed.ops import IndexSet, cond, scatter, gather, indices_of
from chirho.interventional.ops import Intervention, intervene


S = TypeVar("S")
T = TypeVar("T")

Expand All @@ -23,10 +25,61 @@ def split(obs: T, acts: Tuple[Intervention[T], ...], **kwargs) -> T:
return scatter(act_values, event_dim=kwargs.get("event_dim", 0))


@pyro.poutine.runtime.effectful(type="undo_split")
rfl-urbaniak marked this conversation as resolved.
Show resolved Hide resolved
def undo_split(
rfl-urbaniak marked this conversation as resolved.
Show resolved Hide resolved
antecedents: Iterable[str] = None, event_dim: int = 0
) -> Callable[[T], T]:
"""
A helper function that undoes an upstream `chirho.counterfactual.ops.split`
rfl-urbaniak marked this conversation as resolved.
Show resolved Hide resolved
operation by gathering the factual value and scattering it back into
two alternative cases.

:param antecedents: A list of upstream intervened sites which induced
the `split` to be reversed.
:param event_dim: The event dimension.
rfl-urbaniak marked this conversation as resolved.
Show resolved Hide resolved

:return: A callable that applied to a site value object returns
a site value object in which the factual value has been
scattered back into two alternative cases.
"""
if antecedents is None:
antecedents = []

def _undo_split(value: T) -> T:
antecedents_ = [
a
for a in antecedents
if a in indices_of(value, event_dim=event_dim)
]

factual_value = gather(
value,
IndexSet(**{antecedent: {0} for antecedent in antecedents_}),
event_dim=event_dim,
)

return scatter(
{
IndexSet(
**{antecedent: {0} for antecedent in antecedents_}
): factual_value,
IndexSet(
**{antecedent: {1} for antecedent in antecedents_}
): factual_value,
},
event_dim=event_dim,
)

return _undo_split


@pyro.poutine.runtime.effectful(type="preempt")
@pyro.poutine.block(hide_types=["intervene"])
def preempt(
obs: T, acts: Tuple[Intervention[T], ...], case: Optional[S] = None, **kwargs
obs: T,
acts: Tuple[Intervention[T], ...],
case: Optional[S] = None,
**kwargs
) -> T:
"""
Effectful primitive operation for "preempting" values in a probabilistic program.
Expand Down
Loading