Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move keys property of State into a helper function get_keys #331

Merged
merged 4 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chirho/dynamical/handlers/interruption.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _pyro_post_simulate(self, msg) -> None:
name_to_dim["__time"] = -1
len_traj = (
0
if not self.trajectory.keys
if not get_keys(self.trajectory)
else 1 + max(indices_of(self.trajectory, name_to_dim=name_to_dim)["__time"])
)

Expand Down
20 changes: 10 additions & 10 deletions chirho/dynamical/internals/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from chirho.dynamical.ops import State, Trajectory
from chirho.dynamical.ops import State, Trajectory, get_keys
from chirho.indexed.ops import IndexSet, gather, indices_of, union
from chirho.interventional.handlers import intervene

Expand All @@ -16,7 +16,7 @@ def _indices_of_state(state: State, *, event_dim: int = 0, **kwargs) -> IndexSet
return union(
*(
indices_of(getattr(state, k), event_dim=event_dim, **kwargs)
for k in state.keys
for k in get_keys(state)
)
)

Expand All @@ -28,15 +28,15 @@ def _gather_state(
return type(state)(
**{
k: gather(getattr(state, k), indices, event_dim=event_dim, **kwargs)
for k in state.keys
for k in get_keys(state)
}
)


@intervene.register(State)
def _state_intervene(obs: State[T], act: State[T], **kwargs) -> State[T]:
new_state: State[T] = State()
for k in obs.keys:
for k in get_keys(obs):
setattr(
new_state, k, intervene(getattr(obs, k), getattr(act, k, None), **kwargs)
)
Expand All @@ -50,19 +50,19 @@ def append(fst, rest: T) -> T:

@append.register(Trajectory)
def _append_trajectory(traj1: Trajectory[T], traj2: Trajectory[T]) -> Trajectory[T]:
if len(traj1.keys) == 0:
if len(get_keys(traj1)) == 0:
return traj2

if len(traj2.keys) == 0:
if len(get_keys(traj2)) == 0:
return traj1

if traj1.keys != traj2.keys:
if get_keys(traj1) != get_keys(traj2):
raise ValueError(
f"Trajectories must have the same keys to be appended, but got {traj1.keys} and {traj2.keys}."
f"Trajectories must have the same keys to be appended, but got {get_keys(traj1)} and {get_keys(traj2)}."
)

result: Trajectory[T] = Trajectory()
for k in traj1.keys:
for k in get_keys(traj1):
setattr(result, k, append(getattr(traj1, k), getattr(traj2, k)))

return result
Expand All @@ -83,4 +83,4 @@ def _var_order(varnames: FrozenSet[str]) -> Tuple[str, ...]:


def _trajectory_to_state(traj: Trajectory[T]) -> State[T]:
return State(**{k: getattr(traj, k).squeeze(-1) for k in traj.keys})
return State(**{k: getattr(traj, k).squeeze(-1) for k in get_keys(traj)})
8 changes: 4 additions & 4 deletions chirho/dynamical/internals/backends/torchdiffeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
simulate_point,
simulate_trajectory,
)
from chirho.dynamical.ops import InPlaceDynamics, State, Trajectory
from chirho.dynamical.ops import InPlaceDynamics, State, Trajectory, get_keys
from chirho.indexed.ops import IndexSet, gather, get_index_plates

