-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix time collision bug in LogTrajectory
#397
Conversation
…into sw-time-collision
…into sw-time-collision
@@ -27,6 +31,19 @@ def __enter__(self) -> "LogTrajectory[T]": | |||
self.trajectory: State[T] = State() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eb8680 , do you think it makes sense to move this line into _pyro_simulate
instead of __enter__
. This would change the behavior from "concatenate multiple simulate
calls trajectories together" to "store only the trajectory of the final simulate call". Or alternatively, we could give simulate
an optional name argument if you want to store multiple trajectories.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, re-initializing self.trajectory
in _pyro_simulate
seems more correct than the current behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test that checks the case where a StaticInterruption
's time
coincides with a logging time, and also rerun the dynamical_intro
notebook from scratch to verify that it no longer fails because of this bug?
I believe this existing test covers the test case? https://github.com/BasisResearch/chirho/blob/sw-time-collision/tests/dynamical/test_log_trajectory.py#L42 Happy to add more if that doesn't cover what you're looking for. I'll rerun the notebook in this PR. Would you like me to make the change we discussed in this comment in this PR or separately?
|
…into sw-time-collision
…causal_pyro into sw-time-collision
…plied after all solves
@eb8680 , I edited this PR to address my suggestion here: #397 (comment). It is now ready for review again. I have also strengthened the tests a bit. To ensure that handlers commute after this PR we now have the following strict order enforced regardless of the order handlers are applied: I reran the dynamical system notebook, confirming that it works with these changes (although, requiring a minor change to the name of a trace address in the final cell because of a previous PR). I'm excluding the revised run from this PR to avoid merge conflicts with #377, which I expect to be completed by EOD or early tomorrow. Now that #308 is merged, I'll add a follow up PR after #377 to add the dynamical systems notebook to CI. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making sure LogTrajectory
, Solver
and BatchObservation
commute is the right idea, but it seems like there's a simpler way to achieve that:
@@ -130,5 +130,17 @@ def __init__( | |||
self.observation = observation | |||
super().__init__(times) | |||
|
|||
def _pyro_post_simulate(self, msg: dict) -> None: | |||
self.trajectory = observe(self.trajectory, self.observation) | |||
def _pyro_simulate(self, msg: dict) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you should need to change StaticBatchObservation
at all, other than adding a single super()._pyro_post_simulate
call:
def _pyro_post_simulate(self, msg: dict) -> None:
super()._pyro_post_simulate(msg) # update self.trajectory
self.trajectory = observe(self.trajectory, self.observation)
# LogTrajectory's simulate_point will log only timepoints that are greater than the start_time of each | ||
# simulate_point call, which can occur multiple times in a single simulate call when there | ||
# are interruptions. | ||
self.trajectory = append( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this change is right, and is already enough to make LogTrajectory
and Solver
commutative. In particular, since append
is associative and Solver
doesn't change initial_state
or start_time
in its _pyro_simulate
method, as long as no outside code needs to access self.trajectory
before the simulate
call finishes, it shouldn't matter if LogTrajectory
's _pyro_simulate
runs before or after Solver
's.
chirho/dynamical/internals/solver.py
Outdated
|
||
@typing.final | ||
@staticmethod | ||
def _pyro_post_simulate(msg: dict) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This move from _pyro_simulate
to _pyro_post_simulate
seems unnecessary, per my comment on LogTrajectory
.
I agree. The key challenge here is actually in the assignment of I'll try and find a workaround to this problem that's a bit less verbose and touches fewer pieces. |
Hmm, what about adding an extra boolean state variable to def __init__(self, ...):
self._needs_reset: bool = False
self.trajectory = State()
...
def __enter__(self):
self._needs_reset = False
self.trajectory = State()
...
def _pyro_simulate(self, msg: dict) -> None:
if self._needs_reset:
self.trajectory = State()
self._needs_reset = False
...
def _pyro_post_simulate(self, msg: dict) -> None:
if not self._needs_reset:
self._needs_reset = True
... |
@eb8680 , I've made this PR a bit simpler, but in the process it exposed an error that was occurring with a test using The (now skipped) test fails because |
My solution (that I finished before seeing this) is in the spirit of this suggestion, but a bit less explicit and a bit more concise. I'm really happy either way. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for fixing this!
This PR addresses #396 by adding the initial state to the trajectory if the
start_time
ofsimulate
is equal to the first element of the (sorted)logging_times
argument of theLogTrajectory
handler. This works because theLogTrajectory
handler excludes time collisions on thestart_time
and includes time collisions on theend_time
of eachsimulate_point
call, which was necessary to not "double include" interruption times that collided with an element in thelogging_times
and thus show up as thestart_time
andend_time
of twosimulate_point
calls.In addition, this PR adds a slight modification of the test described in #396 to the test suite.