Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Trajectory type #327

Merged
merged 16 commits into from
Oct 16, 2023
Merged
8 changes: 4 additions & 4 deletions chirho/dynamical/handlers/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

from chirho.dynamical.internals._utils import _trajectory_to_state, append
from chirho.dynamical.internals.solver import Solver, get_solver, simulate_trajectory
from chirho.dynamical.ops import Trajectory
from chirho.dynamical.ops import State
from chirho.indexed.ops import IndexSet, gather, get_index_plates

T = TypeVar("T")


class LogTrajectory(Generic[T], pyro.poutine.messenger.Messenger):
trajectory: Trajectory[T]
trajectory: State[T]

def __init__(self, times: torch.Tensor, *, eps: float = 1e-6):
# Adding epsilon to the logging times to avoid collision issues with the logging times being exactly on the
Expand All @@ -27,7 +27,7 @@ def __init__(self, times: torch.Tensor, *, eps: float = 1e-6):
super().__init__()

def __enter__(self) -> "LogTrajectory[T]":
self.trajectory: Trajectory[T] = Trajectory()
self.trajectory: State[T] = State()
return super().__enter__()

def _pyro_simulate(self, msg) -> None:
Expand Down Expand Up @@ -63,7 +63,7 @@ def _pyro_post_simulate(self, msg) -> None:
if len(timespan) > 2:
part_idx = IndexSet(**{idx_name: set(range(1, len(timespan) - 1))})
new_part = gather(trajectory, part_idx, name_to_dim=name_to_dim)
self.trajectory: Trajectory[T] = append(self.trajectory, new_part)
self.trajectory: State[T] = append(self.trajectory, new_part)

final_idx = IndexSet(**{idx_name: {len(timespan) - 1}})
final_state = gather(trajectory, final_idx, name_to_dim=name_to_dim)
Expand Down
10 changes: 5 additions & 5 deletions chirho/dynamical/internals/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from chirho.dynamical.ops import State, Trajectory
from chirho.dynamical.ops import State
from chirho.indexed.ops import IndexSet, gather, indices_of, union
from chirho.interventional.handlers import intervene

Expand Down Expand Up @@ -48,8 +48,8 @@ def append(fst, rest: T) -> T:
raise NotImplementedError(f"append not implemented for type {type(fst)}.")


@append.register(Trajectory)
def _append_trajectory(traj1: Trajectory[T], traj2: Trajectory[T]) -> Trajectory[T]:
SamWitty marked this conversation as resolved.
Show resolved Hide resolved
@append.register(State)
def _append_trajectory(traj1: State[T], traj2: State[T]) -> State[T]:
if len(traj1.keys) == 0:
return traj2

Expand All @@ -61,7 +61,7 @@ def _append_trajectory(traj1: Trajectory[T], traj2: Trajectory[T]) -> Trajectory
f"Trajectories must have the same keys to be appended, but got {traj1.keys} and {traj2.keys}."
)

result: Trajectory[T] = Trajectory()
result: State[T] = State()
for k in traj1.keys:
setattr(result, k, append(getattr(traj1, k), getattr(traj2, k)))

Expand All @@ -82,5 +82,5 @@ def _var_order(varnames: FrozenSet[str]) -> Tuple[str, ...]:
return tuple(sorted(varnames))


def _trajectory_to_state(traj: Trajectory[T]) -> State[T]:
eb8680 marked this conversation as resolved.
Show resolved Hide resolved
def _trajectory_to_state(traj: State[T]) -> State[T]:
return State(**{k: getattr(traj, k).squeeze(-1) for k in traj.keys})
6 changes: 3 additions & 3 deletions chirho/dynamical/internals/backends/torchdiffeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
simulate_point,
simulate_trajectory,
)
from chirho.dynamical.ops import InPlaceDynamics, State, Trajectory
from chirho.dynamical.ops import InPlaceDynamics, State
from chirho.indexed.ops import IndexSet, gather, get_index_plates

S = TypeVar("S")
Expand Down Expand Up @@ -57,7 +57,7 @@ def _torchdiffeq_ode_simulate_inner(
**odeint_kwargs,
)

trajectory: Trajectory[torch.Tensor] = Trajectory()
trajectory: State[torch.Tensor] = State()
for var, soln in zip(var_order, solns):
setattr(trajectory, var, soln)

Expand Down Expand Up @@ -136,7 +136,7 @@ def torchdiffeq_ode_simulate_trajectory(
dynamics: InPlaceDynamics[torch.Tensor],
initial_state: State[torch.Tensor],
timespan: torch.Tensor,
) -> Trajectory[torch.Tensor]:
) -> State[torch.Tensor]:
return _torchdiffeq_ode_simulate_inner(
dynamics, initial_state, timespan, **solver.odeint_kwargs
)
Expand Down
4 changes: 2 additions & 2 deletions chirho/dynamical/internals/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pyro
import torch

from chirho.dynamical.ops import InPlaceDynamics, State, Trajectory, simulate
from chirho.dynamical.ops import InPlaceDynamics, State, simulate

if typing.TYPE_CHECKING:
from chirho.dynamical.handlers.interruption import (
Expand Down Expand Up @@ -63,7 +63,7 @@ def simulate_trajectory(
initial_state: State[T],
timespan: R,
**kwargs,
) -> Trajectory[T]:
) -> State[T]:
"""
Simulate a dynamical system.
"""
Expand Down
7 changes: 1 addition & 6 deletions chirho/dynamical/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
R = Union[numbers.Real, torch.Tensor]
S = TypeVar("S")
T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)