S = TypeVar("S")
Expand All @@ -35,7 +35,7 @@ def _deriv(
for var, value in zip(var_order, state):
setattr(env, var, value)

assert "t" not in env.keys, "variable name t is reserved for time"
assert "t" not in get_keys(env), "variable name t is reserved for time"
env.t = time

dynamics.diff(ddt, env)
Expand All @@ -48,7 +48,7 @@ def _torchdiffeq_ode_simulate_inner(
timespan,
**odeint_kwargs,
):
var_order = _var_order(initial_state.keys) # arbitrary, but fixed
var_order = _var_order(get_keys(initial_state)) # arbitrary, but fixed

solns = _batched_odeint( # torchdiffeq.odeint(
functools.partial(_deriv, dynamics, var_order),
Expand Down Expand Up @@ -152,7 +152,7 @@ def torchdiffeq_get_next_interruptions_dynamic(
dynamic_interruptions: List[DynamicInterruption],
**kwargs,
) -> Tuple[Tuple[Interruption, ...], torch.Tensor]:
var_order = _var_order(start_state.keys) # arbitrary, but fixed
var_order = _var_order(get_keys(start_state)) # arbitrary, but fixed

# Create the event function combining all dynamic events and the terminal (next) static interruption.
combined_event_f = torchdiffeq_combined_event_f(
Expand Down
9 changes: 4 additions & 5 deletions chirho/dynamical/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,10 @@

class State(Generic[T]):
def __init__(self, **values: T):
# self.class_name =
self.__dict__["_values"] = {}
for k, v in values.items():
setattr(self, k, v)

@property
def keys(self) -> FrozenSet[str]:
return frozenset(self.__dict__["_values"].keys())

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.__dict__['_values']})"

Expand All @@ -42,6 +37,10 @@ class Trajectory(Generic[T], State[T]):
pass


def get_keys(state: State[T]) -> FrozenSet[str]:
return frozenset(state.__dict__["_values"].keys())


@typing.runtime_checkable
class Observable(Protocol[S]):
def observation(self, __state: Union[State[S], Trajectory[S]]) -> None:
Expand Down
12 changes: 6 additions & 6 deletions docs/source/dynamical_intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
" NonInterruptingPointObservationArray\n",
")\n",
"from chirho.dynamical.handlers.interruption import _InterventionMixin\n",
"from chirho.dynamical.ops.dynamical import State, Trajectory, simulate\n",
"from chirho.dynamical.ops.dynamical import State, Trajectory, get_keys, simulate\n",
"\n",
"from chirho.dynamical.ops.ODE import ODEDynamics\n",
"from chirho.dynamical.handlers.ODE.solvers import TorchDiffEq\n",
Expand Down Expand Up @@ -413,7 +413,7 @@
" simulate(sir, init_state, logging_times[0], logging_times[-1] + 1e-3, solver=TorchDiffEq())\n",
" trajectory = dt.trace\n",
" # This is a small trick to make the solution variables available to pyro\n",
" [pyro.deterministic(k, getattr(trajectory, k)) for k in trajectory.keys]\n",
" [pyro.deterministic(k, getattr(trajectory, k)) for k in get_keys(trajectory)]\n",
" return trajectory\n",
"\n",
"\n",
Expand Down Expand Up @@ -891,7 +891,7 @@
" \n",
" trajectory = dt.trace\n",
" # This is a small trick to make the solution variables available to pyro\n",
" [pyro.deterministic(k, getattr(trajectory, k)) for k in trajectory.keys]\n",
" [pyro.deterministic(k, getattr(trajectory, k)) for k in get_keys(trajectory)]\n",
" return trajectory"
]
},
Expand Down Expand Up @@ -1251,7 +1251,7 @@
" trajectory = dt.trace\n",
" \n",
" # This is a small trick to make the solution variables available to pyro\n",
" [pyro.deterministic(k, getattr(trajectory, k)) for k in trajectory.keys]\n",
" [pyro.deterministic(k, getattr(trajectory, k)) for k in get_keys(trajectory)]\n",
" return trajectory"
]
},
Expand Down Expand Up @@ -1740,7 +1740,7 @@
" cf_traj = gather(trajectory, cf_indices, event_dim=0)\n",
" \n",
" # This is a small trick to make the trajectory variables available to pyro \n",
" for k in trajectory.keys:\n",
" for k in get_keys(trajectory):\n",
" pyro.deterministic(k + '_factual', getattr(factual_traj, k))\n",
" pyro.deterministic(k + '_cf', getattr(cf_traj, k))"
]
Expand Down Expand Up @@ -2178,7 +2178,7 @@
" trajectory = dt.trace\n",
" solutions.append(trajectory)\n",
" # This is a small trick to make the trajectory variables available to pyro\n",
" [pyro.deterministic(f\"{k}_{unit_ix}\", getattr(trajectory, k))for k in trajectory.keys]\n",
" [pyro.deterministic(f\"{k}_{unit_ix}\", getattr(trajectory, k))for k in get_keys(trajectory)]\n",
" return solutions\n",
"\n",
"\n",
Expand Down
12 changes: 6 additions & 6 deletions tests/dynamical/dynamical_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from pyro.distributions import Normal, Uniform, constraints

from chirho.dynamical.ops import State, Trajectory
from chirho.dynamical.ops import State, Trajectory, get_keys

T = TypeVar("T")

