Skip to content

Commit

Permalink
Move dynamical backend interface into one file (#312)
Browse files Browse the repository at this point in the history
* consolidate backend interface in one file

* consolidate patterns

* rename indexed to _patterns

* rename to backend

* rename pattern

* rename pattern and lint

* import

* reorder file

* remove deleted file

* fix merge

* fix error

* lint

* dead code

* move append and rename _utils

* get_solver

* fix
  • Loading branch information
eb8680 authored Oct 11, 2023
1 parent 72e2e3c commit d663a59
Show file tree
Hide file tree
Showing 13 changed files with 153 additions and 173 deletions.
2 changes: 1 addition & 1 deletion chirho/dynamical/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ..internals.backend import Solver # noqa: F401
from .dynamical import SimulatorEventLoop # noqa: F401
from .interruption.array_observation import StaticBatchObservation # noqa: F401
from .interruption.interruption import ( # noqa: F401
Expand All @@ -8,5 +9,4 @@
StaticIntervention,
StaticObservation,
)
from .solver import Solver # noqa: F401
from .trace import DynamicTrace # noqa: F401
9 changes: 5 additions & 4 deletions chirho/dynamical/handlers/dynamical.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import pyro

from chirho.dynamical.handlers.interruption.interruption import Interruption
from chirho.dynamical.internals.interruption import (
from chirho.dynamical.internals.backend import (
apply_interruptions,
get_solver,
simulate_to_interruption,
)

Expand All @@ -17,10 +18,10 @@
class SimulatorEventLoop(Generic[T], pyro.poutine.messenger.Messenger):
def _pyro_simulate(self, msg) -> None:
dynamics, state, start_time, end_time = msg["args"]
if "solver" in msg["kwargs"]:
if msg["kwargs"].get("solver", None) is not None:
solver = msg["kwargs"]["solver"]
else: # Early return to trigger `simulate` ValueError for not having a solver.
return
else:
solver = get_solver()

# Simulate through the timespan, stopping at each interruption. This gives e.g. intervention handlers
# a chance to modify the state and/or dynamics before the next span is simulated.
Expand Down
1 change: 0 additions & 1 deletion chirho/dynamical/handlers/interruption/interruption.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import pyro
import torch

import chirho.dynamical.internals.interventional # noqa: F401
from chirho.dynamical.ops.dynamical import ObservableInPlaceDynamics, State
from chirho.interventional.ops import Intervention, intervene
from chirho.observational.handlers import condition
Expand Down
8 changes: 1 addition & 7 deletions chirho/dynamical/handlers/solver.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
import pyro


class Solver(pyro.poutine.messenger.Messenger):
def _pyro_simulate(self, msg) -> None:
# Overwrite the solver in the message with the enclosing solver when used as a context manager.
msg["kwargs"]["solver"] = self
from chirho.dynamical.internals.backend import Solver


class TorchDiffEq(Solver):
Expand Down
38 changes: 2 additions & 36 deletions chirho/dynamical/handlers/trace.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,15 @@
import functools
from typing import Generic, TypeVar

import pyro
import torch

from chirho.dynamical.internals.dynamical import simulate_trajectory
from chirho.dynamical.internals._utils import append
from chirho.dynamical.internals.backend import simulate_trajectory
from chirho.dynamical.ops import Trajectory

T = TypeVar("T")


@functools.singledispatch
def append(fst, rest: T) -> T:
raise NotImplementedError(f"append not implemented for type {type(fst)}.")


@append.register(Trajectory)
def append_trajectory(traj1: Trajectory[T], traj2: Trajectory[T]) -> Trajectory[T]:
if len(traj1.keys) == 0:
return traj2

if len(traj2.keys) == 0:
return traj1

if traj1.keys != traj2.keys:
raise ValueError(
f"Trajectories must have the same keys to be appended, but got {traj1.keys} and {traj2.keys}."
)

result: Trajectory[T] = Trajectory()
for k in traj1.keys:
setattr(result, k, append(getattr(traj1, k), getattr(traj2, k)))

return result


@append.register(torch.Tensor)
def append_tensor(prev_v: torch.Tensor, curr_v: torch.Tensor) -> torch.Tensor:
time_dim = -1 # TODO generalize to nontrivial event_shape
batch_shape = torch.broadcast_shapes(prev_v.shape[:-1], curr_v.shape[:-1])
prev_v = prev_v.expand(*batch_shape, *prev_v.shape[-1:])
curr_v = curr_v.expand(*batch_shape, *curr_v.shape[-1:])
return torch.cat([prev_v, curr_v], dim=time_dim)


class DynamicTrace(Generic[T], pyro.poutine.messenger.Messenger):
trace: Trajectory[T]

Expand Down
3 changes: 1 addition & 2 deletions chirho/dynamical/internals/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Include only imports that are needed for registering dispatches.

from . import dynamical # noqa: F401
from . import indexed # noqa: F401
from . import _utils # noqa: F401
from . import solver # noqa: F401
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import functools
from typing import TypeVar

import torch

from chirho.dynamical.ops.dynamical import State, Trajectory
from chirho.indexed.ops import IndexSet, gather, indices_of, union
from chirho.interventional.handlers import intervene

S = TypeVar("S")
T = TypeVar("T")
Expand Down Expand Up @@ -78,3 +80,47 @@ def _index_last_dim_with_mask(x: torch.Tensor, mask: torch.Tensor) -> torch.Tens
mask.reshape((1,) * (x.ndim - 1) + mask.shape)
# masked_select flattens tensors, so we need to reshape back to the original shape w/ the mask applied.
).reshape(x.shape[:-1] + (int(mask.sum()),))


@intervene.register(State)
def _state_intervene(obs: State[T], act: State[T], **kwargs) -> State[T]:
new_state: State[T] = State()
for k in obs.keys:
setattr(
new_state, k, intervene(getattr(obs, k), getattr(act, k, None), **kwargs)
)
return new_state


@functools.singledispatch
def append(fst, rest: T) -> T:
raise NotImplementedError(f"append not implemented for type {type(fst)}.")


@append.register(Trajectory)
def _append_trajectory(traj1: Trajectory[T], traj2: Trajectory[T]) -> Trajectory[T]:
if len(traj1.keys) == 0:
return traj2

if len(traj2.keys) == 0:
return traj1

if traj1.keys != traj2.keys:
raise ValueError(
f"Trajectories must have the same keys to be appended, but got {traj1.keys} and {traj2.keys}."
)

result: Trajectory[T] = Trajectory()
for k in traj1.keys:
setattr(result, k, append(getattr(traj1, k), getattr(traj2, k)))

return result


@append.register(torch.Tensor)
def _append_tensor(prev_v: torch.Tensor, curr_v: torch.Tensor) -> torch.Tensor:
time_dim = -1 # TODO generalize to nontrivial event_shape
batch_shape = torch.broadcast_shapes(prev_v.shape[:-1], curr_v.shape[:-1])
prev_v = prev_v.expand(*batch_shape, *prev_v.shape[-1:])
curr_v = curr_v.expand(*batch_shape, *curr_v.shape[-1:])
return torch.cat([prev_v, curr_v], dim=time_dim)
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,80 @@

import functools
import numbers
from typing import TYPE_CHECKING, List, Optional, Tuple, TypeVar, Union
import typing
from typing import List, Optional, Tuple, TypeVar, Union

import pyro
import torch

from chirho.dynamical.handlers.interruption.interruption import (
DynamicInterruption,
Interruption,
StaticInterruption,
)
from chirho.dynamical.ops.dynamical import InPlaceDynamics, State, simulate
from chirho.dynamical.ops.dynamical import InPlaceDynamics, State, Trajectory, simulate

if TYPE_CHECKING:
from chirho.dynamical.handlers.solver import Solver
if typing.TYPE_CHECKING:
from chirho.dynamical.handlers.interruption.interruption import (
DynamicInterruption,
Interruption,
StaticInterruption,
)


R = Union[numbers.Real, torch.Tensor]
S = TypeVar("S")
T = TypeVar("T")


class Solver(pyro.poutine.messenger.Messenger):
def _pyro_get_solver(self, msg) -> None:
# Overwrite the solver in the message with the enclosing solver when used as a context manager.
msg["value"] = self
msg["done"] = True
msg["stop"] = True


@pyro.poutine.runtime.effectful(type="get_solver")
def get_solver() -> Solver:
"""
Get the current solver from the context.
"""
raise ValueError("Solver not found in context.")


@functools.singledispatch
def simulate_point(
solver: Solver,
dynamics: InPlaceDynamics[T],
initial_state: State[T],
start_time: R,
end_time: R,
**kwargs,
) -> State[T]:
"""
Simulate a dynamical system.
"""
raise NotImplementedError(
f"simulate not implemented for solver of type {type(solver)}"
)


@functools.singledispatch
def simulate_trajectory(
solver: Solver,
dynamics: InPlaceDynamics[T],
initial_state: State[T],
timespan: R,
**kwargs,
) -> Trajectory[T]:
"""
Simulate a dynamical system.
"""
raise NotImplementedError(
f"simulate_trajectory not implemented for solver of type {type(solver)}"
)


# Separating out the effectful operation from the non-effectful dispatch on the default implementation
@pyro.poutine.runtime.effectful(type="simulate_to_interruption")
def simulate_to_interruption(
solver: "Solver", # Quoted type necessary w/ TYPE_CHECKING to avoid circular import error
solver: Solver,
dynamics: InPlaceDynamics[T],
start_state: State[T],
start_time: R,
Expand Down Expand Up @@ -63,8 +112,19 @@ def simulate_to_interruption(
return event_state, interruptions, interruption_time


@pyro.poutine.runtime.effectful(type="apply_interruptions")
def apply_interruptions(
dynamics: InPlaceDynamics[T], start_state: State[T]
) -> Tuple[InPlaceDynamics[T], State[T]]:
"""
Apply the effects of an interruption to a dynamical system.
"""
# Default is to do nothing.
return dynamics, start_state


def get_next_interruptions(
solver: "Solver", # Quoted type necessary w/ TYPE_CHECKING to avoid circular import error
solver: Solver,
dynamics: InPlaceDynamics[T],
start_state: State[T],
start_time: R,
Expand All @@ -74,40 +134,33 @@ def get_next_interruptions(
dynamic_interruptions: List[DynamicInterruption] = [],
**kwargs,
) -> Tuple[Tuple[Interruption, ...], R]:
nodyn = len(dynamic_interruptions) == 0
nostat = next_static_interruption is None
from chirho.dynamical.handlers.interruption.interruption import StaticInterruption

if nostat or next_static_interruption.time > end_time: # type: ignore
if isinstance(next_static_interruption, type(None)):
# If there's no static interruption or the next static interruption is after the end time,
# we'll just simulate until the end time.
next_static_interruption = StaticInterruption(time=end_time)

assert isinstance(
next_static_interruption, StaticInterruption
) # Linter needs a hint

if nodyn:
assert isinstance(next_static_interruption, StaticInterruption)
if len(dynamic_interruptions) == 0:
# If there's no dynamic intervention, we'll simulate until either the end_time,
# or the `next_static_interruption` whichever comes first.
return (next_static_interruption,), next_static_interruption.time # type: ignore
return (next_static_interruption,), next_static_interruption.time
else:
return get_next_interruptions_dynamic( # type: ignore
solver, # type: ignore
dynamics, # type: ignore
start_state, # type: ignore
start_time, # type: ignore
return get_next_interruptions_dynamic(
solver,
dynamics,
start_state,
start_time,
next_static_interruption=next_static_interruption,
dynamic_interruptions=dynamic_interruptions,
**kwargs,
)

raise ValueError("Unreachable code reached.")


# noinspection PyUnusedLocal
@functools.singledispatch
def get_next_interruptions_dynamic(
solver: "Solver", # Quoted type necessary w/ TYPE_CHECKING to avoid circular import error
solver: Solver,
dynamics: InPlaceDynamics[T],
start_state: State[T],
start_time: R,
Expand All @@ -117,14 +170,3 @@ def get_next_interruptions_dynamic(
raise NotImplementedError(
f"get_next_interruptions_dynamic not implemented for type {type(dynamics)}"
)


@pyro.poutine.runtime.effectful(type="apply_interruptions")
def apply_interruptions(
dynamics: InPlaceDynamics[T], start_state: State[T]
) -> Tuple[InPlaceDynamics[T], State[T]]:
"""
Apply the effects of an interruption to a dynamical system.
"""
# Default is to do nothing.
return dynamics, start_state
29 changes: 0 additions & 29 deletions chirho/dynamical/internals/dynamical.py

This file was deleted.

Loading

0 comments on commit d663a59

Please sign in to comment.