Skip to content

Commit

Permalink
Consolidate Backend and SolverHandler into a single Solver effe…
Browse files Browse the repository at this point in the history
…ct handler / dispatch mechanism (#292)

* replaced ODEBackend with ODESolver

* lingering renaming of Backend to Solver and removal of unnecessary Solver handler

* typo in notebook

* remove dupliace Solver

* dummy commit to trigger linting?

* fixed lint error

---------

Co-authored-by: Raj Agrawal <[email protected]>
  • Loading branch information
SamWitty and agrawalraj authored Sep 27, 2023
1 parent e3f4ce0 commit c015a12
Show file tree
Hide file tree
Showing 26 changed files with 48 additions and 53 deletions.
2 changes: 2 additions & 0 deletions chirho/dynamical/handlers/ODE/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from . import solvers # noqa F401
from .ode import ODESolver # noqa F401
5 changes: 5 additions & 0 deletions chirho/dynamical/handlers/ODE/ode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from chirho.dynamical.handlers.solver import Solver


class ODESolver(Solver):
pass
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from chirho.dynamical.ops.ODE import ODEBackend
from chirho.dynamical.handlers.ODE.ode import ODESolver


class TorchDiffEq(ODEBackend):
class TorchDiffEq(ODESolver):
def __init__(self, rtol=1e-7, atol=1e-9, method=None, options=None):
self.rtol = rtol
self.atol = atol
Expand All @@ -13,3 +13,4 @@ def __init__(self, rtol=1e-7, atol=1e-9, method=None, options=None):
"method": method,
"options": options,
}
super().__init__()
3 changes: 2 additions & 1 deletion chirho/dynamical/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import ODE # noqa: F401
from .dynamical import SimulatorEventLoop # noqa: F401
from .interruption import ( # noqa: F401
DynamicInterruption,
Expand All @@ -8,4 +9,4 @@
PointIntervention,
PointObservation,
)
from .solver import SolverHandler # noqa: F401
from .solver import Solver # noqa: F401
7 changes: 2 additions & 5 deletions chirho/dynamical/handlers/dynamical.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,13 @@
concatenate,
simulate_to_interruption,
)
from chirho.dynamical.ops import Trajectory
from chirho.dynamical.ops.dynamical import Trajectory

S = TypeVar("S")
T = TypeVar("T")


class SimulatorEventLoop(Generic[T], pyro.poutine.messenger.Messenger):
def __enter__(self):
return super().__enter__()

# noinspection PyMethodMayBeStatic
def _pyro_simulate(self, msg) -> None:
dynamics, initial_state, full_timespan = msg["args"]
Expand Down Expand Up @@ -100,7 +97,7 @@ def _pyro_simulate(self, msg) -> None:

last = default_terminal_interruption in terminal_interruptions

# Update the full trajectory.
# Update the full trajectory
if first:
full_trajs.append(span_traj)
else:
Expand Down
2 changes: 1 addition & 1 deletion chirho/dynamical/handlers/interruption.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch

from chirho.dynamical.internals.interventional import intervene
from chirho.dynamical.ops import State
from chirho.dynamical.ops.dynamical import State
from chirho.observational.handlers import condition

S = TypeVar("S")
Expand Down
12 changes: 3 additions & 9 deletions chirho/dynamical/handlers/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,8 @@

import pyro

from chirho.dynamical.ops import Solver


class SolverHandler(pyro.poutine.messenger.Messenger):
def __init__(self, solver: Solver):
self.solver = solver
super().__init__()

class Solver(pyro.poutine.messenger.Messenger):
def _pyro_simulate(self, msg) -> None:
# Overwrite the solver in the message with the one we're handling.
msg["kwargs"]["solver"] = self.solver
# Overwrite the solver in the message with the enclosing solver when used as a context manager.
msg["kwargs"]["solver"] = self
4 changes: 2 additions & 2 deletions chirho/dynamical/internals/ODE/backends/torchdiffeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
import torch
import torchdiffeq

from chirho.dynamical.handlers.ODE.solvers import TorchDiffEq
from chirho.dynamical.internals.ODE.ode_simulate import (
ode_simulate,
ode_simulate_to_interruption,
)
from chirho.dynamical.ops import State, Trajectory
from chirho.dynamical.ops.dynamical import State, Trajectory
from chirho.dynamical.ops.ODE import ODEDynamics
from chirho.dynamical.ops.ODE.solvers import TorchDiffEq

