Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/BasisResearch/chirho into…
Browse files Browse the repository at this point in the history
… sw-dynamic-multi-tutorial
  • Loading branch information
rfl-urbaniak committed Nov 25, 2024
2 parents ca098db + f015a77 commit 3ea9ce4
Show file tree
Hide file tree
Showing 13 changed files with 914 additions and 398 deletions.
36 changes: 23 additions & 13 deletions chirho/explainable/handlers/components.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
from typing import Callable, Iterable, MutableMapping, Optional, TypeVar

import pyro
Expand Down Expand Up @@ -113,26 +112,37 @@ def undo_split(
"""

def _undo_split(value: T) -> T:
antecedents_ = [
a
for a in antecedents
if a in indices_of(value, event_dim=support.event_dim)
]
antecedents_ = {
a: v
for a, v in indices_of(value, event_dim=support.event_dim).items()
if a in antecedents
}

factual_value = gather(
value,
IndexSet(**{antecedent: {0} for antecedent in antecedents_}),
IndexSet(**{antecedent: {0} for antecedent in antecedents_.keys()}),
event_dim=support.event_dim,
)

# TODO exponential in len(antecedents) - add an indexed.ops.expand to do this cheaply
index_keys: Iterable[MutableMapping[str, Iterable[int]]] = list()
for a, v in antecedents_.items():
if index_keys == []:
index_keys = [dict({a: {value}}.items()) for value in v]
else:
temp_index_keys = []
for i in index_keys:
temp_index_keys.extend(
[
dict(tuple(dict(i).items()) + tuple({a: {value}}.items()))
for value in v
]
)
index_keys = temp_index_keys
index_keys = index_keys if index_keys != [] else [{}]

return scatter_n(
{
IndexSet(
**{antecedent: {ind} for antecedent, ind in zip(antecedents_, inds)}
): factual_value
for inds in itertools.product(*[[0, 1]] * len(antecedents_))
},
{IndexSet(**ind_key): factual_value for ind_key in index_keys},
event_dim=support.event_dim,
)

Expand Down
4 changes: 2 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@
"myst_parser",
# "sphinx_gallery.gen_gallery",
# "sphinx_search.extension",
"sphinxcontrib.bibtex"
"sphinxcontrib.bibtex",
"sphinxcontrib.jquery",
]

# Point sphinxcontrib.bibtex to the bibtex file.
Expand Down Expand Up @@ -102,7 +103,6 @@
# logo
html_logo = "_static/img/chirho_logo_wide.png"


# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
Expand Down
722 changes: 534 additions & 188 deletions docs/source/explainable_categorical.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"isort",
"sphinx==7.1.2",
"sphinxcontrib-bibtex",
"sphinxcontrib-jquery",
"sphinx_rtd_theme==1.3.0",
"myst_parser",
"nbsphinx",
Expand Down
89 changes: 67 additions & 22 deletions tests/dynamical/dynamical_fixtures.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TypeVar
from typing import Mapping, TypeVar

import pyro
import torch
Expand All @@ -10,30 +10,30 @@

T = TypeVar("T")

ATempParams = Mapping[str, T]

class UnifiedFixtureDynamics(pyro.nn.PyroModule):
def __init__(self, beta=None, gamma=None):
super().__init__()

self.beta = beta
if self.beta is None:
self.beta = pyro.param("beta", torch.tensor(0.5), constraints.positive)
# SIR dynamics written as a pure function of state and parameters.
def pure_sir_dynamics(
state: State[torch.Tensor], atemp_params: ATempParams[torch.Tensor]
) -> State[torch.Tensor]:
beta = atemp_params["beta"]
gamma = atemp_params["gamma"]

self.gamma = gamma
if self.gamma is None:
self.gamma = pyro.param("gamma", torch.tensor(0.7), constraints.positive)
dX: State[torch.Tensor] = dict()

def forward(self, X: State[torch.Tensor]):
dX: State[torch.Tensor] = dict()
beta = self.beta * (
1.0 + 0.1 * torch.sin(0.1 * X["t"])
) # beta oscilates slowly in time.
beta = beta * (
1.0 + 0.1 * torch.sin(0.1 * state["t"])
) # beta oscilates slowly in time.

dX["S"] = -beta * state["S"] * state["I"] # noqa
dX["I"] = beta * state["S"] * state["I"] - gamma * state["I"] # noqa
dX["R"] = gamma * state["I"] # noqa

dX["S"] = -beta * X["S"] * X["I"]
dX["I"] = beta * X["S"] * X["I"] - self.gamma * X["I"] # noqa
dX["R"] = self.gamma * X["I"]
return dX
return dX


class SIRObservationMixin:
def _unit_measurement_error(self, name: str, x: torch.Tensor):
if x.ndim == 0:
return pyro.sample(name, Normal(x, 1))
Expand All @@ -47,9 +47,46 @@ def observation(self, X: State[torch.Tensor]):
self._unit_measurement_error("R_obs", X["R"])


def bayes_sir_model():
class SIRReparamObservationMixin(SIRObservationMixin):
def observation(self, X: State[torch.Tensor]):

# A flight arrives in a country that tests all arrivals for a disease. The number of people infected on the
# plane is a noisy function of the number of infected people in the country of origin at that time.
u_ip = pyro.sample(
"u_ip", Normal(7.0, 2.0).expand(X["I"].shape[-1:]).to_event(1)
)
pyro.deterministic("infected_passengers", X["I"] + u_ip, event_dim=1)


class UnifiedFixtureDynamicsBase(pyro.nn.PyroModule):
def __init__(self, beta=None, gamma=None):
super().__init__()

self.beta = beta
if self.beta is None:
self.beta = pyro.param("beta", torch.tensor(0.5), constraints.positive)

self.gamma = gamma
if self.gamma is None:
self.gamma = pyro.param("gamma", torch.tensor(0.7), constraints.positive)

def forward(self, X: State[torch.Tensor]):
atemp_params = dict(beta=self.beta, gamma=self.gamma)
return pure_sir_dynamics(X, atemp_params)


class UnifiedFixtureDynamics(UnifiedFixtureDynamicsBase, SIRObservationMixin):
pass


def sir_param_prior():
beta = pyro.sample("beta", Uniform(0, 1))
gamma = pyro.sample("gamma", Uniform(0, 1))
return beta, gamma


def bayes_sir_model():
beta, gamma = sir_param_prior()
sir = UnifiedFixtureDynamics(beta, gamma)
return sir

Expand All @@ -64,7 +101,8 @@ def check_states_match(state1: State[torch.Tensor], state2: State[torch.Tensor])

for k in state1.keys():
assert torch.allclose(
state1[k], state2[k]
state1[k],
state2[k],
), f"Trajectories differ in state trajectory of variable {k}, but should be identical."

return True
Expand All @@ -77,7 +115,7 @@ def check_trajectories_match_in_all_but_values(

for k in traj1.keys():
assert not torch.allclose(
traj2[k], traj1[k]
traj2[k], traj1[k], atol=1e-6, rtol=1e-3
), f"Trajectories are identical in state trajectory of variable {k}, but should differ."

return True
Expand All @@ -98,3 +136,10 @@ def run_svi_inference_torch_direct(model, n_steps=100, verbose=True, **model_kwa
if (step % 100 == 0) or (step == 1) & verbose:
print("[iteration %04d] loss: %.4f" % (step, loss))
return guide


def build_event_fn_zero_after_tt(tt: torch.Tensor):
def zero_after_tt(t: torch.Tensor, state: State[torch.Tensor]):
return torch.where(t < tt, tt - t, 0.0)

return zero_after_tt
Loading

0 comments on commit 3ea9ce4

Please sign in to comment.