Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace multiple applications of single dispatch with multimethod multiple dispatch #295

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions chirho/dynamical/internals/ODE/__init__.py

This file was deleted.

121 changes: 0 additions & 121 deletions chirho/dynamical/internals/ODE/ode_simulate.py

This file was deleted.

2 changes: 1 addition & 1 deletion chirho/dynamical/internals/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import ODE # noqa: F401
from . import backends # noqa: F401
from . import dynamical # noqa: F401
from . import indexed # noqa: F401
from . import interruption # noqa: F401
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,9 @@
import torchdiffeq

from chirho.dynamical.handlers.ODE.solvers import TorchDiffEq
from chirho.dynamical.internals.ODE.ode_simulate import (
_ode_get_next_interruptions_dynamic,
_ode_simulate_trajectory,
ode_simulate,
)
from chirho.dynamical.ops.dynamical import State, Trajectory
from chirho.dynamical.internals.dynamical import _simulate_trajectory
from chirho.dynamical.internals.interruption import _get_next_interruptions_dynamic
from chirho.dynamical.ops.dynamical import State, Trajectory, _simulate
from chirho.dynamical.ops.ODE import ODEDynamics

if TYPE_CHECKING:
Expand Down Expand Up @@ -102,10 +99,10 @@ def _batched_odeint(
return yt if event_fn is None else (event_t, yt)


@ode_simulate.register(TorchDiffEq)
@_simulate.register(ODEDynamics, TorchDiffEq)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

multimethod dispatches on all of the arguments to _simulate, so what does this registration mean?

Copy link
Collaborator Author

@SamWitty SamWitty Oct 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that this registration means that we dispatch only on the first two arguments of _simulate, and not the remaining. This isn't documented in multimethod from what I could find, but when I remove the arguments I get the default NotImplementedError for _simulate or the other dispatches.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if that's precisely true (I would guess it implicitly provides a type for the remaining arguments, maybe object or whatever is in the signature?), but I don't think it's a good idea to rely on undocumented behavior in upstream dependencies since it can change without warning in ways that are difficult or impossible to work around.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I'll try and debug why the default dispatch on the full collection of arguments isn't working as expected.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eb8680 , after a bit of digging around I've found two options that rely less on undocumented behavior and pass tests. I have not been able to get multimethod to work with parametric types fully specified.

Option 1: Pass unparameterized types to the *args of register and leave fully parametric types explicit in the function signature. E.g.

@_simulate.register(ODEDynamics, TorchDiffEq, State, torch.Tensor, torch.Tensor)
def torchdiffeq_ode_simulate(
    dynamics: ODEDynamics[torch.Tensor, torch.Tensor],
    solver: TorchDiffEq,
    initial_state: State[torch.Tensor],
    start_time: torch.Tensor,
    end_time: torch.Tensor,
) -> State[torch.Tensor]:
    timespan = torch.stack((start_time, end_time))
    trajectory = _torchdiffeq_ode_simulate_inner(
        dynamics, initial_state, timespan, **solver.odeint_kwargs
    )
    return trajectory[..., -1].to_state()

Option 2: Remove type parameters from function signature and use default register without any *args. E.g.

@_simulate.register
def torchdiffeq_ode_simulate(
    dynamics: ODEDynamics,
    solver: TorchDiffEq,
    initial_state: State,
    start_time: torch.Tensor,
    end_time: torch.Tensor,
) -> State:
    timespan = torch.stack((start_time, end_time))
    trajectory = _torchdiffeq_ode_simulate_inner(
        dynamics, initial_state, timespan, **solver.odeint_kwargs
    )
    return trajectory[..., -1].to_state()

Which of these two would you prefer? Alternatively, would you like something else?

Unrelated: I think there is some misuse of types a bit scattered throughout the module, so a separate type refactoring PR (ideally pair programmed) might be nice before merging into master.

def torchdiffeq_ode_simulate(
solver: TorchDiffEq,
dynamics: ODEDynamics,
solver: TorchDiffEq,
initial_state: State[torch.Tensor],
start_time: torch.Tensor,
end_time: torch.Tensor,
Expand All @@ -117,10 +114,10 @@ def torchdiffeq_ode_simulate(
return trajectory[..., -1].to_state()


@_ode_simulate_trajectory.register(TorchDiffEq)
@_simulate_trajectory.register(ODEDynamics, TorchDiffEq)
def torchdiffeq_ode_simulate_trajectory(
solver: TorchDiffEq,
dynamics: ODEDynamics,
solver: TorchDiffEq,
initial_state: State[torch.Tensor],
timespan: torch.Tensor,
) -> State[torch.Tensor]:
Expand All @@ -129,10 +126,10 @@ def torchdiffeq_ode_simulate_trajectory(
)


@_ode_get_next_interruptions_dynamic.register(TorchDiffEq)
@_get_next_interruptions_dynamic.register(ODEDynamics, TorchDiffEq)
def torchdiffeq_get_next_interruptions_dynamic(
solver: TorchDiffEq,
dynamics: ODEDynamics[torch.Tensor, torch.Tensor],
solver: TorchDiffEq,
start_state: State[torch.Tensor],
start_time: torch.Tensor,
next_static_interruption: StaticInterruption,
Expand Down
20 changes: 18 additions & 2 deletions chirho/dynamical/internals/dynamical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional, TypeVar

import torch
from multimethod import multimethod

from chirho.dynamical.handlers.solver import Solver
from chirho.dynamical.ops.dynamical import Dynamics, State, Trajectory
Expand All @@ -10,7 +11,6 @@
T = TypeVar("T")


@functools.singledispatch
def simulate_trajectory(
dynamics: Dynamics[S, T],
initial_state: State[T],
Expand All @@ -30,8 +30,24 @@ def simulate_trajectory(
"\n \n `with TorchDiffEq():` \n"
"\t `simulate_trajectory(dynamics, initial_state, start_time, end_time)`"
)

return _simulate_trajectory(dynamics, solver, initial_state, timespan, **kwargs)


@multimethod
def _simulate_trajectory(
dynamics: Dynamics[S, T],
solver: Solver,
initial_state: State[T],
timespan: T,
**kwargs,
) -> Trajectory[T]:
"""
Simulate a dynamical system.
"""

raise NotImplementedError(
f"simulate_trajectory not implemented for type {type(dynamics)}"
f"simulate_trajectory not implemented for dynamics of type {type(dynamics)} and solver of type {type(solver)}"
)


Expand Down
32 changes: 30 additions & 2 deletions chirho/dynamical/internals/interruption.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
from typing import List, Optional, Tuple, TypeVar

import pyro
from multimethod import multimethod

from chirho.dynamical.handlers.interruption import (
DynamicInterruption,
Expand Down Expand Up @@ -97,7 +97,6 @@ def get_next_interruptions(


# noinspection PyUnusedLocal
@functools.singledispatch
def get_next_interruptions_dynamic(
dynamics: Dynamics[S, T],
start_state: State[T],
Expand All @@ -107,6 +106,35 @@ def get_next_interruptions_dynamic(
*,
solver: Optional[Solver] = None,
**kwargs,
) -> Tuple[Tuple[Interruption, ...], T]:
if solver is None:
raise ValueError(
"`get_next_interruptions_dynamic` requires a solver. To specify a solver, use the keyword argument "
"`solver` in the call to `simulate` or use with a solver effect handler as a context manager. "
"For example,"
"\n \n `with TorchDiffEq():` \n"
"\t `simulate(dynamics, initial_state, start_time, end_time)`"
)
return _get_next_interruptions_dynamic(
dynamics,
solver,
start_state,
start_time,
next_static_interruption,
dynamic_interruptions,
**kwargs,
)


@multimethod
def _get_next_interruptions_dynamic(
dynamics: Dynamics[S, T],
solver: Solver,
start_state: State[T],
start_time: T,
next_static_interruption: StaticInterruption,
dynamic_interruptions: List[DynamicInterruption],
**kwargs,
) -> Tuple[Tuple[Interruption, ...], T]:
raise NotImplementedError(
f"get_next_interruptions_dynamic not implemented for type {type(dynamics)}"
Expand Down
39 changes: 15 additions & 24 deletions chirho/dynamical/ops/dynamical.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
import functools
from typing import (
Callable,
FrozenSet,
Generic,
Optional,
Protocol,
TypeVar,
runtime_checkable,
)
from typing import FrozenSet, Generic, Optional, TypeVar

import pyro
import torch
from multimethod import multimethod

from chirho.dynamical.handlers.solver import Solver

Expand All @@ -20,7 +13,6 @@

class StateOrTrajectory(Generic[T]):
def __init__(self, **values: T):
# self.class_name =
self.__dict__["_values"] = {}
for k, v in values.items():
setattr(self, k, v)
Expand Down Expand Up @@ -158,9 +150,12 @@ def _append_trajectory(self, other: Trajectory):
)


@runtime_checkable
class Dynamics(Protocol[S, T]):
diff: Callable[[State[S], State[S]], T]
class Dynamics(Generic[S, T]):
def diff(self, dX: State[S], X: State[S]) -> T:
raise NotImplementedError

def observation(self, X: State[S]):
raise NotImplementedError


@pyro.poutine.runtime.effectful(type="simulate")
Expand All @@ -184,30 +179,26 @@ def simulate(
"\t `with TorchDiffEq():` \n"
"\t \t `simulate(dynamics, initial_state, start_time, end_time)`"
)
return _simulate(
dynamics, initial_state, start_time, end_time, solver=solver, **kwargs
)
return _simulate(dynamics, solver, initial_state, start_time, end_time, **kwargs)


# This redirection distinguishes between the effectful operation, and the
# type-directed dispatch on Dynamics
@functools.singledispatch
# type-directed dispatch on Dynamics and Solver
@multimethod
def _simulate(
dynamics: Dynamics[S, T],
solver: Solver,
initial_state: State[T],
start_time: T,
end_time: T,
*,
solver: Optional[Solver] = None,
**kwargs,
) -> State[T]:
"""
Simulate a dynamical system.
"""
raise NotImplementedError(f"simulate not implemented for type {type(dynamics)}")


simulate.register = _simulate.register
raise NotImplementedError(
f"simulate not implemented for dynamics of type {type(dynamics)} and solver of type {type(solver)}"
)


def _index_last_dim_with_mask(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
Expand Down
Loading