From 9781706493b65a14902125fd71861f38527c7cba Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 16 Oct 2023 14:45:28 -0400 Subject: [PATCH 1/2] Move keys property of state into a function --- chirho/dynamical/handlers/interruption.py | 4 ++-- chirho/dynamical/internals/_utils.py | 20 +++++++++---------- .../internals/backends/torchdiffeq.py | 8 ++++---- chirho/dynamical/ops.py | 9 ++++----- docs/source/dynamical_intro.ipynb | 12 +++++------ tests/dynamical/dynamical_fixtures.py | 12 +++++------ tests/dynamical/test_dynamic_interventions.py | 14 ++++++------- tests/dynamical/test_log_trajectory.py | 12 +++++------ tests/dynamical/test_static_interventions.py | 16 +++++++-------- 9 files changed, 53 insertions(+), 54 deletions(-) diff --git a/chirho/dynamical/handlers/interruption.py b/chirho/dynamical/handlers/interruption.py index 105deac4..ec1e2f38 100644 --- a/chirho/dynamical/handlers/interruption.py +++ b/chirho/dynamical/handlers/interruption.py @@ -6,7 +6,7 @@ import torch from chirho.dynamical.handlers.trajectory import LogTrajectory -from chirho.dynamical.ops import ObservableInPlaceDynamics, State +from chirho.dynamical.ops import ObservableInPlaceDynamics, State, get_keys from chirho.indexed.ops import get_index_plates, indices_of from chirho.interventional.ops import Intervention, intervene from chirho.observational.handlers import condition @@ -172,7 +172,7 @@ def _pyro_post_simulate(self, msg) -> None: name_to_dim["__time"] = -1 len_traj = ( 0 - if not self.trajectory.keys + if not get_keys(self.trajectory) else 1 + max(indices_of(self.trajectory, name_to_dim=name_to_dim)["__time"]) ) diff --git a/chirho/dynamical/internals/_utils.py b/chirho/dynamical/internals/_utils.py index 47dd1741..b02f1209 100644 --- a/chirho/dynamical/internals/_utils.py +++ b/chirho/dynamical/internals/_utils.py @@ -3,7 +3,7 @@ import torch -from chirho.dynamical.ops import State, Trajectory +from chirho.dynamical.ops import State, Trajectory, get_keys from chirho.indexed.ops import IndexSet, gather, indices_of, union from chirho.interventional.handlers import intervene @@ -16,7 +16,7 @@ def _indices_of_state(state: State, *, event_dim: int = 0, **kwargs) -> IndexSet return union( *( indices_of(getattr(state, k), event_dim=event_dim, **kwargs) - for k in state.keys + for k in get_keys(state) ) ) @@ -28,7 +28,7 @@ def _gather_state( return type(state)( **{ k: gather(getattr(state, k), indices, event_dim=event_dim, **kwargs) - for k in state.keys + for k in get_keys(state) } ) @@ -36,7 +36,7 @@ def _gather_state( @intervene.register(State) def _state_intervene(obs: State[T], act: State[T], **kwargs) -> State[T]: new_state: State[T] = State() - for k in obs.keys: + for k in get_keys(obs): setattr( new_state, k, intervene(getattr(obs, k), getattr(act, k, None), **kwargs) ) @@ -50,19 +50,19 @@ def append(fst, rest: T) -> T: @append.register(Trajectory) def _append_trajectory(traj1: Trajectory[T], traj2: Trajectory[T]) -> Trajectory[T]: - if len(traj1.keys) == 0: + if len(get_keys(traj1)) == 0: return traj2 - if len(traj2.keys) == 0: + if len(get_keys(traj2)) == 0: return traj1 - if traj1.keys != traj2.keys: + if get_keys(traj1) != get_keys(traj2): raise ValueError( - f"Trajectories must have the same keys to be appended, but got {traj1.keys} and {traj2.keys}." + f"Trajectories must have the same keys to be appended, but got {get_keys(traj1)} and {get_keys(traj2)}." ) result: Trajectory[T] = Trajectory() - for k in traj1.keys: + for k in get_keys(traj1): setattr(result, k, append(getattr(traj1, k), getattr(traj2, k))) return result @@ -83,4 +83,4 @@ def _var_order(varnames: FrozenSet[str]) -> Tuple[str, ...]: def _trajectory_to_state(traj: Trajectory[T]) -> State[T]: - return State(**{k: getattr(traj, k).squeeze(-1) for k in traj.keys}) + return State(**{k: getattr(traj, k).squeeze(-1) for k in get_keys(traj)}) diff --git a/chirho/dynamical/internals/backends/torchdiffeq.py b/chirho/dynamical/internals/backends/torchdiffeq.py index 52f0b4be..174d4565 100644 --- a/chirho/dynamical/internals/backends/torchdiffeq.py +++ b/chirho/dynamical/internals/backends/torchdiffeq.py @@ -16,7 +16,7 @@ simulate_point, simulate_trajectory, ) -from chirho.dynamical.ops import InPlaceDynamics, State, Trajectory +from chirho.dynamical.ops import InPlaceDynamics, State, Trajectory, get_keys from chirho.indexed.ops import IndexSet, gather, get_index_plates S = TypeVar("S") @@ -35,7 +35,7 @@ def _deriv( for var, value in zip(var_order, state): setattr(env, var, value) - assert "t" not in env.keys, "variable name t is reserved for time" + assert "t" not in get_keys(env), "variable name t is reserved for time" env.t = time dynamics.diff(ddt, env) @@ -48,7 +48,7 @@ def _torchdiffeq_ode_simulate_inner( timespan, **odeint_kwargs, ): - var_order = _var_order(initial_state.keys) # arbitrary, but fixed + var_order = _var_order(get_keys(initial_state)) # arbitrary, but fixed solns = _batched_odeint( # torchdiffeq.odeint( functools.partial(_deriv, dynamics, var_order), @@ -152,7 +152,7 @@ def torchdiffeq_get_next_interruptions_dynamic( dynamic_interruptions: List[DynamicInterruption], **kwargs, ) -> Tuple[Tuple[Interruption, ...], torch.Tensor]: - var_order = _var_order(start_state.keys) # arbitrary, but fixed + var_order = _var_order(get_keys(start_state)) # arbitrary, but fixed # Create the event function combining all dynamic events and the terminal (next) static interruption. combined_event_f = torchdiffeq_combined_event_f( diff --git a/chirho/dynamical/ops.py b/chirho/dynamical/ops.py index d30c1896..aed6380b 100644 --- a/chirho/dynamical/ops.py +++ b/chirho/dynamical/ops.py @@ -13,15 +13,10 @@ class State(Generic[T]): def __init__(self, **values: T): - # self.class_name = self.__dict__["_values"] = {} for k, v in values.items(): setattr(self, k, v) - @property - def keys(self) -> FrozenSet[str]: - return frozenset(self.__dict__["_values"].keys()) - def __repr__(self) -> str: return f"{self.__class__.__name__}({self.__dict__['_values']})" @@ -42,6 +37,10 @@ class Trajectory(Generic[T], State[T]): pass +def get_keys(state: State[T]) -> FrozenSet[str]: + return frozenset(state.__dict__["_values"].keys()) + + @typing.runtime_checkable class InPlaceDynamics(Protocol[S]): def diff(self, __dstate: State[S], __state: State[S]) -> None: diff --git a/docs/source/dynamical_intro.ipynb b/docs/source/dynamical_intro.ipynb index c6f0ff63..220a878a 100644 --- a/docs/source/dynamical_intro.ipynb +++ b/docs/source/dynamical_intro.ipynb @@ -37,7 +37,7 @@ " NonInterruptingPointObservationArray\n", ")\n", "from chirho.dynamical.handlers.interruption import _InterventionMixin\n", - "from chirho.dynamical.ops.dynamical import State, Trajectory, simulate\n", + "from chirho.dynamical.ops.dynamical import State, Trajectory, get_keys, simulate\n", "\n", "from chirho.dynamical.ops.ODE import ODEDynamics\n", "from chirho.dynamical.handlers.ODE.solvers import TorchDiffEq\n", @@ -413,7 +413,7 @@ " simulate(sir, init_state, logging_times[0], logging_times[-1] + 1e-3, solver=TorchDiffEq())\n", " trajectory = dt.trace\n", " # This is a small trick to make the solution variables available to pyro\n", - " [pyro.deterministic(k, getattr(trajectory, k)) for k in trajectory.keys]\n", + " [pyro.deterministic(k, getattr(trajectory, k)) for k in get_keys(trajectory)]\n", " return trajectory\n", "\n", "\n", @@ -891,7 +891,7 @@ " \n", " trajectory = dt.trace\n", " # This is a small trick to make the solution variables available to pyro\n", - " [pyro.deterministic(k, getattr(trajectory, k)) for k in trajectory.keys]\n", + " [pyro.deterministic(k, getattr(trajectory, k)) for k in get_keys(trajectory)]\n", " return trajectory" ] }, @@ -1251,7 +1251,7 @@ " trajectory = dt.trace\n", " \n", " # This is a small trick to make the solution variables available to pyro\n", - " [pyro.deterministic(k, getattr(trajectory, k)) for k in trajectory.keys]\n", + " [pyro.deterministic(k, getattr(trajectory, k)) for k in get_keys(trajectory)]\n", " return trajectory" ] }, @@ -1740,7 +1740,7 @@ " cf_traj = gather(trajectory, cf_indices, event_dim=0)\n", " \n", " # This is a small trick to make the trajectory variables available to pyro \n", - " for k in trajectory.keys:\n", + " for k in get_keys(trajectory):\n", " pyro.deterministic(k + '_factual', getattr(factual_traj, k))\n", " pyro.deterministic(k + '_cf', getattr(cf_traj, k))" ] @@ -2178,7 +2178,7 @@ " trajectory = dt.trace\n", " solutions.append(trajectory)\n", " # This is a small trick to make the trajectory variables available to pyro\n", - " [pyro.deterministic(f\"{k}_{unit_ix}\", getattr(trajectory, k))for k in trajectory.keys]\n", + " [pyro.deterministic(f\"{k}_{unit_ix}\", getattr(trajectory, k))for k in get_keys(trajectory)]\n", " return solutions\n", "\n", "\n", diff --git a/tests/dynamical/dynamical_fixtures.py b/tests/dynamical/dynamical_fixtures.py index aa6f49f1..8a7a54bb 100644 --- a/tests/dynamical/dynamical_fixtures.py +++ b/tests/dynamical/dynamical_fixtures.py @@ -4,7 +4,7 @@ import torch from pyro.distributions import Normal, Uniform, constraints -from chirho.dynamical.ops import InPlaceDynamics, State, Trajectory +from chirho.dynamical.ops import InPlaceDynamics, State, Trajectory, get_keys T = TypeVar("T") @@ -54,14 +54,14 @@ def bayes_sir_model(): def check_keys_match( obj1: Union[Trajectory[T], State[T]], obj2: Union[Trajectory[T], State[T]] ): - assert obj1.keys == obj2.keys, "Objects have different variables." + assert get_keys(obj1) == get_keys(obj2), "Objects have different variables." return True def check_trajectory_length_match( traj1: Trajectory[torch.tensor], traj2: Trajectory[torch.tensor] ): - for k in traj1.keys: + for k in get_keys(traj1): assert len(getattr(traj2, k)) == len( getattr(traj1, k) ), f"Trajectories have different lengths for variable {k}." @@ -75,7 +75,7 @@ def check_trajectories_match( assert check_trajectory_length_match(traj1, traj2) - for k in traj1.keys: + for k in get_keys(traj1): assert torch.allclose( getattr(traj2, k), getattr(traj1, k) ), f"Trajectories differ in state trajectory of variable {k}, but should be identical." @@ -86,7 +86,7 @@ def check_trajectories_match( def check_states_match(state1: State[torch.tensor], state2: State[torch.tensor]): assert check_keys_match(state1, state2) - for k in state1.keys: + for k in get_keys(state1): assert torch.allclose( getattr(state1, k), getattr(state2, k) ), f"Trajectories differ in state trajectory of variable {k}, but should be identical." @@ -101,7 +101,7 @@ def check_trajectories_match_in_all_but_values( assert check_trajectory_length_match(traj1, traj2) - for k in traj1.keys: + for k in get_keys(traj1): assert not torch.allclose( getattr(traj2, k), getattr(traj1, k) ), f"Trajectories are identical in state trajectory of variable {k}, but should differ." diff --git a/tests/dynamical/test_dynamic_interventions.py b/tests/dynamical/test_dynamic_interventions.py index d66e3804..15dc684f 100644 --- a/tests/dynamical/test_dynamic_interventions.py +++ b/tests/dynamical/test_dynamic_interventions.py @@ -14,7 +14,7 @@ LogTrajectory, ) from chirho.dynamical.handlers.solver import TorchDiffEq -from chirho.dynamical.ops import InPlaceDynamics, State, simulate +from chirho.dynamical.ops import InPlaceDynamics, State, get_keys, simulate from chirho.indexed.ops import IndexSet, gather, indices_of, union from .dynamical_fixtures import UnifiedFixtureDynamics @@ -269,7 +269,7 @@ def test_split_twinworld_dynamic_intervention( with cf: cf_trajectory = dt.trajectory - for k in cf_trajectory.keys: + for k in get_keys(cf_trajectory): # TODO: Figure out why event_dim=1 is not needed with cf_state but is with cf_trajectory. assert cf.default_name in indices_of(getattr(cf_state, k)) assert cf.default_name in indices_of(getattr(cf_trajectory, k), event_dim=1) @@ -317,7 +317,7 @@ def test_split_multiworld_dynamic_intervention( with cf: cf_trajectory = dt.trajectory - for k in cf_trajectory.keys: + for k in get_keys(cf_trajectory): # TODO: Figure out why event_dim=1 is not needed with cf_state but is with cf_trajectory. assert cf.default_name in indices_of(getattr(cf_state, k)) assert cf.default_name in indices_of(getattr(cf_trajectory, k), event_dim=1) @@ -388,15 +388,15 @@ def test_split_twinworld_dynamic_matches_output( assert not set(indices_of(cf_actual, event_dim=0)) assert not set(indices_of(factual_actual, event_dim=0)) - assert set(cf_result.keys) == set(cf_actual.keys) == set(cf_expected.keys) - assert set(cf_result.keys) == set(factual_actual.keys) == set(factual_expected.keys) + assert get_keys(cf_result) == get_keys(cf_actual) == get_keys(cf_expected) + assert get_keys(cf_result) == get_keys(factual_actual) == get_keys(factual_expected) - for k in cf_result.keys: + for k in get_keys(cf_result): assert torch.allclose( getattr(cf_actual, k), getattr(cf_expected, k), atol=1e-3, rtol=0 ), f"Trajectories differ in state result of variable {k}, but should be identical." - for k in cf_result.keys: + for k in get_keys(cf_result): assert torch.allclose( getattr(factual_actual, k), getattr(factual_expected, k), atol=1e-3, rtol=0 ), f"Trajectories differ in state result of variable {k}, but should be identical." diff --git a/tests/dynamical/test_log_trajectory.py b/tests/dynamical/test_log_trajectory.py index 2f2ab1d8..09ead634 100644 --- a/tests/dynamical/test_log_trajectory.py +++ b/tests/dynamical/test_log_trajectory.py @@ -6,7 +6,7 @@ from chirho.dynamical.handlers import InterruptionEventLoop, LogTrajectory from chirho.dynamical.handlers.solver import TorchDiffEq from chirho.dynamical.internals._utils import append -from chirho.dynamical.ops import State, Trajectory, simulate +from chirho.dynamical.ops import State, Trajectory, get_keys, simulate from .dynamical_fixtures import bayes_sir_model, check_states_match @@ -40,17 +40,17 @@ def test_logging(): assert isinstance(result1, State) assert isinstance(dt1.trajectory, Trajectory) assert isinstance(dt2.trajectory, Trajectory) - assert len(dt1.trajectory.keys) == 3 - assert len(dt2.trajectory.keys) == 3 - assert dt1.trajectory.keys == result1.keys - assert dt2.trajectory.keys == result2.keys + assert len(get_keys(dt1.trajectory)) == 3 + assert len(get_keys(dt2.trajectory)) == 3 + assert get_keys(dt1.trajectory) == get_keys(result1) + assert get_keys(dt2.trajectory) == get_keys(result2) assert check_states_match(result1, result2) assert check_states_match(result1, result3) def test_trajectory_methods(): trajectory = Trajectory(S=torch.tensor([1.0, 2.0, 3.0])) - assert trajectory.keys == frozenset({"S"}) + assert get_keys(trajectory) == frozenset({"S"}) assert str(trajectory) == "Trajectory({'S': tensor([1., 2., 3.])})" diff --git a/tests/dynamical/test_static_interventions.py b/tests/dynamical/test_static_interventions.py index f859fd06..5f52dd18 100644 --- a/tests/dynamical/test_static_interventions.py +++ b/tests/dynamical/test_static_interventions.py @@ -13,7 +13,7 @@ StaticIntervention, ) from chirho.dynamical.handlers.solver import TorchDiffEq -from chirho.dynamical.ops import State, simulate +from chirho.dynamical.ops import State, get_keys, simulate from chirho.indexed.ops import IndexSet, gather, indices_of from chirho.interventional.ops import intervene @@ -241,7 +241,7 @@ def test_twinworld_point_intervention( with cf: cf_trajectory = dt.trajectory - for k in cf_trajectory.keys: + for k in get_keys(cf_trajectory): # TODO: Figure out why event_dim=1 is not needed with cf_state but is with cf_trajectory. assert cf.default_name in indices_of(getattr(cf_state, k)) assert cf.default_name in indices_of(getattr(cf_trajectory, k), event_dim=1) @@ -276,7 +276,7 @@ def test_multiworld_point_intervention( with cf: cf_trajectory = dt.trajectory - for k in cf_trajectory.keys: + for k in get_keys(cf_trajectory): # TODO: Figure out why event_dim=1 is not needed with cf_state but is with cf_trajectory. assert cf.default_name in indices_of(getattr(cf_state, k)) assert cf.default_name in indices_of(getattr(cf_trajectory, k), event_dim=1) @@ -300,7 +300,7 @@ def test_split_odeint_broadcast( with cf: trajectory = dt.trajectory - for k in trajectory.keys: + for k in get_keys(trajectory): assert len(indices_of(getattr(trajectory, k), event_dim=1)) > 0 @@ -353,15 +353,15 @@ def test_twinworld_matches_output( assert not set(indices_of(cf_actual, event_dim=0)) assert not set(indices_of(factual_actual, event_dim=0)) - assert set(cf_state.keys) == set(cf_actual.keys) == set(cf_expected.keys) - assert set(cf_state.keys) == set(factual_actual.keys) == set(factual_expected.keys) + assert get_keys(cf_state) == get_keys(cf_actual) == get_keys(cf_expected) + assert get_keys(cf_state) == get_keys(factual_actual) == get_keys(factual_expected) - for k in cf_state.keys: + for k in get_keys(cf_state): assert torch.allclose( getattr(cf_actual, k), getattr(cf_expected, k) ), f"States differ in state trajectory of variable {k}, but should be identical." - for k in cf_state.keys: + for k in get_keys(cf_state): assert torch.allclose( getattr(factual_actual, k), getattr(factual_expected, k) ), f"States differ in state trajectory of variable {k}, but should be identical." From 39aa58b4defacb02ae337fce5da107f966390489 Mon Sep 17 00:00:00 2001 From: Sam Witty Date: Mon, 16 Oct 2023 15:26:05 -0400 Subject: [PATCH 2/2] Fix merge --- chirho/dynamical/handlers/interruption.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chirho/dynamical/handlers/interruption.py b/chirho/dynamical/handlers/interruption.py index 3cd2f00e..d67e3f92 100644 --- a/chirho/dynamical/handlers/interruption.py +++ b/chirho/dynamical/handlers/interruption.py @@ -6,7 +6,7 @@ import torch from chirho.dynamical.handlers.trajectory import LogTrajectory -from chirho.dynamical.ops import Observable, State +from chirho.dynamical.ops import Observable, State, get_keys from chirho.indexed.ops import get_index_plates, indices_of from chirho.interventional.ops import Intervention, intervene from chirho.observational.handlers import condition