if TYPE_CHECKING:
from chirho.dynamical.internals.interruption import (
Expand Down
13 changes: 7 additions & 6 deletions chirho/dynamical/internals/ODE/ode_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import functools
from typing import TypeVar

from chirho.dynamical.handlers.ODE import ODESolver
from chirho.dynamical.internals.interruption import simulate_to_interruption
from chirho.dynamical.ops import State, simulate
from chirho.dynamical.ops.ODE import ODEBackend, ODEDynamics
from chirho.dynamical.ops.dynamical import State, simulate
from chirho.dynamical.ops.ODE import ODEDynamics

S = TypeVar("S")
T = TypeVar("T")
Expand All @@ -17,15 +18,15 @@ def ode_simulate(
initial_state: State[T],
timespan,
*,
solver: ODEBackend,
solver: ODESolver,
**kwargs,
):
return _ode_simulate(solver, dynamics, initial_state, timespan, **kwargs)


@functools.singledispatch
def _ode_simulate(
solver: ODEBackend,
solver: ODESolver,
dynamics: ODEDynamics,
initial_state: State[T],
timespan,
Expand All @@ -48,7 +49,7 @@ def ode_simulate_to_interruption(
initial_state: State[T],
timespan,
*,
solver: ODEBackend,
solver: ODESolver,
**kwargs,
):
return _ode_simulate_to_interruption(
Expand All @@ -58,7 +59,7 @@ def ode_simulate_to_interruption(

@functools.singledispatch
def _ode_simulate_to_interruption(
solver: ODEBackend,
solver: ODESolver,
dynamics: ODEDynamics,
initial_state: State[T],
timespan,
Expand Down
2 changes: 1 addition & 1 deletion chirho/dynamical/internals/indexed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import TypeVar

from chirho.dynamical.ops import State, Trajectory
from chirho.dynamical.ops.dynamical import State, Trajectory
from chirho.indexed.ops import IndexSet, gather, indices_of, union

S = TypeVar("S")
Expand Down
3 changes: 2 additions & 1 deletion chirho/dynamical/internals/interruption.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import pyro
import torch

from chirho.dynamical.ops import Dynamics, Solver, State, Trajectory
from chirho.dynamical.handlers.solver import Solver
from chirho.dynamical.ops.dynamical import Dynamics, State, Trajectory

S = TypeVar("S")
T = TypeVar("T")
Expand Down
2 changes: 1 addition & 1 deletion chirho/dynamical/internals/interventional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import TypeVar

from chirho.dynamical.ops import State
from chirho.dynamical.ops.dynamical import State
from chirho.interventional.handlers import intervene

T = TypeVar("T")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,12 @@

from typing import TypeVar

from chirho.dynamical.ops import Dynamics, Solver, State, simulate
from chirho.dynamical.ops.dynamical import Dynamics, State, simulate

S = TypeVar("S")
T = TypeVar("T")


class ODEBackend(Solver):
pass


# noinspection PyPep8Naming
class ODEDynamics(Dynamics[S, T]):
def diff(self, dX: State[S], X: State[S]) -> T:
Expand Down
1 change: 0 additions & 1 deletion chirho/dynamical/ops/ODE/__init__.py

This file was deleted.

3 changes: 2 additions & 1 deletion chirho/dynamical/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .dynamical import Dynamics, Solver, State, Trajectory, simulate # noqa: F401
from . import ODE # noqa: F401
from .dynamical import Dynamics, State, Trajectory, simulate # noqa: F401
7 changes: 2 additions & 5 deletions chirho/dynamical/ops/dynamical.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pyro
import torch

from chirho.dynamical.handlers.solver import Solver
from chirho.dynamical.internals.dynamical import _index_last_dim_with_mask

S = TypeVar("S")
Expand Down Expand Up @@ -123,10 +124,6 @@ class Dynamics(Protocol[S, T]):
diff: Callable[[State[S], State[S]], T]


class Solver:
pass


@pyro.poutine.runtime.effectful(type="simulate")
def simulate(
dynamics: Dynamics[S, T],
Expand All @@ -144,7 +141,7 @@ def simulate(
"SimulatorEventLoop requires a solver. To specify a solver, use the keyword argument `solver` in"
" the call to `simulate` or use with a solver effect handler as a context manager. For example, \n \n"
"`with SimulatorEventLoop():` \n"
"\t `with SimulatorBackend(TorchDiffEq()):` \n"
"\t `with TorchDiffEq():` \n"
"\t \t `simulate(dynamics, initial_state, timespan)`"
)
return _simulate(dynamics, initial_state, timespan, solver=solver, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions docs/source/dynamical_intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@
" NonInterruptingPointObservationArray\n",
")\n",
"from chirho.dynamical.handlers.interruption import _InterventionMixin\n",
"from chirho.dynamical.ops import State, Trajectory, simulate\n",
"from chirho.dynamical.ops.dynamical import State, Trajectory, simulate\n",
"\n",
"from chirho.dynamical.ops.ODE import ODEDynamics\n",
"from chirho.dynamical.ops.ODE.solvers import TorchDiffEq\n",
"from chirho.dynamical.handlers.ODE.solvers import TorchDiffEq\n",
"\n",
"from chirho.observational.handlers.soft_conditioning import (\n",
" AutoSoftConditioning\n",
Expand Down
2 changes: 1 addition & 1 deletion tests/dynamical/dynamical_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from pyro.distributions import Normal, Uniform, constraints

from chirho.dynamical.ops import State, Trajectory
from chirho.dynamical.ops.dynamical import State, Trajectory
from chirho.dynamical.ops.ODE import ODEDynamics


Expand Down
2 changes: 1 addition & 1 deletion tests/dynamical/obs_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
NonInterruptingPointObservationArray,
SimulatorEventLoop,
)
from chirho.dynamical.handlers.ODE.solvers import TorchDiffEq
from chirho.dynamical.ops import State, simulate
from chirho.dynamical.ops.ODE import ODEDynamics
from chirho.dynamical.ops.ODE.solvers import TorchDiffEq


class SimpleSIRDynamicsBayes(ODEDynamics):
Expand Down
2 changes: 1 addition & 1 deletion tests/dynamical/test_dynamic_interventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
TwinWorldCounterfactual,
)
from chirho.dynamical.handlers import DynamicIntervention, SimulatorEventLoop
from chirho.dynamical.handlers.ODE.solvers import TorchDiffEq
from chirho.dynamical.ops import State, simulate
from chirho.dynamical.ops.ODE import ODEDynamics
from chirho.dynamical.ops.ODE.solvers import TorchDiffEq
from chirho.indexed.ops import IndexSet, gather, indices_of, union

from .dynamical_fixtures import UnifiedFixtureDynamics
Expand Down
2 changes: 1 addition & 1 deletion tests/dynamical/test_handler_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
PointIntervention,
SimulatorEventLoop,
)
from chirho.dynamical.handlers.ODE.solvers import TorchDiffEq
from chirho.dynamical.ops import State, simulate
from chirho.dynamical.ops.ODE.solvers import TorchDiffEq
from chirho.observational.handlers.soft_conditioning import AutoSoftConditioning
from tests.dynamical.dynamical_fixtures import (
UnifiedFixtureDynamics,
Expand Down
2 changes: 1 addition & 1 deletion tests/dynamical/test_noop_interruptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
PointIntervention,
SimulatorEventLoop,
)
from chirho.dynamical.handlers.ODE.solvers import TorchDiffEq
from chirho.dynamical.ops import State, simulate
from chirho.dynamical.ops.ODE.solvers import TorchDiffEq

from .dynamical_fixtures import UnifiedFixtureDynamics, check_trajectories_match

Expand Down
6 changes: 3 additions & 3 deletions tests/dynamical/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import pytest
import torch

from chirho.dynamical.handlers import SimulatorEventLoop, SolverHandler
from chirho.dynamical.handlers import SimulatorEventLoop
from chirho.dynamical.handlers.ODE.solvers import TorchDiffEq
from chirho.dynamical.ops import State, simulate
from chirho.dynamical.ops.ODE.solvers import TorchDiffEq

from .dynamical_fixtures import bayes_sir_model, check_trajectories_match

Expand Down Expand Up @@ -42,7 +42,7 @@ def test_backend_arg():
def test_backend_handler():
sir = bayes_sir_model()
with SimulatorEventLoop():
with SolverHandler(TorchDiffEq()):
with TorchDiffEq():
result_handler = simulate(sir, init_state, tspan)

result_arg = simulate(sir, init_state, tspan, solver=TorchDiffEq())
Expand Down
2 changes: 1 addition & 1 deletion tests/dynamical/test_static_interventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
TwinWorldCounterfactual,
)
from chirho.dynamical.handlers import PointIntervention, SimulatorEventLoop
from chirho.dynamical.handlers.ODE.solvers import TorchDiffEq
from chirho.dynamical.ops import State, simulate
from chirho.dynamical.ops.ODE.solvers import TorchDiffEq
from chirho.indexed.ops import IndexSet, gather, indices_of
from chirho.interventional.ops import intervene

Expand Down
2 changes: 1 addition & 1 deletion tests/dynamical/test_static_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
PointObservation,
SimulatorEventLoop,
)
from chirho.dynamical.handlers.ODE.solvers import TorchDiffEq
from chirho.dynamical.ops import State, simulate
from chirho.dynamical.ops.ODE.solvers import TorchDiffEq

from .dynamical_fixtures import (
UnifiedFixtureDynamics,
Expand Down

0 comments on commit c015a12

Please sign in to comment.