diff --git a/chirho/dynamical/handlers/interruption.py b/chirho/dynamical/handlers/interruption.py index ee084125..d67e3f92 100644 --- a/chirho/dynamical/handlers/interruption.py +++ b/chirho/dynamical/handlers/interruption.py @@ -6,7 +6,7 @@ import torch from chirho.dynamical.handlers.trajectory import LogTrajectory -from chirho.dynamical.ops import Observable, State +from chirho.dynamical.ops import Observable, State, get_keys from chirho.indexed.ops import get_index_plates, indices_of from chirho.interventional.ops import Intervention, intervene from chirho.observational.handlers import condition @@ -172,7 +172,7 @@ def _pyro_post_simulate(self, msg) -> None: name_to_dim["__time"] = -1 len_traj = ( 0 - if not self.trajectory.keys + if not get_keys(self.trajectory) else 1 + max(indices_of(self.trajectory, name_to_dim=name_to_dim)["__time"]) ) diff --git a/chirho/dynamical/internals/_utils.py b/chirho/dynamical/internals/_utils.py index 285ad8e9..fdcfaee7 100644 --- a/chirho/dynamical/internals/_utils.py +++ b/chirho/dynamical/internals/_utils.py @@ -3,7 +3,7 @@ import torch -from chirho.dynamical.ops import State +from chirho.dynamical.ops import State, get_keys from chirho.indexed.ops import IndexSet, gather, indices_of, union from chirho.interventional.handlers import intervene @@ -16,7 +16,7 @@ def _indices_of_state(state: State, *, event_dim: int = 0, **kwargs) -> IndexSet return union( *( indices_of(getattr(state, k), event_dim=event_dim, **kwargs) - for k in state.keys + for k in get_keys(state) ) ) @@ -28,7 +28,7 @@ def _gather_state( return type(state)( **{ k: gather(getattr(state, k), indices, event_dim=event_dim, **kwargs) - for k in state.keys + for k in get_keys(state) } ) @@ -36,7 +36,7 @@ def _gather_state( @intervene.register(State) def _state_intervene(obs: State[T], act: State[T], **kwargs) -> State[T]: new_state: State[T] = State() - for k in obs.keys: + for k in get_keys(obs): setattr( new_state, k, intervene(getattr(obs, k), getattr(act, k, None), **kwargs) ) @@ -50,19 +50,19 @@ def append(fst, rest: T) -> T: @append.register(State) def _append_trajectory(traj1: State[T], traj2: State[T]) -> State[T]: - if len(traj1.keys) == 0: + if len(get_keys(traj1)) == 0: return traj2 - if len(traj2.keys) == 0: + if len(get_keys(traj2)) == 0: return traj1 - if traj1.keys != traj2.keys: + if get_keys(traj1) != get_keys(traj2): raise ValueError( - f"Trajectories must have the same keys to be appended, but got {traj1.keys} and {traj2.keys}." + f"Trajectories must have the same keys to be appended, but got {get_keys(traj1)} and {get_keys(traj2)}." ) result: State[T] = State() - for k in traj1.keys: + for k in get_keys(traj1): setattr(result, k, append(getattr(traj1, k), getattr(traj2, k))) return result @@ -83,4 +83,4 @@ def _var_order(varnames: FrozenSet[str]) -> Tuple[str, ...]: def _squeeze_time_dim(traj: State[T]) -> State[T]: - return State(**{k: getattr(traj, k).squeeze(-1) for k in traj.keys}) + return State(**{k: getattr(traj, k).squeeze(-1) for k in get_keys(traj)}) diff --git a/chirho/dynamical/internals/backends/torchdiffeq.py b/chirho/dynamical/internals/backends/torchdiffeq.py index d3ba12ce..796333f3 100644 --- a/chirho/dynamical/internals/backends/torchdiffeq.py +++ b/chirho/dynamical/internals/backends/torchdiffeq.py @@ -16,7 +16,7 @@ simulate_point, simulate_trajectory, ) -from chirho.dynamical.ops import InPlaceDynamics, State +from chirho.dynamical.ops import InPlaceDynamics, State, get_keys from chirho.indexed.ops import IndexSet, gather, get_index_plates S = TypeVar("S") @@ -35,7 +35,7 @@ def _deriv( for var, value in zip(var_order, state): setattr(env, var, value) - assert "t" not in env.keys, "variable name t is reserved for time" + assert "t" not in get_keys(env), "variable name t is reserved for time" env.t = time dynamics.diff(ddt, env) @@ -48,7 +48,7 @@ def _torchdiffeq_ode_simulate_inner( timespan, **odeint_kwargs, ): - var_order = _var_order(initial_state.keys) # arbitrary, but fixed + var_order = _var_order(get_keys(initial_state)) # arbitrary, but fixed solns = _batched_odeint( # torchdiffeq.odeint( functools.partial(_deriv, dynamics, var_order), @@ -152,7 +152,7 @@ def torchdiffeq_get_next_interruptions_dynamic( dynamic_interruptions: List[DynamicInterruption], **kwargs, ) -> Tuple[Tuple[Interruption, ...], torch.Tensor]: - var_order = _var_order(start_state.keys) # arbitrary, but fixed + var_order = _var_order(get_keys(start_state)) # arbitrary, but fixed # Create the event function combining all dynamic events and the terminal (next) static interruption. combined_event_f = torchdiffeq_combined_event_f( diff --git a/chirho/dynamical/ops.py b/chirho/dynamical/ops.py index 906a9b3a..3022b671 100644 --- a/chirho/dynamical/ops.py +++ b/chirho/dynamical/ops.py @@ -12,15 +12,10 @@ class State(Generic[T]): def __init__(self, **values: T): - # self.class_name = self.__dict__["_values"] = {} for k, v in values.items(): setattr(self, k, v) - @property - def keys(self) -> FrozenSet[str]: - return frozenset(self.__dict__["_values"].keys()) - def __repr__(self) -> str: return f"{self.__class__.__name__}({self.__dict__['_values']})" @@ -37,6 +32,10 @@ def __getattr__(self, __name: str) -> T: raise AttributeError(f"{__name} not in {self.__dict__['_values']}") +def get_keys(state: State[T]) -> FrozenSet[str]: + return frozenset(state.__dict__["_values"].keys()) + + @typing.runtime_checkable class Observable(Protocol[S]): def observation(self, __state: State[S]) -> None: diff --git a/docs/source/dynamical_intro.ipynb b/docs/source/dynamical_intro.ipynb index c6f0ff63..220a878a 100644 --- a/docs/source/dynamical_intro.ipynb +++ b/docs/source/dynamical_intro.ipynb @@ -37,7 +37,7 @@ " NonInterruptingPointObservationArray\n", ")\n", "from chirho.dynamical.handlers.interruption import _InterventionMixin\n", - "from chirho.dynamical.ops.dynamical import State, Trajectory, simulate\n", + "from chirho.dynamical.ops.dynamical import State, Trajectory, get_keys, simulate\n", "\n", "from chirho.dynamical.ops.ODE import ODEDynamics\n", "from chirho.dynamical.handlers.ODE.solvers import TorchDiffEq\n", @@ -413,7 +413,7 @@ " simulate(sir, init_state, logging_times[0], logging_times[-1] + 1e-3, solver=TorchDiffEq())\n", " trajectory = dt.trace\n", " # This is a small trick to make the solution variables available to pyro\n", - " [pyro.deterministic(k, getattr(trajectory, k)) for k in trajectory.keys]\n", + " [pyro.deterministic(k, getattr(trajectory, k)) for k in get_keys(trajectory)]\n", " return trajectory\n", "\n", "\n", @@ -891,7 +891,7 @@ " \n", " trajectory = dt.trace\n", " # This is a small trick to make the solution variables available to pyro\n", - " [pyro.deterministic(k, getattr(trajectory, k)) for k in trajectory.keys]\n", + " [pyro.deterministic(k, getattr(trajectory, k)) for k in get_keys(trajectory)]\n", " return trajectory" ] }, @@ -1251,7 +1251,7 @@ " trajectory = dt.trace\n", " \n", " # This is a small trick to make the solution variables available to pyro\n", - " [pyro.deterministic(k, getattr(trajectory, k)) for k in trajectory.keys]\n", + " [pyro.deterministic(k, getattr(trajectory, k)) for k in get_keys(trajectory)]\n", " return trajectory" ] }, @@ -1740,7 +1740,7 @@ " cf_traj = gather(trajectory, cf_indices, event_dim=0)\n", " \n", " # This is a small trick to make the trajectory variables available to pyro \n", - " for k in trajectory.keys:\n", + " for k in get_keys(trajectory):\n", " pyro.deterministic(k + '_factual', getattr(factual_traj, k))\n", " pyro.deterministic(k + '_cf', getattr(cf_traj, k))" ] @@ -2178,7 +2178,7 @@ " trajectory = dt.trace\n", " solutions.append(trajectory)\n", " # This is a small trick to make the trajectory variables available to pyro\n", - " [pyro.deterministic(f\"{k}_{unit_ix}\", getattr(trajectory, k))for k in trajectory.keys]\n", + " [pyro.deterministic(f\"{k}_{unit_ix}\", getattr(trajectory, k))for k in get_keys(trajectory)]\n", " return solutions\n", "\n", "\n", diff --git a/tests/dynamical/dynamical_fixtures.py b/tests/dynamical/dynamical_fixtures.py index 7ae9ff1f..e53ae4eb 100644 --- a/tests/dynamical/dynamical_fixtures.py +++ b/tests/dynamical/dynamical_fixtures.py @@ -4,7 +4,7 @@ import torch from pyro.distributions import Normal, Uniform, constraints -from chirho.dynamical.ops import State +from chirho.dynamical.ops import State, get_keys T = TypeVar("T") @@ -52,14 +52,14 @@ def bayes_sir_model(): def check_keys_match(obj1: State[T], obj2: State[T]): - assert obj1.keys == obj2.keys, "Objects have different variables." + assert get_keys(obj1) == get_keys(obj2), "Objects have different variables." return True def check_states_match(state1: State[torch.Tensor], state2: State[torch.Tensor]): assert check_keys_match(state1, state2) - for k in state1.keys: + for k in get_keys(state1): assert torch.allclose( getattr(state1, k), getattr(state2, k) ), f"Trajectories differ in state trajectory of variable {k}, but should be identical." @@ -72,7 +72,7 @@ def check_trajectories_match_in_all_but_values( ): assert check_keys_match(traj1, traj2) - for k in traj1.keys: + for k in get_keys(traj1): assert not torch.allclose( getattr(traj2, k), getattr(traj1, k) ), f"Trajectories are identical in state trajectory of variable {k}, but should differ." diff --git a/tests/dynamical/test_dynamic_interventions.py b/tests/dynamical/test_dynamic_interventions.py index d66e3804..15dc684f 100644 --- a/tests/dynamical/test_dynamic_interventions.py +++ b/tests/dynamical/test_dynamic_interventions.py @@ -14,7 +14,7 @@ LogTrajectory, ) from chirho.dynamical.handlers.solver import TorchDiffEq -from chirho.dynamical.ops import InPlaceDynamics, State, simulate +from chirho.dynamical.ops import InPlaceDynamics, State, get_keys, simulate from chirho.indexed.ops import IndexSet, gather, indices_of, union from .dynamical_fixtures import UnifiedFixtureDynamics @@ -269,7 +269,7 @@ def test_split_twinworld_dynamic_intervention( with cf: cf_trajectory = dt.trajectory - for k in cf_trajectory.keys: + for k in get_keys(cf_trajectory): # TODO: Figure out why event_dim=1 is not needed with cf_state but is with cf_trajectory. assert cf.default_name in indices_of(getattr(cf_state, k)) assert cf.default_name in indices_of(getattr(cf_trajectory, k), event_dim=1) @@ -317,7 +317,7 @@ def test_split_multiworld_dynamic_intervention( with cf: cf_trajectory = dt.trajectory - for k in cf_trajectory.keys: + for k in get_keys(cf_trajectory): # TODO: Figure out why event_dim=1 is not needed with cf_state but is with cf_trajectory. assert cf.default_name in indices_of(getattr(cf_state, k)) assert cf.default_name in indices_of(getattr(cf_trajectory, k), event_dim=1) @@ -388,15 +388,15 @@ def test_split_twinworld_dynamic_matches_output( assert not set(indices_of(cf_actual, event_dim=0)) assert not set(indices_of(factual_actual, event_dim=0)) - assert set(cf_result.keys) == set(cf_actual.keys) == set(cf_expected.keys) - assert set(cf_result.keys) == set(factual_actual.keys) == set(factual_expected.keys) + assert get_keys(cf_result) == get_keys(cf_actual) == get_keys(cf_expected) + assert get_keys(cf_result) == get_keys(factual_actual) == get_keys(factual_expected) - for k in cf_result.keys: + for k in get_keys(cf_result): assert torch.allclose( getattr(cf_actual, k), getattr(cf_expected, k), atol=1e-3, rtol=0 ), f"Trajectories differ in state result of variable {k}, but should be identical." - for k in cf_result.keys: + for k in get_keys(cf_result): assert torch.allclose( getattr(factual_actual, k), getattr(factual_expected, k), atol=1e-3, rtol=0 ), f"Trajectories differ in state result of variable {k}, but should be identical." diff --git a/tests/dynamical/test_log_trajectory.py b/tests/dynamical/test_log_trajectory.py index 6bc34d9a..83ea80f3 100644 --- a/tests/dynamical/test_log_trajectory.py +++ b/tests/dynamical/test_log_trajectory.py @@ -6,7 +6,7 @@ from chirho.dynamical.handlers import InterruptionEventLoop, LogTrajectory from chirho.dynamical.handlers.solver import TorchDiffEq from chirho.dynamical.internals._utils import append -from chirho.dynamical.ops import State, simulate +from chirho.dynamical.ops import State, get_keys, simulate from .dynamical_fixtures import bayes_sir_model, check_states_match @@ -40,17 +40,17 @@ def test_logging(): assert isinstance(result1, State) assert isinstance(dt1.trajectory, State) assert isinstance(dt2.trajectory, State) - assert len(dt1.trajectory.keys) == 3 - assert len(dt2.trajectory.keys) == 3 - assert dt1.trajectory.keys == result1.keys - assert dt2.trajectory.keys == result2.keys + assert len(get_keys(dt1.trajectory)) == 3 + assert len(get_keys(dt2.trajectory)) == 3 + assert get_keys(dt1.trajectory) == get_keys(result1) + assert get_keys(dt2.trajectory) == get_keys(result2) assert check_states_match(result1, result2) assert check_states_match(result1, result3) def test_trajectory_methods(): trajectory = State(S=torch.tensor([1.0, 2.0, 3.0])) - assert trajectory.keys == frozenset({"S"}) + assert get_keys(trajectory) == frozenset({"S"}) assert str(trajectory) == "State({'S': tensor([1., 2., 3.])})" diff --git a/tests/dynamical/test_static_interventions.py b/tests/dynamical/test_static_interventions.py index 920321a7..61cfb73c 100644 --- a/tests/dynamical/test_static_interventions.py +++ b/tests/dynamical/test_static_interventions.py @@ -13,7 +13,7 @@ StaticIntervention, ) from chirho.dynamical.handlers.solver import TorchDiffEq -from chirho.dynamical.ops import State, simulate +from chirho.dynamical.ops import State, get_keys, simulate from chirho.indexed.ops import IndexSet, gather, indices_of from chirho.interventional.ops import intervene @@ -241,7 +241,7 @@ def test_twinworld_point_intervention( with cf: cf_trajectory = dt.trajectory - for k in cf_trajectory.keys: + for k in get_keys(cf_trajectory): # TODO: Figure out why event_dim=1 is not needed with cf_state but is with cf_trajectory. assert cf.default_name in indices_of(getattr(cf_state, k)) assert cf.default_name in indices_of(getattr(cf_trajectory, k), event_dim=1) @@ -276,7 +276,7 @@ def test_multiworld_point_intervention( with cf: cf_trajectory = dt.trajectory - for k in cf_trajectory.keys: + for k in get_keys(cf_trajectory): # TODO: Figure out why event_dim=1 is not needed with cf_state but is with cf_trajectory. assert cf.default_name in indices_of(getattr(cf_state, k)) assert cf.default_name in indices_of(getattr(cf_trajectory, k), event_dim=1) @@ -300,7 +300,7 @@ def test_split_odeint_broadcast( with cf: trajectory = dt.trajectory - for k in trajectory.keys: + for k in get_keys(trajectory): assert len(indices_of(getattr(trajectory, k), event_dim=1)) > 0 @@ -353,15 +353,15 @@ def test_twinworld_matches_output( assert not set(indices_of(cf_actual, event_dim=0)) assert not set(indices_of(factual_actual, event_dim=0)) - assert set(cf_state.keys) == set(cf_actual.keys) == set(cf_expected.keys) - assert set(cf_state.keys) == set(factual_actual.keys) == set(factual_expected.keys) + assert get_keys(cf_state) == get_keys(cf_actual) == get_keys(cf_expected) + assert get_keys(cf_state) == get_keys(factual_actual) == get_keys(factual_expected) - for k in cf_state.keys: + for k in get_keys(cf_state): assert torch.allclose( getattr(cf_actual, k), getattr(cf_expected, k) ), f"States differ in state trajectory of variable {k}, but should be identical." - for k in cf_state.keys: + for k in get_keys(cf_state): assert torch.allclose( getattr(factual_actual, k), getattr(factual_expected, k) ), f"States differ in state trajectory of variable {k}, but should be identical."