Skip to content

Commit

Permalink
private reversible
Browse files Browse the repository at this point in the history
  • Loading branch information
sammccallum committed Nov 26, 2024
1 parent 8a7448e commit 24d1935
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 5 deletions.
117 changes: 113 additions & 4 deletions diffrax/_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import functools as ft
import warnings
from collections.abc import Callable, Iterable
from typing import Any, cast, Optional, Union
from typing import Any, cast, Optional, TypeAlias, TypeVar, Union

import equinox as eqx
import equinox.internal as eqxi
Expand All @@ -13,14 +13,19 @@
import lineax as lx
import optimistix.internal as optxi
from equinox.internal import ω
from jaxtyping import PyTree

from ._custom_types import Args, BoolScalarLike, DenseInfo, RealScalarLike, VF, Y
from ._heuristics import is_sde, is_unsafe_sde
from ._saveat import save_y, SaveAt, SubSaveAt
from ._solution import RESULTS, update_result
from ._solver import (
AbstractAdaptiveSolver,
AbstractItoSolver,
AbstractRungeKutta,
AbstractSolver,
AbstractStratonovichSolver,
Reversible,
AbstractWrappedSolver,
)
from ._term import AbstractTerm, AdjointTerm

Expand Down Expand Up @@ -1043,7 +1048,7 @@ def cond_fun(state):

class ReversibleAdjoint(AbstractAdjoint):
"""
Backpropagate through [`diffrax.diffeqsolve`][] by using the reversible solver
Backpropagate through [`diffrax.diffeqsolve`][] using the reversible solver
method.
This method wraps the passed solver to create an algebraically reversible version
Expand Down Expand Up @@ -1121,7 +1126,7 @@ def loop(
"`diffrax.ReversibleAdjoint` is not compatible with events."
)

solver = Reversible(solver, self.l)
solver = _Reversible(solver, self.l)
tprev = init_state.tprev
tnext = init_state.tnext
y = init_state.y
Expand Down Expand Up @@ -1164,3 +1169,107 @@ def loop(
In most cases the default value is sufficient. However, if you find yourself needing
greater control over stability it can be passed as an argument.
"""

_BaseSolverState = TypeVar("_BaseSolverState")
_SolverState: TypeAlias = tuple[_BaseSolverState, Y]


def _add_maybe_none(x, y):
if x is None:
return None
else:
return x + y


class _Reversible(
AbstractAdaptiveSolver[_SolverState], AbstractWrappedSolver[_SolverState]
):
"""
Reversible solver method.
Allows any solver ([`diffrax.AbstractSolver`][]) to be made algebraically
reversible. The convergence order of the reversible solver is inherited from the
wrapped solver.
Gradient calculation through the reversible solver is exact (up to floating
point errors) and backpropagation becomes a linear in time $O(t)$ and constant in
memory $O(1)$ algorithm.
This is implemented in [`diffrax.ReversibleAdjoint`][] and passed to
[`diffrax.diffeqsolve`][] as `adjoint=diffrax.ReversibleAdjoint()`.
"""

solver: AbstractSolver
l: float = 0.999

@property
def term_structure(self):
return self.solver.term_structure

@property
def interpolation_cls(self): # pyright: ignore
return self.solver.interpolation_cls

@property
def term_compatible_contr_kwargs(self):
return self.solver.term_compatible_contr_kwargs

@property
def root_finder(self):
return self.solver.root_finder # pyright: ignore

@property
def root_find_max_steps(self):
return self.solver.root_find_max_steps # pyright: ignore

def order(self, terms: PyTree[AbstractTerm]) -> Optional[int]:
return self.solver.order(terms)

def strong_order(self, terms: PyTree[AbstractTerm]) -> Optional[RealScalarLike]:
return self.solver.strong_order(terms)

def init(
self,
terms: PyTree[AbstractTerm],
t0: RealScalarLike,
t1: RealScalarLike,
y0: Y,
args: Args,
) -> _SolverState:
if isinstance(self.solver, AbstractRungeKutta):
object.__setattr__(self.solver.tableau, "fsal", False)
object.__setattr__(self.solver.tableau, "ssal", False)
original_solver_init = self.solver.init(terms, t0, t1, y0, args)
return (original_solver_init, y0)

def step(
self,
terms: PyTree[AbstractTerm],
t0: RealScalarLike,
t1: RealScalarLike,
y0: Y,
args: Args,
solver_state: _SolverState,
made_jump: BoolScalarLike,
) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]:
original_solver_state, z0 = solver_state

step_z0, z_error, dense_info, original_solver_state, result1 = self.solver.step(
terms, t0, t1, z0, args, original_solver_state, made_jump
)
y1 = (self.l * (ω(y0) - ω(z0)) + ω(step_z0)).ω

step_y1, y_error, _, _, result2 = self.solver.step(
terms, t1, t0, y1, args, original_solver_state, made_jump
)
z1 = (ω(y1) + ω(z0) - ω(step_y1)).ω

solver_state = (original_solver_state, z1)
result = update_result(result1, result2)

return y1, _add_maybe_none(z_error, y_error), dense_info, solver_state, result

def func(
self, terms: PyTree[AbstractTerm], t0: RealScalarLike, y0: Y, args: Args
) -> VF:
return self.solver.func(terms, t0, y0, args)
1 change: 0 additions & 1 deletion diffrax/_solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
StratonovichMilstein as StratonovichMilstein,
)
from .ralston import Ralston as Ralston
from .reversible import Reversible as Reversible
from .reversible_heun import ReversibleHeun as ReversibleHeun
from .runge_kutta import (
AbstractDIRK as AbstractDIRK,
Expand Down

0 comments on commit 24d1935

Please sign in to comment.