Skip to content

Commit

Permalink
Add error handling and warning when interruption is outside of `simul…
Browse files Browse the repository at this point in the history
…ate` start and end time (#359)

* revise tests so that interventions occur before simulate start_time to demonstrate regression in error handling

* add value error and warning to StaticInterruption
  • Loading branch information
SamWitty authored Oct 25, 2023
1 parent a6fb4e2 commit 7f35c24
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
21 changes: 15 additions & 6 deletions chirho/dynamical/handlers/interruption.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,21 @@ def __init__(self, time: R):
self.time = torch.as_tensor(time) # TODO enforce this where it is needed
super().__init__()

def _pyro_simulate(self, msg) -> None:
_, _, start_time, end_time = msg["args"]

if self.time < start_time:
raise ValueError(
f"{StaticInterruption.__name__} time {self.time} occurred before the start of the "
f"timespan {start_time}. This interruption will have no effect."
)
elif self.time >= end_time:
warnings.warn(
f"{StaticInterruption.__name__} time {self.time} occurred after the end of the timespan "
f"{end_time}. This interruption will have no effect.",
UserWarning,
)

def _pyro_get_next_interruptions(self, msg) -> None:
_, _, _, start_time, end_time = msg["args"]

Expand All @@ -47,12 +62,6 @@ def _pyro_get_next_interruptions(self, msg) -> None:
or self.time < next_static_interruption.time
):
msg["kwargs"]["next_static_interruption"] = self
elif self.time >= end_time:
warnings.warn(
f"{StaticInterruption.__name__} time {self.time} occurred after the end of the timespan "
f"{end_time}. This interruption will have no effect.",
UserWarning,
)


class DynamicInterruption(Generic[T], Interruption):
Expand Down
2 changes: 1 addition & 1 deletion tests/dynamical/test_static_interventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
# Points at which to measure the state of the system.
start_time = torch.tensor(0.0)
end_time = torch.tensor(10.0)
logging_times = torch.linspace(start_time + 1, end_time - 2, 5)
logging_times = torch.linspace(start_time + 0.01, end_time - 2, 5)

# Initial state of the system.
init_state_values = State(
Expand Down

0 comments on commit 7f35c24

Please sign in to comment.