Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up Trajectory.append #303

Merged
merged 5 commits into from
Oct 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 40 additions & 3 deletions chirho/dynamical/handlers/trace.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
from typing import Generic, TypeVar

import pyro
Expand All @@ -9,7 +10,43 @@
T = TypeVar("T")


@functools.singledispatch
def append(fst, rest: T) -> T:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be append(fst: T, rest: T) -> T:?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, that doesn't work with singledispatch, which expects registered types to be subtypes of the original type on the first argument. It's an annoying aspect of the way singledispatch interacts with mypy.

raise NotImplementedError(f"append not implemented for type {type(fst)}.")


@append.register(Trajectory)
def append_trajectory(traj1: Trajectory[T], traj2: Trajectory[T]) -> Trajectory[T]:
if len(traj1.keys) == 0:
return traj2

if len(traj2.keys) == 0:
return traj1

if traj1.keys != traj2.keys:
raise ValueError(
f"Trajectories must have the same keys to be appended, but got {traj1.keys} and {traj2.keys}."
)

result: Trajectory[T] = Trajectory()
for k in traj1.keys:
setattr(result, k, append(getattr(traj1, k), getattr(traj2, k)))

return result


@append.register(torch.Tensor)
def append_tensor(prev_v: torch.Tensor, curr_v: torch.Tensor) -> torch.Tensor:
time_dim = -1 # TODO generalize to nontrivial event_shape
batch_shape = torch.broadcast_shapes(prev_v.shape[:-1], curr_v.shape[:-1])
prev_v = prev_v.expand(*batch_shape, *prev_v.shape[-1:])
curr_v = curr_v.expand(*batch_shape, *curr_v.shape[-1:])
return torch.cat([prev_v, curr_v], dim=time_dim)


class DynamicTrace(Generic[T], pyro.poutine.messenger.Messenger):
trace: Trajectory[T]

def __init__(self, logging_times: torch.Tensor, epsilon: float = 1e-6):
# Adding epsilon to the logging times to avoid collision issues with the logging times being exactly on the
# boundaries of the simulation times. This is a hack, but it's a hack that should work for now.
Expand All @@ -23,8 +60,8 @@ def __init__(self, logging_times: torch.Tensor, epsilon: float = 1e-6):

super().__init__()

def _reset(self):
self.trace = Trajectory()
def _reset(self) -> None:
self.trace: Trajectory[T] = Trajectory()

def _pyro_simulate(self, msg) -> None:
msg["done"] = True
Expand All @@ -51,7 +88,7 @@ def _pyro_post_simulate(self, msg) -> None:
initial_state,
timespan,
)
self.trace.append(trajectory[..., 1:-1])
self.trace: Trajectory[T] = append(self.trace, trajectory[..., 1:-1])
if len(self.trace) > len(self.logging_times):
raise ValueError(
"Multiple simulates were used with a single DynamicTrace handler."
Expand Down
32 changes: 0 additions & 32 deletions chirho/dynamical/ops/dynamical.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,6 @@ def _getitem_torchmask(self, key: torch.Tensor) -> "Trajectory[T]":

return self._getitem(key)

@functools.singledispatchmethod
def append(self, other: T):
raise NotImplementedError(f"append not implemented for type {type(other)}")

def to_state(self) -> State[T]:
ret: State[T] = State(
# TODO support event_dim > 0
Expand All @@ -132,34 +128,6 @@ def to_state(self) -> State[T]:
return ret


# TODO: figure out parameteric types of Trajectory.
# This used torch methods in supposedly generic class.
@Trajectory.append.register(Trajectory) # type: ignore
def _append_trajectory(self, other: Trajectory):
# If self is empty, just copy other.
if len(self.keys) == 0:
for k in other.keys:
setattr(self, k, getattr(other, k))
return

if self.keys != other.keys:
raise ValueError(
f"Trajectories must have the same keys to be appended, but got {self.keys} and {other.keys}."
)
for k in self.keys:
prev_v = getattr(self, k)
curr_v = getattr(other, k)
time_dim = -1 # TODO generalize to nontrivial event_shape
batch_shape = torch.broadcast_shapes(prev_v.shape[:-1], curr_v.shape[:-1])
prev_v = prev_v.expand(*batch_shape, *prev_v.shape[-1:])
curr_v = curr_v.expand(*batch_shape, *curr_v.shape[-1:])
setattr(
self,
k,
torch.cat([prev_v, curr_v], dim=time_dim),
)


@runtime_checkable
class Dynamics(Protocol[S, T]):
diff: Callable[[State[S], State[S]], T]
Expand Down
7 changes: 4 additions & 3 deletions tests/dynamical/test_state_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch

from chirho.dynamical.handlers.trace import append
from chirho.dynamical.ops import Trajectory

logger = logging.getLogger(__name__)
Expand All @@ -16,7 +17,7 @@ def test_trajectory_methods():
def test_append():
trajectory1 = Trajectory(S=torch.tensor([1.0, 2.0, 3.0]))
trajectory2 = Trajectory(S=torch.tensor([4.0, 5.0, 6.0]))
trajectory1.append(trajectory2)
trajectory = append(trajectory1, trajectory2)
assert torch.allclose(
trajectory1.S, torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
), "Trajectory.append() failed to append a trajectory"
trajectory.S, torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
), "append() failed to append a trajectory"