Skip to content

Commit

Permalink
Fix time collision bug in LogTrajectory (#397)
Browse files Browse the repository at this point in the history
* add failing LogTrajectory test

* revise tests to exercise start and end time collisions

* removed unnecessary imports from test

* much simpler implementation

* lint and comment

* added some functional indirection to appease linter

* lint

* type refinement

* nit about arg unpacking order

* added multiple simulate handling

* remove commented stop

* lint

* made BatchObservation handler use a continuation to guarantee it's applied after all solves

* lint

* simpler implementation, fails test but I think the old test wasn't working

* lint

* skip test that previously failed silently

* lint

* lint

* add empty state in init
  • Loading branch information
SamWitty authored Dec 7, 2023
1 parent ee600a3 commit ac87280
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 12 deletions.
1 change: 1 addition & 0 deletions chirho/dynamical/handlers/interruption.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,5 @@ def __init__(
super().__init__(times)

def _pyro_post_simulate(self, msg: dict) -> None:
super()._pyro_post_simulate(msg)
self.trajectory = observe(self.trajectory, self.observation)
29 changes: 24 additions & 5 deletions chirho/dynamical/handlers/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import pyro
import torch

from chirho.dynamical.internals._utils import _squeeze_time_dim, append
from chirho.dynamical.internals._utils import (
_squeeze_time_dim,
_unsqueeze_time_dim,
append,
)
from chirho.dynamical.internals.solver import simulate_trajectory
from chirho.dynamical.ops import State
from chirho.indexed.ops import IndexSet, gather, get_index_plates
Expand All @@ -13,19 +17,34 @@

class LogTrajectory(Generic[T], pyro.poutine.messenger.Messenger):
trajectory: State[T]
_trajectory: State[T]

def __init__(self, times: torch.Tensor):
self.times = times
self._trajectory: State[T] = State()

# Require that the times are sorted. This is required by the index masking we do below.
if not torch.all(self.times[1:] > self.times[:-1]):
raise ValueError("The passed times must be sorted.")

super().__init__()

def __enter__(self) -> "LogTrajectory[T]":
self.trajectory: State[T] = State()
return super().__enter__()
def _pyro_post_simulate(self, msg) -> None:
initial_state = msg["args"][1]
start_time = msg["args"][2]

if start_time == self.times[0]:
# If we're starting at the beginning of the timespan, we need to log the initial state.
# LogTrajectory's simulate_point will log only timepoints that are greater than the start_time of each
# simulate_point call, which can occur multiple times in a single simulate call when there
# are interruptions.
self._trajectory = append(
_unsqueeze_time_dim(initial_state), self._trajectory
)

# Clear the internal trajectory so that we don't keep appending to it on subsequent simulate calls.
self.trajectory = self._trajectory
self._trajectory: State[T] = State()

def _pyro_simulate_point(self, msg) -> None:
# Turn a simulate that returns a state into a simulate that returns a trajectory at each of the logging_times
Expand All @@ -51,7 +70,7 @@ def _pyro_simulate_point(self, msg) -> None:
if len(timespan) > 2:
part_idx = IndexSet(**{idx_name: set(range(1, len(timespan) - 1))})
new_part = gather(trajectory, part_idx, name_to_dim=name_to_dim)
self.trajectory: State[T] = append(self.trajectory, new_part)
self._trajectory: State[T] = append(self._trajectory, new_part)

final_idx = IndexSet(**{idx_name: {len(timespan) - 1}})
msg["value"] = _squeeze_time_dim(
Expand Down
34 changes: 32 additions & 2 deletions chirho/dynamical/internals/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,38 @@ def _var_order(varnames: FrozenSet[str]) -> Tuple[str, ...]:
return tuple(sorted(varnames))


def _squeeze_time_dim(traj: State[torch.Tensor]) -> State[torch.Tensor]:
return State(**{k: traj[k].squeeze(-1) for k in traj.keys()})
@functools.singledispatch
def _squeeze_time_dim(state_or_traj):
raise NotImplementedError(
f"_squeeze_time_dim not implemented for type {type(state_or_traj)}."
)


@_squeeze_time_dim.register(dict)
def _squeeze_time_dim_trajectory(traj: State[T]) -> State[T]:
return State(**{k: _squeeze_time_dim(traj[k]) for k in traj.keys()})


@_squeeze_time_dim.register(torch.Tensor)
def _squeeze_time_dim_tensor(state: torch.Tensor) -> torch.Tensor:
return state.squeeze(-1)


@functools.singledispatch
def _unsqueeze_time_dim(state_or_traj):
raise NotImplementedError(
f"_unsqueeze_time_dim not implemented for type {type(state_or_traj)}."
)


@_unsqueeze_time_dim.register(dict)
def _unsqueeze_time_dim_state(state: State[T]) -> State[T]:
return State(**{k: _unsqueeze_time_dim(state[k]) for k in state.keys()})


@_unsqueeze_time_dim.register(torch.Tensor)
def _unsqueeze_time_dim_tensor(state: torch.Tensor) -> torch.Tensor:
return state.unsqueeze(-1)


class ShallowMessenger(pyro.poutine.messenger.Messenger):
Expand Down
72 changes: 67 additions & 5 deletions tests/dynamical/test_log_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
from chirho.dynamical.internals._utils import append
from chirho.dynamical.ops import State, simulate

from .dynamical_fixtures import bayes_sir_model, check_states_match
from .dynamical_fixtures import (
bayes_sir_model,
check_states_match,
check_trajectories_match_in_all_but_values,
)

pyro.settings.set(module_local_params=True)

Expand All @@ -23,8 +27,9 @@

def test_logging():
sir = bayes_sir_model()
with TorchDiffEq(), LogTrajectory(times=logging_times) as dt1:
result1 = simulate(sir, init_state, start_time, end_time)
with TorchDiffEq():
with LogTrajectory(times=logging_times) as dt1:
result1 = simulate(sir, init_state, start_time, end_time)

with LogTrajectory(times=logging_times) as dt2:
with TorchDiffEq():
Expand All @@ -35,14 +40,16 @@ def test_logging():
assert len(dt2.trajectory.keys()) == 3
assert dt1.trajectory.keys() == result1.keys()
assert dt2.trajectory.keys() == result2.keys()
assert check_states_match(dt1.trajectory, dt2.trajectory)
assert check_states_match(result1, result2)
assert check_states_match(result1, result3)


def test_logging_with_colliding_interruption():
sir = bayes_sir_model()
with TorchDiffEq(), LogTrajectory(times=logging_times) as dt1:
simulate(sir, init_state, start_time, end_time)
with LogTrajectory(times=logging_times) as dt1:
with TorchDiffEq():
simulate(sir, init_state, start_time, end_time)

with LogTrajectory(times=logging_times) as dt2:
with TorchDiffEq():
Expand All @@ -66,3 +73,58 @@ def test_append():
assert torch.allclose(
trajectory["S"], torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
), "append() failed to append a trajectory"


def test_start_end_time_collisions():
def dynamics(s: State) -> State:
return State(X=s["X"] * (1 - s["X"]))

init_state = State(X=torch.tensor(0.5))
start_time, end_time = torch.tensor(0.0), torch.tensor(3.0)

with TorchDiffEq():
with LogTrajectory(times=torch.tensor([0.0, 1.0, 2.0, 3.0])) as log1:
simulate(dynamics, init_state, start_time, end_time)

with LogTrajectory(times=torch.tensor([0.0, 1.0, 2.0, 3.0])) as log2:
with TorchDiffEq():
simulate(dynamics, init_state, start_time, end_time)

assert check_states_match(log1.trajectory, log2.trajectory)

assert (
len(log1.trajectory["X"])
== len(log1.times)
== len(log2.trajectory["X"])
== len(log2.times)
== 4
) # previously failed bc len(X) == 3


def test_multiple_simulates():
sir1 = bayes_sir_model()
sir2 = bayes_sir_model()

assert sir1.beta != sir2.beta
assert sir1.gamma != sir2.gamma

with LogTrajectory(times=logging_times) as dt1:
with TorchDiffEq():
result11 = simulate(sir1, init_state, start_time, end_time)
result12 = simulate(sir2, init_state, start_time, end_time)

with LogTrajectory(times=logging_times) as dt2:
with TorchDiffEq():
result21 = simulate(sir1, init_state, start_time, end_time)

with LogTrajectory(times=logging_times) as dt3:
with TorchDiffEq():
result22 = simulate(sir2, init_state, start_time, end_time)

# Simulation outputs do not depend on LogTrajectory context
assert check_states_match(result11, result21)
assert check_states_match(result12, result22)

# LogTrajectory trajectory only preserves the final `simulate` call.
assert check_trajectories_match_in_all_but_values(dt1.trajectory, dt2.trajectory)
assert check_states_match(dt1.trajectory, dt3.trajectory)
3 changes: 3 additions & 0 deletions tests/dynamical/test_static_interventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,9 @@ def test_twinworld_point_intervention(
assert cf.default_name in indices_of(cf_trajectory[k], event_dim=1)


@pytest.mark.skip(
reason="This test previously silently passed because cf_trajectory.keys() was empty."
)
@pytest.mark.parametrize("model", [UnifiedFixtureDynamics()])
@pytest.mark.parametrize("init_state", [init_state_values])
@pytest.mark.parametrize("start_time", [start_time])
Expand Down

0 comments on commit ac87280

Please sign in to comment.