Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix time collision bug in LogTrajectory #397

Merged
merged 24 commits into from
Dec 7, 2023
Merged

Fix time collision bug in LogTrajectory #397

merged 24 commits into from
Dec 7, 2023

Conversation

SamWitty
Copy link
Collaborator

@SamWitty SamWitty commented Nov 28, 2023

This PR addresses #396 by adding the initial state to the trajectory if the start_time of simulate is equal to the first element of the (sorted) logging_times argument of the LogTrajectory handler. This works because the LogTrajectory handler excludes time collisions on the start_time and includes time collisions on the end_time of each simulate_point call, which was necessary to not "double include" interruption times that collided with an element in the logging_times and thus show up as the start_time and end_time of two simulate_point calls.

In addition, this PR adds a slight modification of the test described in #396 to the test suite.

@SamWitty SamWitty added bug Something isn't working status:WIP Work-in-progress not yet ready for review module:dynamical labels Nov 28, 2023
@SamWitty SamWitty self-assigned this Nov 28, 2023
@SamWitty SamWitty added status:awaiting review Awaiting response from reviewer and removed status:WIP Work-in-progress not yet ready for review labels Nov 30, 2023
@SamWitty SamWitty requested a review from eb8680 November 30, 2023 19:15
@SamWitty SamWitty linked an issue Nov 30, 2023 that may be closed by this pull request
@@ -27,6 +31,19 @@ def __enter__(self) -> "LogTrajectory[T]":
self.trajectory: State[T] = State()
Copy link
Collaborator Author

@SamWitty SamWitty Dec 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eb8680 , do you think it makes sense to move this line into _pyro_simulate instead of __enter__. This would change the behavior from "concatenate multiple simulate calls trajectories together" to "store only the trajectory of the final simulate call". Or alternatively, we could give simulate an optional name argument if you want to store multiple trajectories.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, re-initializing self.trajectory in _pyro_simulate seems more correct than the current behavior.

Copy link
Contributor

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test that checks the case where a StaticInterruption's time coincides with a logging time, and also rerun the dynamical_intro notebook from scratch to verify that it no longer fails because of this bug?

@eb8680 eb8680 added status:awaiting response Awaiting response from creator and removed status:awaiting review Awaiting response from reviewer labels Dec 1, 2023
@SamWitty
Copy link
Collaborator Author

SamWitty commented Dec 1, 2023

Can you add a test that checks the case where a StaticInterruption's time coincides with a logging time, and also rerun the dynamical_intro notebook from scratch to verify that it no longer fails because of this bug?

I believe this existing test covers the test case? https://github.com/BasisResearch/chirho/blob/sw-time-collision/tests/dynamical/test_log_trajectory.py#L42 Happy to add more if that doesn't cover what you're looking for.

I'll rerun the notebook in this PR.

Would you like me to make the change we discussed in this comment in this PR or separately?

@eb8680 , do you think it makes sense to move this line into _pyro_simulate instead of __enter__. This would change the behavior from "concatenate multiple simulate calls trajectories together" to "store only the trajectory of the final simulate call". Or alternatively, we could give simulate an optional name argument if you want to store multiple trajectories.

@SamWitty SamWitty requested a review from eb8680 December 4, 2023 17:24
@SamWitty
Copy link
Collaborator Author

SamWitty commented Dec 4, 2023

@eb8680 , I edited this PR to address my suggestion here: #397 (comment). It is now ready for review again. I have also strengthened the tests a bit.

To ensure that handlers commute after this PR we now have the following strict order enforced regardless of the order handlers are applied:
LogTrajectory - Handles simulate during preprocessing to address potential time collision.
Solver - Handles simulate during postprocessing to do basically everything.
BatchObservation - Handles simulate to condition on data using a continuation, which occurs after all postprocessing.

I reran the dynamical system notebook, confirming that it works with these changes (although, requiring a minor change to the name of a trace address in the final cell because of a previous PR). I'm excluding the revised run from this PR to avoid merge conflicts with #377, which I expect to be completed by EOD or early tomorrow. Now that #308 is merged, I'll add a follow up PR after #377 to add the dynamical systems notebook to CI.

@SamWitty SamWitty added status:awaiting review Awaiting response from reviewer and removed status:awaiting response Awaiting response from creator labels Dec 4, 2023
Copy link
Contributor

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making sure LogTrajectory, Solver and BatchObservation commute is the right idea, but it seems like there's a simpler way to achieve that:

