Skip to content

Commit

Permalink
replaced explicit State with Dict (#346)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamWitty authored Oct 18, 2023
1 parent 772f824 commit a40a2a1
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 101 deletions.
22 changes: 9 additions & 13 deletions chirho/dynamical/internals/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
@indices_of.register(State)
def _indices_of_state(state: State, *, event_dim: int = 0, **kwargs) -> IndexSet:
return union(
*(
indices_of(getattr(state, k), event_dim=event_dim, **kwargs)
for k in get_keys(state)
)
*(indices_of(state[k], event_dim=event_dim, **kwargs) for k in get_keys(state))
)


Expand All @@ -28,7 +25,7 @@ def _gather_state(
) -> State[T]:
return type(state)(
**{
k: gather(getattr(state, k), indices, event_dim=event_dim, **kwargs)
k: gather(state[k], indices, event_dim=event_dim, **kwargs)
for k in get_keys(state)
}
)
Expand All @@ -38,9 +35,7 @@ def _gather_state(
def _state_intervene(obs: State[T], act: State[T], **kwargs) -> State[T]:
new_state: State[T] = State()
for k in get_keys(obs):
setattr(
new_state, k, intervene(getattr(obs, k), getattr(act, k, None), **kwargs)
)
new_state[k] = intervene(obs[k], act[k] if k in act else None, **kwargs)
return new_state


Expand All @@ -64,8 +59,7 @@ def _append_trajectory(traj1: State[T], traj2: State[T]) -> State[T]:

result: State[T] = State()
for k in get_keys(traj1):
setattr(result, k, append(getattr(traj1, k), getattr(traj2, k)))

result[k] = append(traj1[k], traj2[k])
return result


Expand All @@ -83,8 +77,8 @@ def _var_order(varnames: FrozenSet[str]) -> Tuple[str, ...]:
return tuple(sorted(varnames))


def _squeeze_time_dim(traj: State[T]) -> State[T]:
return State(**{k: getattr(traj, k).squeeze(-1) for k in get_keys(traj)})
def _squeeze_time_dim(traj: State[torch.Tensor]) -> State[torch.Tensor]:
return State(**{k: traj[k].squeeze(-1) for k in get_keys(traj)})


@observe.register(State)
Expand All @@ -103,9 +97,11 @@ def _observe_state(
if obs is rv or obs is None:
return rv

assert isinstance(obs, State)

return State(
**{
k: observe(getattr(rv, k), getattr(obs, k), name=f"{name}__{k}", **kwargs)
k: observe(rv[k], obs[k], name=f"{name}__{k}", **kwargs)
for k in get_keys(rv)
}
)
12 changes: 6 additions & 6 deletions chirho/dynamical/internals/backends/torchdiffeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ def _deriv(
) -> Tuple[torch.Tensor, ...]:
env: State[torch.Tensor] = State()
for var, value in zip(var_order, state):
setattr(env, var, value)
env[var] = value

assert "t" not in get_keys(env), "variable name t is reserved for time"
env.t = time
env["t"] = time

ddt: State[torch.Tensor] = dynamics(env)
return tuple(getattr(ddt, var, torch.tensor(0.0)) for var in var_order)
return tuple(ddt.get(var, torch.tensor(0.0)) for var in var_order)


def _torchdiffeq_ode_simulate_inner(
Expand All @@ -50,14 +50,14 @@ def _torchdiffeq_ode_simulate_inner(

solns = _batched_odeint( # torchdiffeq.odeint(
functools.partial(_deriv, dynamics, var_order),
tuple(getattr(initial_state, v) for v in var_order),
tuple(initial_state[v] for v in var_order),
timespan,
**odeint_kwargs,
)

trajectory: State[torch.Tensor] = State()
for var, soln in zip(var_order, solns):
setattr(trajectory, var, soln)
trajectory[var] = soln

return trajectory

Expand Down Expand Up @@ -160,7 +160,7 @@ def torchdiffeq_get_next_interruptions_dynamic(
# Simulate to the event execution.
event_time, event_solutions = _batched_odeint( # torchdiffeq.odeint_event(
functools.partial(_deriv, dynamics, var_order),
tuple(getattr(start_state, v) for v in var_order),
tuple(start_state[v] for v in var_order),
start_time,
event_fn=combined_event_f,
**solver.odeint_kwargs,
Expand Down
26 changes: 4 additions & 22 deletions chirho/dynamical/ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numbers
import typing
from typing import Callable, FrozenSet, Generic, Optional, TypeVar, Union
from typing import Callable, Dict, FrozenSet, Generic, Optional, TypeVar, Union

import pyro
import torch
Expand All @@ -10,30 +10,12 @@
T = TypeVar("T")


class State(Generic[T]):
def __init__(self, **values: T):
self.__dict__["_values"] = {}
for k, v in values.items():
setattr(self, k, v)

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.__dict__['_values']})"

def __str__(self) -> str:
return f"{self.__class__.__name__}({self.__dict__['_values']})"

def __setattr__(self, __name: str, __value: T) -> None:
self.__dict__["_values"][__name] = __value

def __getattr__(self, __name: str) -> T:
if __name in self.__dict__["_values"]:
return self.__dict__["_values"][__name]
else:
raise AttributeError(f"{__name} not in {self.__dict__['_values']}")
class State(Generic[T], Dict[str, T]):
pass


def get_keys(state: State[T]) -> FrozenSet[str]:
return frozenset(state.__dict__["_values"].keys())
return frozenset(state.keys())


Dynamics = Callable[[State[T]], State[T]]
Expand Down
18 changes: 9 additions & 9 deletions tests/dynamical/dynamical_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ def __init__(self, beta=None, gamma=None):
def forward(self, X: State[torch.Tensor]):
dX: State[torch.Tensor] = State()
beta = self.beta * (
1.0 + 0.1 * torch.sin(0.1 * X.t)
1.0 + 0.1 * torch.sin(0.1 * X["t"])
) # beta oscilates slowly in time.

dX.S = -beta * X.S * X.I
dX.I = beta * X.S * X.I - self.gamma * X.I # noqa
dX.R = self.gamma * X.I
dX["S"] = -beta * X["S"] * X["I"]
dX["I"] = beta * X["S"] * X["I"] - self.gamma * X["I"] # noqa
dX["R"] = self.gamma * X["I"]
return dX

def _unit_measurement_error(self, name: str, x: torch.Tensor):
Expand All @@ -42,9 +42,9 @@ def _unit_measurement_error(self, name: str, x: torch.Tensor):

@pyro.nn.pyro_method
def observation(self, X: State[torch.Tensor]):
self._unit_measurement_error("S_obs", X.S)
self._unit_measurement_error("I_obs", X.I)
self._unit_measurement_error("R_obs", X.R)
self._unit_measurement_error("S_obs", X["S"])
self._unit_measurement_error("I_obs", X["I"])
self._unit_measurement_error("R_obs", X["R"])


def bayes_sir_model():
Expand All @@ -64,7 +64,7 @@ def check_states_match(state1: State[torch.Tensor], state2: State[torch.Tensor])

for k in get_keys(state1):
assert torch.allclose(
getattr(state1, k), getattr(state2, k)
state1[k], state2[k]
), f"Trajectories differ in state trajectory of variable {k}, but should be identical."

return True
Expand All @@ -77,7 +77,7 @@ def check_trajectories_match_in_all_but_values(

for k in get_keys(traj1):
assert not torch.allclose(
getattr(traj2, k), getattr(traj1, k)
traj2[k], traj1[k]
), f"Trajectories are identical in state trajectory of variable {k}, but should differ."

return True
Expand Down
71 changes: 40 additions & 31 deletions tests/dynamical/test_dynamic_interventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

def get_state_reached_event_f(target_state: State[torch.tensor], event_dim: int = 0):
def event_f(t: torch.tensor, state: State[torch.tensor]):
actual, target = state.R, target_state.R
actual, target = state["R"], target_state["R"]
cf_indices = IndexSet(
**{
k: {1}
Expand Down Expand Up @@ -96,14 +96,14 @@ def test_nested_dynamic_intervention_causes_change(
model, init_state, start_time, end_time, solver=TorchDiffEq()
)

preint_total = init_state.S + init_state.I + init_state.R
preint_total = init_state["S"] + init_state["I"] + init_state["R"]

# Each intervention just adds a certain amount of susceptible people after the recovered count exceeds some amount

trajectory = dt.trajectory

postint_mask1 = trajectory.R > ts1.R
postint_mask2 = trajectory.R > ts2.R
postint_mask1 = trajectory["R"] > ts1["R"]
postint_mask2 = trajectory["R"] > ts2["R"]
preint_mask = ~(postint_mask1 | postint_mask2)

# TODO support dim != -1
Expand All @@ -114,16 +114,18 @@ def test_nested_dynamic_intervention_causes_change(

# Make sure all points before the intervention maintain the same total population.
preint_traj = gather(trajectory, preint_idx, name_to_dim=name_to_dim)
assert torch.allclose(preint_total, preint_traj.S + preint_traj.I + preint_traj.R)
assert torch.allclose(
preint_total, preint_traj["S"] + preint_traj["I"] + preint_traj["R"]
)

# Make sure all points after the first intervention, but before the second, include the added population of that
# first intervention.
postfirst_int_mask, postsec_int_mask = (
(postint_mask1, postint_mask2)
if ts1.R < ts2.R
if ts1["R"] < ts2["R"]
else (postint_mask2, postint_mask1)
)
firstis, secondis = (is1, is2) if ts1.R < ts2.R else (is2, is1)
firstis, secondis = (is1, is2) if ts1["R"] < ts2["R"] else (is2, is1)

postfirst_int_presec_int_mask = postfirst_int_mask & ~postsec_int_mask

Expand All @@ -143,10 +145,10 @@ def test_nested_dynamic_intervention_causes_change(
trajectory, postfirst_int_presec_int_idx, name_to_dim=name_to_dim
)
assert torch.all(
postfirst_int_presec_int_traj.S
+ postfirst_int_presec_int_traj.I
+ postfirst_int_presec_int_traj.R
> (preint_total + firstis.S) * 0.95
postfirst_int_presec_int_traj["S"]
+ postfirst_int_presec_int_traj["I"]
+ postfirst_int_presec_int_traj["R"]
> (preint_total + firstis["S"]) * 0.95
)

postsec_int_idx = IndexSet(
Expand All @@ -155,8 +157,8 @@ def test_nested_dynamic_intervention_causes_change(

postsec_int_traj = gather(trajectory, postsec_int_idx, name_to_dim=name_to_dim)
assert torch.all(
postsec_int_traj.S + postsec_int_traj.I + postsec_int_traj.R
> (preint_total + firstis.S + secondis.S) * 0.95
postsec_int_traj["S"] + postsec_int_traj["I"] + postsec_int_traj["R"]
> (preint_total + firstis["S"] + secondis["S"]) * 0.95
)


Expand Down Expand Up @@ -186,15 +188,15 @@ def test_dynamic_intervention_causes_change(
):
simulate(model, init_state, start_time, end_time, solver=TorchDiffEq())

preint_total = init_state.S + init_state.I + init_state.R
preint_total = init_state["S"] + init_state["I"] + init_state["R"]

trajectory = dt.trajectory

# The intervention just "adds" (sets) 50 "people" to the susceptible population.
# It happens that the susceptible population is roughly 0 at the intervention point,
# so this serves to make sure the intervention actually causes that population influx.

postint_mask = trajectory.R > trigger_state.R
postint_mask = trajectory["R"] > trigger_state["R"]

# TODO support dim != -1
name_to_dim = {"__time": -1}
Expand All @@ -210,13 +212,15 @@ def test_dynamic_intervention_causes_change(
preint_traj = gather(trajectory, preint_idx, name_to_dim=name_to_dim)

# Make sure all points before the intervention maintain the same total population.
assert torch.allclose(preint_total, preint_traj.S + preint_traj.I + preint_traj.R)
assert torch.allclose(
preint_total, preint_traj["S"] + preint_traj["I"] + preint_traj["R"]
)

# Make sure all points after the intervention include the added population.
# noinspection PyTypeChecker
assert torch.all(
postint_traj.S + postint_traj.I + postint_traj.R
> (preint_total + intervene_state.S) * 0.95
postint_traj["S"] + postint_traj["I"] + postint_traj["R"]
> (preint_total + intervene_state["S"]) * 0.95
)


Expand Down Expand Up @@ -271,8 +275,8 @@ def test_split_twinworld_dynamic_intervention(
cf_trajectory = dt.trajectory
for k in get_keys(cf_trajectory):
# TODO: Figure out why event_dim=1 is not needed with cf_state but is with cf_trajectory.
assert cf.default_name in indices_of(getattr(cf_state, k))
assert cf.default_name in indices_of(getattr(cf_trajectory, k), event_dim=1)
assert cf.default_name in indices_of(cf_state[k])
assert cf.default_name in indices_of(cf_trajectory[k], event_dim=1)


@pytest.mark.parametrize("model", [UnifiedFixtureDynamics()])
Expand Down Expand Up @@ -319,8 +323,8 @@ def test_split_multiworld_dynamic_intervention(
cf_trajectory = dt.trajectory
for k in get_keys(cf_trajectory):
# TODO: Figure out why event_dim=1 is not needed with cf_state but is with cf_trajectory.
assert cf.default_name in indices_of(getattr(cf_state, k))
assert cf.default_name in indices_of(getattr(cf_trajectory, k), event_dim=1)
assert cf.default_name in indices_of(cf_state[k])
assert cf.default_name in indices_of(cf_trajectory[k], event_dim=1)


@pytest.mark.parametrize("model", [UnifiedFixtureDynamics()])
Expand Down Expand Up @@ -393,30 +397,35 @@ def test_split_twinworld_dynamic_matches_output(

for k in get_keys(cf_result):
assert torch.allclose(
getattr(cf_actual, k), getattr(cf_expected, k), atol=1e-3, rtol=0
cf_actual[k], cf_expected[k], atol=1e-3, rtol=0
), f"Trajectories differ in state result of variable {k}, but should be identical."

for k in get_keys(cf_result):
assert torch.allclose(
getattr(factual_actual, k), getattr(factual_expected, k), atol=1e-3, rtol=0
factual_actual[k],
factual_expected[k],
atol=1e-3,
rtol=0,
), f"Trajectories differ in state result of variable {k}, but should be identical."


def test_grad_of_dynamic_intervention_event_f_params():
def model(X: State[torch.Tensor]):
dX = State()
dX.x = tt(1.0)
dX.z = X.dz
dX.dz = tt(0.0) # also a constant, this gets set by interventions.
dX.param = tt(0.0) # this is a constant event function parameter, so no change.
dX["x"] = tt(1.0)
dX["z"] = X["dz"]
dX["dz"] = tt(0.0) # also a constant, this gets set by interventions.
dX["param"] = tt(
0.0
) # this is a constant event function parameter, so no change.
return dX

param = torch.nn.Parameter(tt(5.0))
# Param has to be part of the state in order to take gradients with respect to it.
s0 = State(x=tt(0.0), z=tt(0.0), dz=tt(0.0), param=param)

dynamic_intervention = DynamicIntervention(
event_f=lambda t, s: t - s.param,
event_f=lambda t, s: t - s["param"],
intervention=State(dz=tt(1.0)),
)

Expand All @@ -426,13 +435,13 @@ def model(X: State[torch.Tensor]):
result = simulate(model, s0, start_time, end_time, solver=TorchDiffEq())

(dxdparam,) = torch.autograd.grad(
outputs=(result.x,), inputs=(param,), create_graph=True
outputs=(result["x"],), inputs=(param,), create_graph=True
)
assert torch.isclose(dxdparam, tt(0.0), atol=1e-5)

# Z begins accruing dz=1 at t=param, so dzdparam should be -1.0.
(dzdparam,) = torch.autograd.grad(
outputs=(result.z,), inputs=(param,), create_graph=True
outputs=(result["z"],), inputs=(param,), create_graph=True
)
assert torch.isclose(dzdparam, tt(-1.0), atol=1e-5)

Expand Down
Loading

0 comments on commit a40a2a1

Please sign in to comment.