diff --git a/chirho/dynamical/handlers/trace.py b/chirho/dynamical/handlers/trace.py index ece0c9c2..a31f7bb0 100644 --- a/chirho/dynamical/handlers/trace.py +++ b/chirho/dynamical/handlers/trace.py @@ -1,3 +1,4 @@ +import functools from typing import Generic, TypeVar import pyro @@ -9,7 +10,43 @@ T = TypeVar("T") +@functools.singledispatch +def append(fst, rest: T) -> T: + raise NotImplementedError(f"append not implemented for type {type(fst)}.") + + +@append.register(Trajectory) +def append_trajectory(traj1: Trajectory[T], traj2: Trajectory[T]) -> Trajectory[T]: + if len(traj1.keys) == 0: + return traj2 + + if len(traj2.keys) == 0: + return traj1 + + if traj1.keys != traj2.keys: + raise ValueError( + f"Trajectories must have the same keys to be appended, but got {traj1.keys} and {traj2.keys}." + ) + + result: Trajectory[T] = Trajectory() + for k in traj1.keys: + setattr(result, k, append(getattr(traj1, k), getattr(traj2, k))) + + return result + + +@append.register(torch.Tensor) +def append_tensor(prev_v: torch.Tensor, curr_v: torch.Tensor) -> torch.Tensor: + time_dim = -1 # TODO generalize to nontrivial event_shape + batch_shape = torch.broadcast_shapes(prev_v.shape[:-1], curr_v.shape[:-1]) + prev_v = prev_v.expand(*batch_shape, *prev_v.shape[-1:]) + curr_v = curr_v.expand(*batch_shape, *curr_v.shape[-1:]) + return torch.cat([prev_v, curr_v], dim=time_dim) + + class DynamicTrace(Generic[T], pyro.poutine.messenger.Messenger): + trace: Trajectory[T] + def __init__(self, logging_times: torch.Tensor, epsilon: float = 1e-6): # Adding epsilon to the logging times to avoid collision issues with the logging times being exactly on the # boundaries of the simulation times. This is a hack, but it's a hack that should work for now. @@ -23,8 +60,8 @@ def __init__(self, logging_times: torch.Tensor, epsilon: float = 1e-6): super().__init__() - def _reset(self): - self.trace = Trajectory() + def _reset(self) -> None: + self.trace: Trajectory[T] = Trajectory() def _pyro_simulate(self, msg) -> None: msg["done"] = True @@ -51,7 +88,7 @@ def _pyro_post_simulate(self, msg) -> None: initial_state, timespan, ) - self.trace.append(trajectory[..., 1:-1]) + self.trace: Trajectory[T] = append(self.trace, trajectory[..., 1:-1]) if len(self.trace) > len(self.logging_times): raise ValueError( "Multiple simulates were used with a single DynamicTrace handler." diff --git a/chirho/dynamical/ops/dynamical.py b/chirho/dynamical/ops/dynamical.py index 1e23b562..ea98c0ab 100644 --- a/chirho/dynamical/ops/dynamical.py +++ b/chirho/dynamical/ops/dynamical.py @@ -120,10 +120,6 @@ def _getitem_torchmask(self, key: torch.Tensor) -> "Trajectory[T]": return self._getitem(key) - @functools.singledispatchmethod - def append(self, other: T): - raise NotImplementedError(f"append not implemented for type {type(other)}") - def to_state(self) -> State[T]: ret: State[T] = State( # TODO support event_dim > 0 @@ -132,34 +128,6 @@ def to_state(self) -> State[T]: return ret -# TODO: figure out parameteric types of Trajectory. -# This used torch methods in supposedly generic class. -@Trajectory.append.register(Trajectory) # type: ignore -def _append_trajectory(self, other: Trajectory): - # If self is empty, just copy other. - if len(self.keys) == 0: - for k in other.keys: - setattr(self, k, getattr(other, k)) - return - - if self.keys != other.keys: - raise ValueError( - f"Trajectories must have the same keys to be appended, but got {self.keys} and {other.keys}." - ) - for k in self.keys: - prev_v = getattr(self, k) - curr_v = getattr(other, k) - time_dim = -1 # TODO generalize to nontrivial event_shape - batch_shape = torch.broadcast_shapes(prev_v.shape[:-1], curr_v.shape[:-1]) - prev_v = prev_v.expand(*batch_shape, *prev_v.shape[-1:]) - curr_v = curr_v.expand(*batch_shape, *curr_v.shape[-1:]) - setattr( - self, - k, - torch.cat([prev_v, curr_v], dim=time_dim), - ) - - @runtime_checkable class Dynamics(Protocol[S, T]): diff: Callable[[State[S], State[S]], T] diff --git a/tests/dynamical/test_state_trajectory.py b/tests/dynamical/test_state_trajectory.py index ed8f2293..3065a3a4 100644 --- a/tests/dynamical/test_state_trajectory.py +++ b/tests/dynamical/test_state_trajectory.py @@ -2,6 +2,7 @@ import torch +from chirho.dynamical.handlers.trace import append from chirho.dynamical.ops import Trajectory logger = logging.getLogger(__name__) @@ -16,7 +17,7 @@ def test_trajectory_methods(): def test_append(): trajectory1 = Trajectory(S=torch.tensor([1.0, 2.0, 3.0])) trajectory2 = Trajectory(S=torch.tensor([4.0, 5.0, 6.0])) - trajectory1.append(trajectory2) + trajectory = append(trajectory1, trajectory2) assert torch.allclose( - trajectory1.S, torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) - ), "Trajectory.append() failed to append a trajectory" + trajectory.S, torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + ), "append() failed to append a trajectory"