@@ -130,5 +130,17 @@ def __init__(
self.observation = observation
super().__init__(times)

def _pyro_post_simulate(self, msg: dict) -> None:
self.trajectory = observe(self.trajectory, self.observation)
def _pyro_simulate(self, msg: dict) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you should need to change StaticBatchObservation at all, other than adding a single super()._pyro_post_simulate call:

def _pyro_post_simulate(self, msg: dict) -> None:
    super()._pyro_post_simulate(msg)  # update self.trajectory
    self.trajectory = observe(self.trajectory, self.observation)

# LogTrajectory's simulate_point will log only timepoints that are greater than the start_time of each
# simulate_point call, which can occur multiple times in a single simulate call when there
# are interruptions.
self.trajectory = append(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this change is right, and is already enough to make LogTrajectory and Solver commutative. In particular, since append is associative and Solver doesn't change initial_state or start_time in its _pyro_simulate method, as long as no outside code needs to access self.trajectory before the simulate call finishes, it shouldn't matter if LogTrajectory's _pyro_simulate runs before or after Solver's.


@typing.final
@staticmethod
def _pyro_post_simulate(msg: dict) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This move from _pyro_simulate to _pyro_post_simulate seems unnecessary, per my comment on LogTrajectory.

@eb8680 eb8680 added status:awaiting response Awaiting response from creator and removed status:awaiting review Awaiting response from reviewer labels Dec 6, 2023
@SamWitty
Copy link
Collaborator Author

SamWitty commented Dec 6, 2023

Making sure LogTrajectory, Solver and BatchObservation commute is the right idea, but it seems like there's a simpler way to achieve that:

I agree. The key challenge here is actually in the assignment of self.trajectory: State[T] = State() in _pyro_simulate of LogTrajectory. Even though the append is associative, reseting self.trajectory must happen before the Solver's _pyro_simulate is called, or else it'll be entirely cleared in that line.

I'll try and find a workaround to this problem that's a bit less verbose and touches fewer pieces.

@eb8680
Copy link
Contributor

eb8680 commented Dec 6, 2023

The key challenge here is actually in the assignment of self.trajectory: State[T] = State() in _pyro_simulate of LogTrajectory. Even though the append is associative, reseting self.trajectory must happen before the Solver's _pyro_simulate is called, or else it'll be entirely cleared in that line.

Hmm, what about adding an extra boolean state variable to LogTrajectory that gets set in _pyro_post_simulate to track whether simulate has run? It's kind of clunky but at least it's isolated to LogTrajectory.

def __init__(self, ...):
    self._needs_reset: bool = False
    self.trajectory = State()
    ...

def __enter__(self):
    self._needs_reset = False
    self.trajectory = State()
    ...

def _pyro_simulate(self, msg: dict) -> None:
    if self._needs_reset:
        self.trajectory = State()
        self._needs_reset = False
    ...

def _pyro_post_simulate(self, msg: dict) -> None:
    if not self._needs_reset:
        self._needs_reset = True
    ...

@SamWitty
Copy link
Collaborator Author

SamWitty commented Dec 6, 2023

@eb8680 , I've made this PR a bit simpler, but in the process it exposed an error that was occurring with a test using MultiWorldCounterfactual. I've decided to skip that test here rather than address it in this PR, as I believe doing so will require a solution to #379 . That's honestly a bit speculative though.

The (now skipped) test fails because dt.trajectory is never instantiated, whereas it was previously instantiated on __enter__ but never modified. As the test enumerated through the empty set of keys, it previously passed even though it shouldn't have.

@SamWitty
Copy link
Collaborator Author

SamWitty commented Dec 6, 2023

The key challenge here is actually in the assignment of self.trajectory: State[T] = State() in _pyro_simulate of LogTrajectory. Even though the append is associative, reseting self.trajectory must happen before the Solver's _pyro_simulate is called, or else it'll be entirely cleared in that line.

Hmm, what about adding an extra boolean state variable to LogTrajectory that gets set in _pyro_post_simulate to track whether simulate has run? It's kind of clunky but at least it's isolated to LogTrajectory.

def __init__(self, ...):
    self._needs_reset: bool = False
    self.trajectory = State()
    ...

def __enter__(self):
    self._needs_reset = False
    self.trajectory = State()
    ...

def _pyro_simulate(self, msg: dict) -> None:
    if self._needs_reset:
        self.trajectory = State()
        self._needs_reset = False
    ...

def _pyro_post_simulate(self, msg: dict) -> None:
    if not self._needs_reset:
        self._needs_reset = True
    ...

My solution (that I finished before seeing this) is in the spirit of this suggestion, but a bit less explicit and a bit more concise. I'm really happy either way.

@SamWitty SamWitty added status:awaiting review Awaiting response from reviewer and removed status:awaiting response Awaiting response from creator labels Dec 6, 2023
@SamWitty SamWitty requested a review from eb8680 December 6, 2023 21:41
Copy link
Contributor

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for fixing this!

@eb8680 eb8680 merged commit ac87280 into master Dec 7, 2023
7 checks passed
@eb8680 eb8680 deleted the sw-time-collision branch December 7, 2023 02:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working module:dynamical status:awaiting review Awaiting response from reviewer
Projects
None yet
Development

Successfully merging this pull request may close these issues.

LogTrajectory does not handle edge cases correctly
2 participants