Expand Down Expand Up @@ -54,14 +54,14 @@ def bayes_sir_model():
def check_keys_match(
obj1: Union[Trajectory[T], State[T]], obj2: Union[Trajectory[T], State[T]]
):
assert obj1.keys == obj2.keys, "Objects have different variables."
assert get_keys(obj1) == get_keys(obj2), "Objects have different variables."
return True


def check_trajectory_length_match(
traj1: Trajectory[torch.tensor], traj2: Trajectory[torch.tensor]
):
for k in traj1.keys:
for k in get_keys(traj1):
assert len(getattr(traj2, k)) == len(
getattr(traj1, k)
), f"Trajectories have different lengths for variable {k}."
Expand All @@ -75,7 +75,7 @@ def check_trajectories_match(

assert check_trajectory_length_match(traj1, traj2)

for k in traj1.keys:
for k in get_keys(traj1):
assert torch.allclose(
getattr(traj2, k), getattr(traj1, k)
), f"Trajectories differ in state trajectory of variable {k}, but should be identical."
Expand All @@ -86,7 +86,7 @@ def check_trajectories_match(
def check_states_match(state1: State[torch.tensor], state2: State[torch.tensor]):
assert check_keys_match(state1, state2)

for k in state1.keys:
for k in get_keys(state1):
assert torch.allclose(
getattr(state1, k), getattr(state2, k)
), f"Trajectories differ in state trajectory of variable {k}, but should be identical."
Expand All @@ -101,7 +101,7 @@ def check_trajectories_match_in_all_but_values(

assert check_trajectory_length_match(traj1, traj2)

for k in traj1.keys:
for k in get_keys(traj1):
assert not torch.allclose(
getattr(traj2, k), getattr(traj1, k)
), f"Trajectories are identical in state trajectory of variable {k}, but should differ."
Expand Down
14 changes: 7 additions & 7 deletions tests/dynamical/test_dynamic_interventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
LogTrajectory,
)
from chirho.dynamical.handlers.solver import TorchDiffEq
from chirho.dynamical.ops import InPlaceDynamics, State, simulate
from chirho.dynamical.ops import InPlaceDynamics, State, get_keys, simulate
from chirho.indexed.ops import IndexSet, gather, indices_of, union

from .dynamical_fixtures import UnifiedFixtureDynamics
Expand Down Expand Up @@ -269,7 +269,7 @@ def test_split_twinworld_dynamic_intervention(

with cf:
cf_trajectory = dt.trajectory
for k in cf_trajectory.keys:
for k in get_keys(cf_trajectory):
# TODO: Figure out why event_dim=1 is not needed with cf_state but is with cf_trajectory.
assert cf.default_name in indices_of(getattr(cf_state, k))
assert cf.default_name in indices_of(getattr(cf_trajectory, k), event_dim=1)
Expand Down Expand Up @@ -317,7 +317,7 @@ def test_split_multiworld_dynamic_intervention(

with cf:
cf_trajectory = dt.trajectory
for k in cf_trajectory.keys:
for k in get_keys(cf_trajectory):
# TODO: Figure out why event_dim=1 is not needed with cf_state but is with cf_trajectory.
assert cf.default_name in indices_of(getattr(cf_state, k))
assert cf.default_name in indices_of(getattr(cf_trajectory, k), event_dim=1)
Expand Down Expand Up @@ -388,15 +388,15 @@ def test_split_twinworld_dynamic_matches_output(
assert not set(indices_of(cf_actual, event_dim=0))
assert not set(indices_of(factual_actual, event_dim=0))

assert set(cf_result.keys) == set(cf_actual.keys) == set(cf_expected.keys)
assert set(cf_result.keys) == set(factual_actual.keys) == set(factual_expected.keys)
assert get_keys(cf_result) == get_keys(cf_actual) == get_keys(cf_expected)
assert get_keys(cf_result) == get_keys(factual_actual) == get_keys(factual_expected)

for k in cf_result.keys:
for k in get_keys(cf_result):
assert torch.allclose(
getattr(cf_actual, k), getattr(cf_expected, k), atol=1e-3, rtol=0
), f"Trajectories differ in state result of variable {k}, but should be identical."

for k in cf_result.keys:
for k in get_keys(cf_result):
assert torch.allclose(
getattr(factual_actual, k), getattr(factual_expected, k), atol=1e-3, rtol=0
), f"Trajectories differ in state result of variable {k}, but should be identical."
Expand Down
12 changes: 6 additions & 6 deletions tests/dynamical/test_log_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from chirho.dynamical.handlers import InterruptionEventLoop, LogTrajectory
from chirho.dynamical.handlers.solver import TorchDiffEq
from chirho.dynamical.internals._utils import append
from chirho.dynamical.ops import State, Trajectory, simulate
from chirho.dynamical.ops import State, Trajectory, get_keys, simulate

from .dynamical_fixtures import bayes_sir_model, check_states_match

Expand Down Expand Up @@ -40,17 +40,17 @@ def test_logging():
assert isinstance(result1, State)
assert isinstance(dt1.trajectory, Trajectory)
assert isinstance(dt2.trajectory, Trajectory)
assert len(dt1.trajectory.keys) == 3
assert len(dt2.trajectory.keys) == 3
assert dt1.trajectory.keys == result1.keys
assert dt2.trajectory.keys == result2.keys
assert len(get_keys(dt1.trajectory)) == 3
assert len(get_keys(dt2.trajectory)) == 3
assert get_keys(dt1.trajectory) == get_keys(result1)
assert get_keys(dt2.trajectory) == get_keys(result2)
assert check_states_match(result1, result2)
assert check_states_match(result1, result3)


def test_trajectory_methods():
trajectory = Trajectory(S=torch.tensor([1.0, 2.0, 3.0]))
assert trajectory.keys == frozenset({"S"})
assert get_keys(trajectory) == frozenset({"S"})
assert str(trajectory) == "Trajectory({'S': tensor([1., 2., 3.])})"


Expand Down
16 changes: 8 additions & 8 deletions tests/dynamical/test_static_interventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
StaticIntervention,
)
from chirho.dynamical.handlers.solver import TorchDiffEq
from chirho.dynamical.ops import State, simulate
from chirho.dynamical.ops import State, get_keys, simulate
from chirho.indexed.ops import IndexSet, gather, indices_of
from chirho.interventional.ops import intervene

Expand Down Expand Up @@ -241,7 +241,7 @@ def test_twinworld_point_intervention(

with cf:
cf_trajectory = dt.trajectory
for k in cf_trajectory.keys:
for k in get_keys(cf_trajectory):
# TODO: Figure out why event_dim=1 is not needed with cf_state but is with cf_trajectory.
assert cf.default_name in indices_of(getattr(cf_state, k))
assert cf.default_name in indices_of(getattr(cf_trajectory, k), event_dim=1)
Expand Down Expand Up @@ -276,7 +276,7 @@ def test_multiworld_point_intervention(

with cf:
cf_trajectory = dt.trajectory
for k in cf_trajectory.keys:
for k in get_keys(cf_trajectory):
# TODO: Figure out why event_dim=1 is not needed with cf_state but is with cf_trajectory.
assert cf.default_name in indices_of(getattr(cf_state, k))
assert cf.default_name in indices_of(getattr(cf_trajectory, k), event_dim=1)
Expand All @@ -300,7 +300,7 @@ def test_split_odeint_broadcast(

with cf:
trajectory = dt.trajectory
for k in trajectory.keys:
for k in get_keys(trajectory):
assert len(indices_of(getattr(trajectory, k), event_dim=1)) > 0


Expand Down Expand Up @@ -353,15 +353,15 @@ def test_twinworld_matches_output(
assert not set(indices_of(cf_actual, event_dim=0))
assert not set(indices_of(factual_actual, event_dim=0))

assert set(cf_state.keys) == set(cf_actual.keys) == set(cf_expected.keys)
assert set(cf_state.keys) == set(factual_actual.keys) == set(factual_expected.keys)
assert get_keys(cf_state) == get_keys(cf_actual) == get_keys(cf_expected)
assert get_keys(cf_state) == get_keys(factual_actual) == get_keys(factual_expected)

for k in cf_state.keys:
for k in get_keys(cf_state):
assert torch.allclose(
getattr(cf_actual, k), getattr(cf_expected, k)
), f"States differ in state trajectory of variable {k}, but should be identical."

for k in cf_state.keys:
for k in get_keys(cf_state):
assert torch.allclose(
getattr(factual_actual, k), getattr(factual_expected, k)
), f"States differ in state trajectory of variable {k}, but should be identical."