class State(Generic[T]):
Expand Down Expand Up @@ -38,10 +37,6 @@ def __getattr__(self, __name: str) -> T:
raise AttributeError(f"{__name} not in {self.__dict__['_values']}")


class Trajectory(Generic[T], State[T]):
pass


@typing.runtime_checkable
class InPlaceDynamics(Protocol[S]):
def diff(self, __dstate: State[S], __state: State[S]) -> None:
Expand All @@ -53,7 +48,7 @@ class ObservableInPlaceDynamics(InPlaceDynamics[S], Protocol[S]):
def diff(self, __dstate: State[S], __state: State[S]) -> None:
...

def observation(self, __state: Union[State[S], Trajectory[S]]) -> None:
def observation(self, __state: State[S]) -> None:
...


Expand Down
20 changes: 8 additions & 12 deletions tests/dynamical/dynamical_fixtures.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import TypeVar, Union
from typing import TypeVar

import pyro
import torch
from pyro.distributions import Normal, Uniform, constraints

from chirho.dynamical.ops import InPlaceDynamics, State, Trajectory
from chirho.dynamical.ops import InPlaceDynamics, State

T = TypeVar("T")

Expand All @@ -30,7 +30,7 @@ def diff(self, dX: State[torch.Tensor], X: State[torch.Tensor]):
dX.I = beta * X.S * X.I - self.gamma * X.I # noqa
dX.R = self.gamma * X.I

def _unit_measurement_error(self, name: str, x: torch.tensor):
def _unit_measurement_error(self, name: str, x: torch.Tensor):
if x.ndim == 0:
return pyro.sample(name, Normal(x, 1))
else:
Expand All @@ -51,15 +51,13 @@ def bayes_sir_model():
return sir


def check_keys_match(
obj1: Union[Trajectory[T], State[T]], obj2: Union[Trajectory[T], State[T]]
):
def check_keys_match(obj1: State[T], obj2: State[T]):
assert obj1.keys == obj2.keys, "Objects have different variables."
return True


def check_trajectory_length_match(
traj1: Trajectory[torch.tensor], traj2: Trajectory[torch.tensor]
traj1: State[torch.Tensor], traj2: State[torch.Tensor]
):
for k in traj1.keys:
assert len(getattr(traj2, k)) == len(
Expand All @@ -68,9 +66,7 @@ def check_trajectory_length_match(
return True


def check_trajectories_match(
traj1: Trajectory[torch.tensor], traj2: Trajectory[torch.tensor]
):
def check_trajectories_match(traj1: State[torch.Tensor], traj2: State[torch.Tensor]):
eb8680 marked this conversation as resolved.
Show resolved Hide resolved
assert check_keys_match(traj1, traj2)

assert check_trajectory_length_match(traj1, traj2)
Expand All @@ -83,7 +79,7 @@ def check_trajectories_match(
return True


def check_states_match(state1: State[torch.tensor], state2: State[torch.tensor]):
def check_states_match(state1: State[torch.Tensor], state2: State[torch.Tensor]):
assert check_keys_match(state1, state2)

for k in state1.keys:
Expand All @@ -95,7 +91,7 @@ def check_states_match(state1: State[torch.tensor], state2: State[torch.tensor])


def check_trajectories_match_in_all_but_values(
traj1: Trajectory[torch.tensor], traj2: Trajectory[torch.tensor]
traj1: State[torch.Tensor], traj2: State[torch.Tensor]
):
assert check_keys_match(traj1, traj2)

Expand Down
14 changes: 7 additions & 7 deletions tests/dynamical/test_log_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from chirho.dynamical.handlers import InterruptionEventLoop, LogTrajectory
from chirho.dynamical.handlers.solver import TorchDiffEq
from chirho.dynamical.internals._utils import append
from chirho.dynamical.ops import State, Trajectory, simulate
from chirho.dynamical.ops import State, simulate

from .dynamical_fixtures import bayes_sir_model, check_states_match

Expand Down Expand Up @@ -38,8 +38,8 @@ def test_logging():
result3 = simulate(sir, init_state, start_time, end_time, solver=TorchDiffEq())

assert isinstance(result1, State)
assert isinstance(dt1.trajectory, Trajectory)
assert isinstance(dt2.trajectory, Trajectory)
assert isinstance(dt1.trajectory, State)
assert isinstance(dt2.trajectory, State)
assert len(dt1.trajectory.keys) == 3
assert len(dt2.trajectory.keys) == 3
assert dt1.trajectory.keys == result1.keys
Expand All @@ -49,14 +49,14 @@ def test_logging():


def test_trajectory_methods():
trajectory = Trajectory(S=torch.tensor([1.0, 2.0, 3.0]))
trajectory = State(S=torch.tensor([1.0, 2.0, 3.0]))
assert trajectory.keys == frozenset({"S"})
assert str(trajectory) == "Trajectory({'S': tensor([1., 2., 3.])})"
assert str(trajectory) == "State({'S': tensor([1., 2., 3.])})"


def test_append():
trajectory1 = Trajectory(S=torch.tensor([1.0, 2.0, 3.0]))
trajectory2 = Trajectory(S=torch.tensor([4.0, 5.0, 6.0]))
trajectory1 = State(S=torch.tensor([1.0, 2.0, 3.0]))
trajectory2 = State(S=torch.tensor([4.0, 5.0, 6.0]))
trajectory = append(trajectory1, trajectory2)
assert torch.allclose(
trajectory.S, torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
Expand Down