diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ed8edc71..6efa999d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.7 + rev: v0.2.2 hooks: - id: ruff # linter types_or: [ python, pyi, jupyter ] @@ -8,7 +8,7 @@ repos: - id: ruff-format # formatter types_or: [ python, pyi, jupyter ] - repo: https://github.com/RobertCraigie/pyright-python - rev: v1.1.316 + rev: v1.1.350 hooks: - id: pyright additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typing_extensions] diff --git a/diffrax/__init__.py b/diffrax/__init__.py index 688e2af8..7d1547c6 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -13,7 +13,11 @@ UnsafeBrownianPath as UnsafeBrownianPath, VirtualBrownianTree as VirtualBrownianTree, ) -from ._custom_types import LevyVal as LevyVal +from ._custom_types import ( + AbstractBrownianReturn as AbstractBrownianReturn, + BrownianIncrement as BrownianIncrement, + SpaceTimeLevyArea as SpaceTimeLevyArea, +) from ._event import ( AbstractDiscreteTerminatingEvent as AbstractDiscreteTerminatingEvent, DiscreteTerminatingEvent as DiscreteTerminatingEvent, diff --git a/diffrax/_adjoint.py b/diffrax/_adjoint.py index eeb2d5c9..4535a047 100644 --- a/diffrax/_adjoint.py +++ b/diffrax/_adjoint.py @@ -1,8 +1,8 @@ import abc import functools as ft import warnings -from collections.abc import Iterable -from typing import Any, Optional, Union +from collections.abc import Callable, Iterable +from typing import Any, cast, Optional, Union import equinox as eqx import equinox.internal as eqxi @@ -20,6 +20,9 @@ from ._term import AbstractTerm, AdjointTerm +ω = cast(Callable, ω) + + def _is_none(x): return x is None diff --git a/diffrax/_brownian/base.py b/diffrax/_brownian/base.py index 96a07253..53b1ddfc 100644 --- a/diffrax/_brownian/base.py +++ b/diffrax/_brownian/base.py @@ -1,17 +1,20 @@ import abc -from typing import Optional, Union +from typing import Optional, TypeVar, Union from equinox.internal import AbstractVar from jaxtyping import Array, PyTree -from .._custom_types import LevyArea, LevyVal, RealScalarLike +from .._custom_types import AbstractBrownianReturn, RealScalarLike from .._path import AbstractPath -class AbstractBrownianPath(AbstractPath): +_Control = TypeVar("_Control", bound=Union[PyTree[Array], AbstractBrownianReturn]) + + +class AbstractBrownianPath(AbstractPath[_Control]): """Abstract base class for all Brownian paths.""" - levy_area: AbstractVar[LevyArea] + levy_area: AbstractVar[type[AbstractBrownianReturn]] @abc.abstractmethod def evaluate( @@ -20,7 +23,7 @@ def evaluate( t1: Optional[RealScalarLike] = None, left: bool = True, use_levy: bool = False, - ) -> Union[PyTree[Array], LevyVal]: + ) -> _Control: r"""Samples a Brownian increment $w(t_1) - w(t_0)$. Each increment has distribution $\mathcal{N}(0, t_1 - t_0)$. diff --git a/diffrax/_brownian/path.py b/diffrax/_brownian/path.py index 33b7df78..683e481d 100644 --- a/diffrax/_brownian/path.py +++ b/diffrax/_brownian/path.py @@ -10,7 +10,12 @@ import lineax.internal as lxi from jaxtyping import Array, PRNGKeyArray, PyTree -from .._custom_types import levy_tree_transpose, LevyArea, LevyVal, RealScalarLike +from .._custom_types import ( + BrownianIncrement, + levy_tree_transpose, + RealScalarLike, + SpaceTimeLevyArea, +) from .._misc import ( force_bitcast_convert_type, is_tuple_of_ints, @@ -42,14 +47,16 @@ class UnsafeBrownianPath(AbstractBrownianPath): """ shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True) - levy_area: LevyArea = eqx.field(static=True) + levy_area: type[Union[BrownianIncrement, SpaceTimeLevyArea]] = eqx.field( + static=True + ) key: PRNGKeyArray def __init__( self, shape: Union[tuple[int, ...], PyTree[jax.ShapeDtypeStruct]], key: PRNGKeyArray, - levy_area: LevyArea = "", + levy_area: type[Union[BrownianIncrement, SpaceTimeLevyArea]], ): self.shape = ( jax.ShapeDtypeStruct(shape, lxi.default_floating_dtype()) @@ -57,11 +64,7 @@ def __init__( else shape ) self.key = key - if levy_area not in ["", "space-time"]: - raise ValueError( - f"levy_area must be one of '', 'space-time', but got {levy_area}." - ) - self.levy_area = levy_area + self.levy_area = levy_area # pyright: ignore[reportIncompatibleVariableOverride] if any( not jnp.issubdtype(x.dtype, jnp.inexact) @@ -70,11 +73,11 @@ def __init__( raise ValueError("UnsafeBrownianPath dtypes all have to be floating-point.") @property - def t0(self): + def t0(self): # pyright: ignore[reportIncompatibleVariableOverride] return -jnp.inf @property - def t1(self): + def t1(self): # pyright: ignore[reportIncompatibleVariableOverride] return jnp.inf @eqx.filter_jit @@ -84,7 +87,7 @@ def evaluate( t1: Optional[RealScalarLike] = None, left: bool = True, use_levy: bool = False, - ) -> Union[PyTree[Array], LevyVal]: + ) -> Union[PyTree[Array], BrownianIncrement, SpaceTimeLevyArea]: del left if t1 is None: dtype = jnp.result_type(t0) @@ -112,7 +115,7 @@ def evaluate( ) if use_levy: out = levy_tree_transpose(self.shape, out) - assert isinstance(out, LevyVal) + assert isinstance(out, (BrownianIncrement, SpaceTimeLevyArea)) return out @staticmethod @@ -121,25 +124,26 @@ def _evaluate_leaf( t1: RealScalarLike, key, shape: jax.ShapeDtypeStruct, - levy_area: str, + levy_area: type[Union[BrownianIncrement, SpaceTimeLevyArea]], use_levy: bool, ): w_std = jnp.sqrt(t1 - t0).astype(shape.dtype) + w = jr.normal(key, shape.shape, shape.dtype) * w_std + dt = t1 - t0 - if levy_area == "space-time": + if levy_area is SpaceTimeLevyArea: key, key_hh = jr.split(key, 2) hh_std = w_std / math.sqrt(12) hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std - elif levy_area == "": - hh = None + levy_val = SpaceTimeLevyArea(dt=dt, W=w, H=hh, K=None) + elif levy_area is BrownianIncrement: + levy_val = BrownianIncrement(dt=dt, W=w) else: assert False - w = jr.normal(key, shape.shape, shape.dtype) * w_std if use_levy: - return LevyVal(dt=t1 - t0, W=w, H=hh, K=None) - else: - return w + return levy_val + return w UnsafeBrownianPath.__init__.__doc__ = """ diff --git a/diffrax/_brownian/tree.py b/diffrax/_brownian/tree.py index b901675b..78c1f995 100644 --- a/diffrax/_brownian/tree.py +++ b/diffrax/_brownian/tree.py @@ -1,5 +1,5 @@ import math -from typing import Literal, Optional, Union +from typing import Literal, Optional, TypeVar, Union from typing_extensions import TypeAlias import equinox as eqx @@ -13,12 +13,13 @@ from jaxtyping import Array, Float, PRNGKeyArray, PyTree from .._custom_types import ( + AbstractBrownianReturn, BoolScalarLike, + BrownianIncrement, IntScalarLike, levy_tree_transpose, - LevyArea, - LevyVal, RealScalarLike, + SpaceTimeLevyArea, ) from .._misc import ( is_tuple_of_ints, @@ -58,6 +59,7 @@ Float[Array, " *shape"], Float[Array, " *shape"], Float[Array, " *shape"] ] _Spline: TypeAlias = Literal["sqrt", "quad", "zero"] +_BrownianReturn = TypeVar("_BrownianReturn", bound=AbstractBrownianReturn) class _State(eqx.Module): @@ -69,7 +71,7 @@ class _State(eqx.Module): bkk_s_u_su: Optional[FloatTriple] # \bar{K}_s, _u, _{s,u} -def _levy_diff(_, x0: tuple, x1: tuple) -> LevyVal: +def _levy_diff(_, x0: tuple, x1: tuple) -> Union[BrownianIncrement, SpaceTimeLevyArea]: r"""Computes $(W_{s,u}, H_{s,u})$ from $(W_s, \bar{H}_{s,u})$ and $(W_u, \bar{H}_u)$, where $\bar{H}_u = u * H_u$. @@ -89,7 +91,7 @@ def _levy_diff(_, x0: tuple, x1: tuple) -> LevyVal: dt0, w0 = x0 dt1, w1 = x1 su = jnp.asarray(dt1 - dt0, dtype=w0.dtype) - return LevyVal(dt=su, W=w1 - w0, H=None, K=None) + return BrownianIncrement(dt=su, W=w1 - w0) elif len(x0) == 4: # space-time levy area case assert len(x1) == 4 @@ -103,18 +105,18 @@ def _levy_diff(_, x0: tuple, x1: tuple) -> LevyVal: u_bb_s = dt1 * w0 - dt0 * w1 bhh_su = bhh1 - bhh0 - 0.5 * u_bb_s # bhh_su = H_{s,u} * (u-s) hh_su = inverse_su * bhh_su - return LevyVal(dt=su, W=w_su, H=hh_su, K=None) + return SpaceTimeLevyArea(dt=su, W=w_su, H=hh_su, K=None) else: assert False -def _make_levy_val(_, x: tuple) -> LevyVal: +def _make_levy_val(_, x: tuple) -> Union[BrownianIncrement, SpaceTimeLevyArea]: if len(x) == 2: dt, w = x - return LevyVal(dt=dt, W=w, H=None, K=None) + return BrownianIncrement(dt=dt, W=w) elif len(x) == 4: dt, w, hh, bhh = x - return LevyVal(dt=dt, W=w, H=hh, K=None) + return SpaceTimeLevyArea(dt=dt, W=w, H=hh, K=None) else: assert False @@ -171,7 +173,9 @@ class VirtualBrownianTree(AbstractBrownianPath): t1: RealScalarLike tol: RealScalarLike shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True) - levy_area: LevyArea = eqx.field(static=True) + levy_area: type[Union[BrownianIncrement, SpaceTimeLevyArea]] = eqx.field( + static=True + ) key: PyTree[PRNGKeyArray] _spline: _Spline = eqx.field(static=True) @@ -183,21 +187,18 @@ def __init__( tol: RealScalarLike, shape: Union[tuple[int, ...], PyTree[jax.ShapeDtypeStruct]], key: PRNGKeyArray, - levy_area: LevyArea = "", + levy_area: type[ + Union[BrownianIncrement, SpaceTimeLevyArea] + ] = BrownianIncrement, _spline: _Spline = "sqrt", ): (t0, t1) = eqx.error_if((t0, t1), t0 >= t1, "t0 must be strictly less than t1") - self.t0 = t0 - self.t1 = t1 + self.t0 = t0 # pyright: ignore[reportIncompatibleVariableOverride] + self.t1 = t1 # pyright: ignore[reportIncompatibleVariableOverride] # Since we rescale the interval to [0,1], # we need to rescale the tolerance too. self.tol = tol / (self.t1 - self.t0) - - if levy_area not in ["", "space-time"]: - raise ValueError( - f"levy_area must be one of '', 'space-time', but got {levy_area}." - ) - self.levy_area = levy_area + self.levy_area = levy_area # pyright: ignore[reportIncompatibleVariableOverride] self._spline = _spline self.shape = ( jax.ShapeDtypeStruct(shape, lxi.default_floating_dtype()) @@ -213,7 +214,7 @@ def __init__( ) self.key = split_by_tree(key, self.shape) - def _denormalise_bm_inc(self, x: LevyVal) -> LevyVal: + def _denormalise_bm_inc(self, x: _BrownianReturn) -> _BrownianReturn: # Rescaling back from [0, 1] to the original interval [t0, t1]. interval_len = self.t1 - self.t0 # can be any dtype sqrt_len = jnp.sqrt(interval_len) @@ -227,12 +228,13 @@ def sqrt_mult(z): dtype = jnp.result_type(z) return jnp.astype(sqrt_len, dtype) * z - return LevyVal( - dt=jtu.tree_map(mult, x.dt), - W=jtu.tree_map(sqrt_mult, x.W), - H=jtu.tree_map(sqrt_mult, x.H), - K=jtu.tree_map(sqrt_mult, x.K), - ) + def is_dt(z): + return z is x.dt + + dt, other = eqx.partition(x, is_dt) + dt_normalized = jtu.tree_map(mult, dt) + other_normalized = jtu.tree_map(sqrt_mult, other) + return eqx.combine(dt_normalized, other_normalized) @eqx.filter_jit def evaluate( @@ -241,7 +243,7 @@ def evaluate( t1: Optional[RealScalarLike] = None, left: bool = True, use_levy: bool = False, - ) -> Union[PyTree[Array], LevyVal]: + ) -> Union[PyTree[Array], BrownianIncrement, SpaceTimeLevyArea]: t0 = eqxi.nondifferentiable(t0, name="t0") # map the interval [self.t0, self.t1] onto [0,1] t0 = linear_rescale(self.t0, t0, self.t1) @@ -249,7 +251,6 @@ def evaluate( if t1 is None: levy_out = levy_0 levy_out = jtu.tree_map(_make_levy_val, self.shape, levy_out) - else: t1 = eqxi.nondifferentiable(t1, name="t1") # map the interval [self.t0, self.t1] onto [0,1] @@ -260,7 +261,7 @@ def evaluate( levy_out = levy_tree_transpose(self.shape, levy_out) # now map [0,1] back onto [self.t0, self.t1] levy_out = self._denormalise_bm_inc(levy_out) - assert isinstance(levy_out, LevyVal) + assert isinstance(levy_out, (BrownianIncrement, SpaceTimeLevyArea)) return levy_out if use_levy else levy_out.W def _evaluate(self, r: RealScalarLike) -> PyTree: @@ -286,7 +287,7 @@ def _evaluate_leaf( t0 = jnp.zeros((), dtype) r = jnp.asarray(r, dtype) - if self.levy_area == "space-time": + if self.levy_area is SpaceTimeLevyArea: state_key, init_key_w, init_key_la = jr.split(key, 3) bhh_1 = jr.normal(init_key_la, shape, dtype) / math.sqrt(12) bhh_0 = jnp.zeros_like(bhh_1) @@ -332,13 +333,12 @@ def _body_fun(_state: _State): _key = jnp.where(_cond, _key_st, _key_tu) _w = _split_interval(_cond, _w_stu, _w_inc) - if not self.levy_area == "": + _bkk = None + if self.levy_area is not BrownianIncrement: assert _bhh_stu is not None and _bhh_st_tu is not None _bhh = _split_interval(_cond, _bhh_stu, _bhh_st_tu) - _bkk = None else: _bhh = None - _bkk = None return _State( level=_level, @@ -360,8 +360,7 @@ def _body_fun(_state: _State): w_s, w_u, w_su = final_state.w_s_u_su - # BM only case - if self.levy_area == "": + if self.levy_area is BrownianIncrement: w_mean = w_s + sr / su * w_su if self._spline == "sqrt": z = jr.normal(final_state.key, shape, dtype) @@ -376,7 +375,7 @@ def _body_fun(_state: _State): w_r = w_mean + bb return r, w_r - elif self.levy_area == "space-time": + elif self.levy_area is SpaceTimeLevyArea: # This is based on Theorem 6.1.4 of Foster's thesis (see above). assert final_state.bhh_s_u_su is not None @@ -396,7 +395,7 @@ def _body_fun(_state: _State): x2 = jnp.zeros(shape, dtype) else: raise ValueError( - f"When levy_area='space-time', only 'sqrt' and" + f"When levy_area='SpaceTimeLevyArea', only 'sqrt' and" f" 'zero' splines are permitted, got {self._spline}." ) @@ -469,7 +468,7 @@ def _brownian_arch( w_s, w_u, w_su = _state.w_s_u_su - if self.levy_area == "space-time": + if self.levy_area is SpaceTimeLevyArea: assert _state.bhh_s_u_su is not None assert _state.bkk_s_u_su is None bhh_s, bhh_u, bhh_su = _state.bhh_s_u_su diff --git a/diffrax/_custom_types.py b/diffrax/_custom_types.py index e6a09569..0d45fe60 100644 --- a/diffrax/_custom_types.py +++ b/diffrax/_custom_types.py @@ -1,6 +1,5 @@ import typing -from typing import Any, Literal, Optional, TYPE_CHECKING, Union -from typing_extensions import TypeAlias +from typing import Any, Optional, TYPE_CHECKING, Union import equinox as eqx import equinox.internal as eqxi @@ -53,32 +52,42 @@ BufferDenseInfos = dict[str, PyTree[eqxi.MaybeBuffer[Shaped[Array, "times ..."]]]] sentinel: Any = eqxi.doc_repr(object(), "sentinel") -LevyArea: TypeAlias = Literal["", "space-time"] +class AbstractBrownianReturn(eqx.Module): + dt: eqx.AbstractVar[PyTree] + W: eqx.AbstractVar[PyTree] -class LevyVal(eqx.Module): + +class BrownianIncrement(AbstractBrownianReturn): + dt: PyTree + W: PyTree + + +class SpaceTimeLevyArea(AbstractBrownianReturn): dt: PyTree W: PyTree H: Optional[PyTree] K: Optional[PyTree] -def levy_tree_transpose(tree_shape, tree: PyTree): - """Helper that takes a PyTree of LevyVals and transposes - into a LevyVal of PyTrees. +def levy_tree_transpose( + tree_shape, tree: PyTree[AbstractBrownianReturn] +) -> AbstractBrownianReturn: + """Helper that takes a PyTree of AbstractBrownianReturn and transposes + into an AbstractBrownianReturn of PyTrees. **Arguments:** - `tree_shape`: Corresponds to `outer_treedef` in `jax.tree_transpose`. - - `levy_area`: can be `""` or `"space-time"`, which indicates - which fields of the LevyVal will have values. - - `tree`: the PyTree of LevyVals to transpose. + - `tree`: the PyTree of AbstractBrownianReturn to transpose. **Returns:** - A `LevyVal` of PyTrees. + An `AbstractBrownianReturn` of PyTrees. """ - inner_tree = jtu.tree_leaves(tree, is_leaf=lambda x: isinstance(x, LevyVal))[0] + inner_tree = jtu.tree_leaves( + tree, is_leaf=lambda x: isinstance(x, AbstractBrownianReturn) + )[0] inner_tree_shape = jtu.tree_structure(inner_tree) return jtu.tree_transpose( outer_treedef=jtu.tree_structure(tree_shape), diff --git a/diffrax/_global_interpolation.py b/diffrax/_global_interpolation.py index f10a21fd..31d8b038 100644 --- a/diffrax/_global_interpolation.py +++ b/diffrax/_global_interpolation.py @@ -1,6 +1,6 @@ import functools as ft from collections.abc import Callable -from typing import Optional, TYPE_CHECKING +from typing import cast, Optional, TYPE_CHECKING import equinox as eqx import equinox.internal as eqxi @@ -24,6 +24,9 @@ from ._path import AbstractPath +ω = cast(Callable, ω) + + class AbstractGlobalInterpolation(AbstractPath): ts: AbstractVar[Real[Array, " times"]] ts_size: AbstractVar[IntScalarLike] diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index d0696402..34e8c030 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -2,7 +2,14 @@ import typing import warnings from collections.abc import Callable -from typing import Any, get_args, get_origin, Optional, Tuple, TYPE_CHECKING +from typing import ( + Any, + get_args, + get_origin, + Optional, + Tuple, + TYPE_CHECKING, +) import equinox as eqx import equinox.internal as eqxi @@ -49,6 +56,7 @@ StepTo, ) from ._term import AbstractTerm, MultiTerm, ODETerm, WrapTerm +from ._typing import better_isinstance, get_args_of, get_origin_no_specials class SaveState(eqx.Module): @@ -78,26 +86,69 @@ class State(eqx.Module): progress_meter_state: PyTree[Array] -def _is_none(x): +def _is_none(x: Any) -> bool: return x is None -def _term_compatible(terms, term_structure): - def _check(term_cls, term): - if get_origin(term_cls) is MultiTerm: +def _term_compatible( + y: PyTree[ArrayLike], + args: PyTree[Any], + terms: PyTree[AbstractTerm], + term_structure: PyTree, +) -> bool: + error_msg = "term_structure" + + def _check(term_cls, term, yi): + if get_origin_no_specials(term_cls, error_msg) is MultiTerm: if isinstance(term, MultiTerm): [_tmp] = get_args(term_cls) assert get_origin(_tmp) in (tuple, Tuple), "Malformed term_structure" - if not _term_compatible(term.terms, get_args(_tmp)): - raise ValueError + assert len(term.terms) == len(get_args(_tmp)) + for term, arg in zip(term.terms, get_args(_tmp)): + if not _term_compatible(yi, args, term, arg): + raise ValueError else: raise ValueError else: - if not isinstance(term, term_cls): + # Check that `term` is an instance of `term_cls` (ignoring any generic + # parameterization). + origin_cls = get_origin_no_specials(term_cls, error_msg) + if origin_cls is None: + origin_cls = term_cls + if not isinstance(term, origin_cls): raise ValueError + # Now check the generic parametrization of `term_cls`; can be one of: + # ----------------------------------------- + # `term_cls` | `term_args` + # --------------------------|-------------- + # AbstractTerm | () + # AbstractTerm[VF, Control] | (VF, Control) + # ----------------------------------------- + term_args = get_args_of(AbstractTerm, term_cls, error_msg) + n_term_args = len(term_args) + if n_term_args == 0: + pass + elif n_term_args == 2: + vf_type_expected, control_type_expected = term_args + vf_type = eqx.filter_eval_shape(term.vf, 0.0, yi, args) + vf_type_compatible = eqx.filter_eval_shape( + better_isinstance, vf_type, vf_type_expected + ) + if not vf_type_compatible: + raise ValueError + control_type = jax.eval_shape(term.contr, 0.0, 0.0) + control_type_compatible = eqx.filter_eval_shape( + better_isinstance, control_type, control_type_expected + ) + if not control_type_compatible: + raise ValueError + else: + assert False, "Malformed term structure" + # If we've got to this point then the term is compatible + try: - jtu.tree_map(_check, term_structure, terms) + jtu.tree_map(_check, term_structure, terms, y) except ValueError: # ValueError may also arise from mismatched tree structures return False @@ -661,47 +712,6 @@ def diffeqsolve( stacklevel=2, ) - # Backward compatibility - if isinstance( - solver, (EulerHeun, ItoMilstein, StratonovichMilstein) - ) and _term_compatible(terms, (ODETerm, AbstractTerm)): - warnings.warn( - "Passing `terms=(ODETerm(...), SomeOtherTerm(...))` to " - f"{solver.__class__.__name__} is deprecated in favour of " - "`terms=MultiTerm(ODETerm(...), SomeOtherTerm(...))`. This means that " - "the same terms can now be passed used for both general and SDE-specific " - "solvers!", - stacklevel=2, - ) - terms = MultiTerm(*terms) - - # Error checking - if not _term_compatible(terms, solver.term_structure): - raise ValueError( - "`terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`), with " - f"structure {solver.term_structure}" - ) - - if is_sde(terms): - if not isinstance(solver, (AbstractItoSolver, AbstractStratonovichSolver)): - warnings.warn( - f"`{type(solver).__name__}` is not marked as converging to either the " - "Itô or the Stratonovich solution.", - stacklevel=2, - ) - if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController): - # Specific check to not work even if using HalfSolver(Euler()) - if isinstance(solver, Euler): - raise ValueError( - "An SDE should not be solved with adaptive step sizes with Euler's " - "method, as it may not converge to the correct solution." - ) - if is_unsafe_sde(terms): - if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController): - raise ValueError( - "`UnsafeBrownianPath` cannot be used with adaptive step sizes." - ) - # Allow setting e.g. t0 as an int with dt0 as a float. timelikes = [t0, t1, dt0] + [ s.ts for s in jtu.tree_leaves(saveat.subs, is_leaf=_is_subsaveat) @@ -745,6 +755,47 @@ def _promote(yi): y0 = jtu.tree_map(_promote, y0) del timelikes + # Backward compatibility + if isinstance( + solver, (EulerHeun, ItoMilstein, StratonovichMilstein) + ) and _term_compatible(y0, args, terms, (ODETerm, AbstractTerm)): + warnings.warn( + "Passing `terms=(ODETerm(...), SomeOtherTerm(...))` to " + f"{solver.__class__.__name__} is deprecated in favour of " + "`terms=MultiTerm(ODETerm(...), SomeOtherTerm(...))`. This means that " + "the same terms can now be passed used for both general and SDE-specific " + "solvers!", + stacklevel=2, + ) + terms = MultiTerm(*terms) + + # Error checking + if not _term_compatible(y0, args, terms, solver.term_structure): + raise ValueError( + "`terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`), with " + f"structure {solver.term_structure}" + ) + + if is_sde(terms): + if not isinstance(solver, (AbstractItoSolver, AbstractStratonovichSolver)): + warnings.warn( + f"`{type(solver).__name__}` is not marked as converging to either the " + "Itô or the Stratonovich solution.", + stacklevel=2, + ) + if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController): + # Specific check to not work even if using HalfSolver(Euler()) + if isinstance(solver, Euler): + raise ValueError( + "An SDE should not be solved with adaptive step sizes with Euler's " + "method, as it may not converge to the correct solution." + ) + if is_unsafe_sde(terms): + if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController): + raise ValueError( + "`UnsafeBrownianPath` cannot be used with adaptive step sizes." + ) + # Normalises time: if t0 > t1 then flip things around. direction = jnp.where(t0 < t1, 1, -1) t0 = t0 * direction diff --git a/diffrax/_local_interpolation.py b/diffrax/_local_interpolation.py index 1969cd87..97ccdb15 100644 --- a/diffrax/_local_interpolation.py +++ b/diffrax/_local_interpolation.py @@ -1,4 +1,5 @@ -from typing import Optional, TYPE_CHECKING +from collections.abc import Callable +from typing import cast, Optional, TYPE_CHECKING import jax.numpy as jnp import jax.tree_util as jtu @@ -17,6 +18,9 @@ from ._path import AbstractPath +ω = cast(Callable, ω) + + class AbstractLocalInterpolation(AbstractPath): pass diff --git a/diffrax/_path.py b/diffrax/_path.py index c7b90a3b..e78b8d8b 100644 --- a/diffrax/_path.py +++ b/diffrax/_path.py @@ -1,5 +1,5 @@ import abc -from typing import Optional, TYPE_CHECKING +from typing import Generic, Optional, TYPE_CHECKING, TypeVar import equinox as eqx import jax @@ -10,12 +10,14 @@ from typing import ClassVar as AbstractVar else: from equinox import AbstractVar -from jaxtyping import Array, PyTree -from ._custom_types import RealScalarLike +from ._custom_types import Control, RealScalarLike -class AbstractPath(eqx.Module): +_Control = TypeVar("_Control", bound=Control) + + +class AbstractPath(eqx.Module, Generic[_Control]): """Abstract base class for all paths. Every path has a start point `t0` and an end point `t1`. In between these values @@ -48,7 +50,7 @@ def evaluate(self, t0, t1=None, left=True): @abc.abstractmethod def evaluate( self, t0: RealScalarLike, t1: Optional[RealScalarLike] = None, left: bool = True - ) -> PyTree[Array]: + ) -> _Control: r"""Evaluate the path at any point in the interval $[t_0, t_1]$. **Arguments:** @@ -77,7 +79,7 @@ def evaluate( The increment of the path between `t0` and `t1`. """ - def derivative(self, t: RealScalarLike, left: bool = True) -> PyTree[Array]: + def derivative(self, t: RealScalarLike, left: bool = True) -> _Control: r"""Evaluate the derivative of the path. Essentially equivalent to `jax.jvp(self.evaluate, (t,), (jnp.ones_like(t),))` (and indeed this is its default implementation if no other is specified). diff --git a/diffrax/_root_finder/_verychord.py b/diffrax/_root_finder/_verychord.py index 4bce16a5..09c1f386 100644 --- a/diffrax/_root_finder/_verychord.py +++ b/diffrax/_root_finder/_verychord.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from typing import Any +from typing import Any, cast import equinox as eqx import jax @@ -15,6 +15,9 @@ from .._custom_types import Y +ω = cast(Callable, ω) + + def _small(diffsize: Scalar) -> Bool[Array, ""]: # TODO(kidger): make a more careful choice here -- the existence of this # function is pretty ad-hoc. diff --git a/diffrax/_solver/milstein.py b/diffrax/_solver/milstein.py index 5eae5e5b..8bc1ea5b 100644 --- a/diffrax/_solver/milstein.py +++ b/diffrax/_solver/milstein.py @@ -48,7 +48,7 @@ class StratonovichMilstein(AbstractStratonovichSolver): ] = LocalLinearInterpolation def order(self, terms): - raise ValueError("`StratonovichMilstein` should not used to solve ODEs.") + raise ValueError("`StratonovichMilstein` should not be used to solve ODEs.") def strong_order(self, terms): return 1 # assuming commutative noise @@ -122,7 +122,7 @@ class ItoMilstein(AbstractItoSolver): ] = LocalLinearInterpolation def order(self, terms): - raise ValueError("`ItoMilstein` should not used to solve ODEs.") + raise ValueError("`ItoMilstein` should not be used to solve ODEs.") def strong_order(self, terms): return 1 # assuming commutative noise diff --git a/diffrax/_solver/sil3.py b/diffrax/_solver/sil3.py index 7bec9d4b..7d24f5e9 100644 --- a/diffrax/_solver/sil3.py +++ b/diffrax/_solver/sil3.py @@ -1,4 +1,5 @@ -from typing import ClassVar +from collections.abc import Callable +from typing import cast, ClassVar import numpy as np import optimistix as optx @@ -15,6 +16,9 @@ ) +ω = cast(Callable, ω) + + # See # https://docs.kidger.site/diffrax/devdocs/predictor_dirk/ # for the construction of the a_predictor tableau, which is new here. diff --git a/diffrax/_step_size_controller/adaptive.py b/diffrax/_step_size_controller/adaptive.py index 5e497fc1..3ee544e9 100644 --- a/diffrax/_step_size_controller/adaptive.py +++ b/diffrax/_step_size_controller/adaptive.py @@ -33,6 +33,9 @@ from .base import AbstractStepSizeController +ω = cast(Callable, ω) + + def _select_initial_step( terms: PyTree[AbstractTerm], t0: RealScalarLike, diff --git a/diffrax/_step_size_controller/constant.py b/diffrax/_step_size_controller/constant.py index ce8ad425..bd00023b 100644 --- a/diffrax/_step_size_controller/constant.py +++ b/diffrax/_step_size_controller/constant.py @@ -47,7 +47,7 @@ def adapt_step_size( y1_candidate: Y, args: Args, y_error: Optional[Y], - error_order: RealScalarLike, + error_order: Optional[RealScalarLike], controller_state: RealScalarLike, ) -> tuple[bool, RealScalarLike, RealScalarLike, bool, RealScalarLike, RESULTS]: del t0, y0, y1_candidate, args, y_error, error_order diff --git a/diffrax/_term.py b/diffrax/_term.py index b2dd7673..e38b8c07 100644 --- a/diffrax/_term.py +++ b/diffrax/_term.py @@ -9,14 +9,18 @@ import jax.tree_util as jtu import numpy as np from equinox.internal import ω -from jaxtyping import Array, ArrayLike, PyTree, PyTreeDef +from jaxtyping import ArrayLike, PyTree, PyTreeDef from ._custom_types import Args, Control, IntScalarLike, RealScalarLike, VF, Y from ._misc import upcast_or_raise from ._path import AbstractPath -class AbstractTerm(eqx.Module): +_VF = TypeVar("_VF", bound=VF) +_Control = TypeVar("_Control", bound=Control) + + +class AbstractTerm(eqx.Module, Generic[_VF, _Control]): r"""Abstract base class for all terms. Let $y$ solve some differential equation with vector field $f$ and control $x$. @@ -28,7 +32,7 @@ class AbstractTerm(eqx.Module): """ @abc.abstractmethod - def vf(self, t: RealScalarLike, y: Y, args: Args) -> VF: + def vf(self, t: RealScalarLike, y: Y, args: Args) -> _VF: """The vector field. Represents a function $f(t, y(t), args)$. @@ -46,7 +50,7 @@ def vf(self, t: RealScalarLike, y: Y, args: Args) -> VF: pass @abc.abstractmethod - def contr(self, t0: RealScalarLike, t1: RealScalarLike) -> Control: + def contr(self, t0: RealScalarLike, t1: RealScalarLike) -> _Control: r"""The control. Represents the $\mathrm{d}t$ in an ODE, or the $\mathrm{d}w(t)$ in an SDE, etc. @@ -71,7 +75,7 @@ def contr(self, t0: RealScalarLike, t1: RealScalarLike) -> Control: pass @abc.abstractmethod - def prod(self, vf: VF, control: Control) -> Y: + def prod(self, vf: _VF, control: _Control) -> Y: r"""Determines the interaction between vector field and control. With a solution $y$ to a differential equation with vector field $f$ and @@ -94,7 +98,7 @@ def prod(self, vf: VF, control: Control) -> Y: """ pass - def vf_prod(self, t: RealScalarLike, y: Y, args: Args, control: Control) -> Y: + def vf_prod(self, t: RealScalarLike, y: Y, args: Args, control: _Control) -> Y: r"""The composition of [`diffrax.AbstractTerm.vf`][] and [`diffrax.AbstractTerm.prod`][]. @@ -155,7 +159,7 @@ def is_vf_expensive( return False -class ODETerm(AbstractTerm): +class ODETerm(AbstractTerm[_VF, RealScalarLike]): r"""A term representing $f(t, y(t), args) \mathrm{d}t$. That is to say, the term appearing on the right hand side of an ODE, in which the control is time. @@ -172,9 +176,9 @@ class ODETerm(AbstractTerm): ``` """ - vector_field: Callable[[RealScalarLike, Y, Args], VF] + vector_field: Callable[[RealScalarLike, Y, Args], _VF] - def vf(self, t: RealScalarLike, y: Y, args: Args) -> VF: + def vf(self, t: RealScalarLike, y: Y, args: Args) -> _VF: out = self.vector_field(t, y, args) if jtu.tree_structure(out) != jtu.tree_structure(y): raise ValueError( @@ -197,7 +201,7 @@ def _broadcast_and_upcast(oi, yi): def contr(self, t0: RealScalarLike, t1: RealScalarLike) -> RealScalarLike: return t1 - t0 - def prod(self, vf: VF, control: RealScalarLike) -> Y: + def prod(self, vf: _VF, control: RealScalarLike) -> Y: def _mul(v): c = upcast_or_raise( control, @@ -219,24 +223,28 @@ def _mul(v): """ -class _CallableToPath(AbstractPath): +class _CallableToPath(AbstractPath[_Control]): fn: Callable @property - def t0(self): + def t0(self): # pyright: ignore[reportIncompatibleVariableOverride] return -jnp.inf @property - def t1(self): + def t1(self): # pyright: ignore[reportIncompatibleVariableOverride] return jnp.inf def evaluate( self, t0: RealScalarLike, t1: Optional[RealScalarLike] = None, left: bool = True - ) -> PyTree[Array]: + ) -> _Control: return self.fn(t0, t1) -def _callable_to_path(x): +def _callable_to_path( + x: Union[ + AbstractPath[_Control], Callable[[RealScalarLike, RealScalarLike], _Control] + ], +) -> AbstractPath[_Control]: if isinstance(x, AbstractPath): return x else: @@ -250,15 +258,17 @@ def _prod(vf, control): return jnp.tensordot(vf, control, axes=jnp.ndim(control)) -class _ControlTerm(AbstractTerm): - vector_field: Callable[[RealScalarLike, Y, Args], VF] - control: Union[AbstractPath, Callable] = eqx.field(converter=_callable_to_path) +class _AbstractControlTerm(AbstractTerm[_VF, _Control]): + vector_field: Callable[[RealScalarLike, Y, Args], _VF] + control: Union[ + AbstractPath[_Control], Callable[[RealScalarLike, RealScalarLike], _Control] + ] = eqx.field(converter=_callable_to_path) # pyright: ignore def vf(self, t: RealScalarLike, y: Y, args: Args) -> VF: return self.vector_field(t, y, args) - def contr(self, t0: RealScalarLike, t1: RealScalarLike) -> Control: - return self.control.evaluate(t0, t1) + def contr(self, t0: RealScalarLike, t1: RealScalarLike) -> _Control: + return self.control.evaluate(t0, t1) # pyright: ignore def to_ode(self) -> ODETerm: r"""If the control is differentiable then $f(t, y(t), args) \mathrm{d}x(t)$ @@ -273,7 +283,7 @@ def to_ode(self) -> ODETerm: return ODETerm(vector_field=vector_field) -_ControlTerm.__init__.__doc__ = """**Arguments:** +_AbstractControlTerm.__init__.__doc__ = """**Arguments:** - `vector_field`: A callable representing the vector field. This callable takes three arguments `(t, y, args)`. `t` is a scalar representing the integration time. `y` is @@ -286,7 +296,7 @@ def to_ode(self) -> ODETerm: """ -class ControlTerm(_ControlTerm): +class ControlTerm(_AbstractControlTerm[_VF, _Control]): r"""A term representing the general case of $f(t, y(t), args) \mathrm{d}x(t)$, in which the vector field - control interaction is a matrix-vector product. @@ -324,11 +334,11 @@ class ControlTerm(_ControlTerm): ``` """ - def prod(self, vf: VF, control: Control) -> Y: + def prod(self, vf: _VF, control: _Control) -> Y: return jtu.tree_map(_prod, vf, control) -class WeaklyDiagonalControlTerm(_ControlTerm): +class WeaklyDiagonalControlTerm(_AbstractControlTerm[_VF, _Control]): r"""A term representing the case of $f(t, y(t), args) \mathrm{d}x(t)$, in which the vector field - control interaction is a matrix-vector product, and the matrix is square and diagonal. In this case we may represent the matrix as a vector @@ -350,15 +360,15 @@ class WeaklyDiagonalControlTerm(_ControlTerm): without the "weak". (This stronger property is useful in some SDE solvers.) """ - def prod(self, vf: VF, control: Control) -> Y: + def prod(self, vf: _VF, control: _Control) -> Y: return jtu.tree_map(operator.mul, vf, control) class _ControlToODE(eqx.Module): - control_term: _ControlTerm + control_term: _AbstractControlTerm def __call__(self, t: RealScalarLike, y: Y, args: Args) -> Y: - control = self.control_term.control.derivative(t) + control = self.control_term.control.derivative(t) # pyright: ignore return self.control_term.vf_prod(t, y, args, control) @@ -437,23 +447,23 @@ def is_vf_expensive( return any(term.is_vf_expensive(t0, t1, y, args) for term in self.terms) -class WrapTerm(AbstractTerm): - term: AbstractTerm +class WrapTerm(AbstractTerm[_VF, _Control]): + term: AbstractTerm[_VF, _Control] direction: IntScalarLike - def vf(self, t: RealScalarLike, y: Y, args: Args) -> VF: + def vf(self, t: RealScalarLike, y: Y, args: Args) -> _VF: t = t * self.direction return self.term.vf(t, y, args) - def contr(self, t0: RealScalarLike, t1: RealScalarLike) -> Control: + def contr(self, t0: RealScalarLike, t1: RealScalarLike) -> _Control: _t0 = jnp.where(self.direction == 1, t0, -t1) _t1 = jnp.where(self.direction == 1, t1, -t0) return (self.direction * self.term.contr(_t0, _t1) ** ω).ω - def prod(self, vf: VF, control: Control) -> Y: + def prod(self, vf: _VF, control: _Control) -> Y: return self.term.prod(vf, control) - def vf_prod(self, t: RealScalarLike, y: Y, args: Args, control: Control) -> Y: + def vf_prod(self, t: RealScalarLike, y: Y, args: Args, control: _Control) -> Y: t = t * self.direction return self.term.vf_prod(t, y, args, control) @@ -469,8 +479,8 @@ def is_vf_expensive( return self.term.is_vf_expensive(_t0, _t1, y, args) -class AdjointTerm(AbstractTerm): - term: AbstractTerm +class AdjointTerm(AbstractTerm[_VF, _Control]): + term: AbstractTerm[_VF, _Control] def is_vf_expensive( self, @@ -548,11 +558,11 @@ def _fn(_control): ) return jtu.tree_transpose(vf_prod_tree, control_tree, jac) - def contr(self, t0: RealScalarLike, t1: RealScalarLike) -> PyTree[ArrayLike]: + def contr(self, t0: RealScalarLike, t1: RealScalarLike) -> _Control: return self.term.contr(t0, t1) def prod( - self, vf: PyTree[ArrayLike], control: Control + self, vf: PyTree[ArrayLike], control: _Control ) -> tuple[ PyTree[ArrayLike], PyTree[ArrayLike], PyTree[ArrayLike], PyTree[ArrayLike] ]: @@ -595,7 +605,7 @@ def vf_prod( PyTree[ArrayLike], PyTree[ArrayLike], PyTree[ArrayLike], PyTree[ArrayLike] ], args: Args, - control: Control, + control: _Control, ) -> tuple[ PyTree[ArrayLike], PyTree[ArrayLike], PyTree[ArrayLike], PyTree[ArrayLike] ]: diff --git a/diffrax/_typing.py b/diffrax/_typing.py new file mode 100644 index 00000000..e0bfff6c --- /dev/null +++ b/diffrax/_typing.py @@ -0,0 +1,184 @@ +import inspect +import sys +import types +from typing import ( + Annotated, + Any, + Generic, + get_args, + get_origin, + Optional, + Protocol, + TypeVar, + Union, +) +from typing_extensions import TypeAlias + +import typeguard + + +# We don't actually care what people have subscripted with. +# In practice this should be thought of as TypeLike = Union[type, types.UnionType]. Plus +# maybe type(Literal) and so on? +TypeLike: TypeAlias = Any + + +def better_isinstance(x, annotation) -> bool: + """As `isinstance`, but supports general type hints.""" + + @typeguard.typechecked + def f(y: annotation): + pass + + try: + f(x) + except TypeError: + return False + else: + return True + + +_union_types: list = [Union] +if sys.version_info >= (3, 10): + _union_types.append(types.UnionType) + + +def get_origin_no_specials(x, error_msg: str) -> Optional[type]: + """As `typing.get_origin`, but ignores `Annotated` and throws a + `NotImplementedError` if passed any other non-class: `Union`, `Literal`, etc. Serves + as a guard against the full weirdness of the Python type system. + + **Arguments:** + + - `x`: the type to apply `get_origin` to. + - `error_msg`: if a disallowed type is used, then this will appear in the error + message. + + **Returns:** + + As `get_origin`, specifically either `None` or a class. + """ + origin = get_origin(x) + if origin in _union_types: + raise NotImplementedError(f"Cannot use unions in `{error_msg}`.") + elif origin is Annotated: + # We do allow Annotated, just because it's easy to handle. + return get_origin_no_specials(get_args(x)[0], error_msg) + elif origin is None or inspect.isclass(origin): + return origin + else: + raise NotImplementedError(f"Cannot use {x} in `{error_msg}`.") + + +def get_args_of(base_cls: type, cls, error_msg: str) -> tuple[TypeLike, ...]: + """Equivalent to `get_args(cls)`, except that it tracks through the type hierarchy + finding the way in which `cls` subclasses `base_cls`, and returns the arguments that + subscript that instead. + + For example, + ```python + class Foo(Generic[T]): + pass + + class Bar(Generic[S]): + pass + + class Qux(Foo[T], Bar[S]): + pass + + get_args_of(Foo, Qux[int, str], "...") # int + ``` + + In addition, any unfilled type variables are returned as `Any`. + + **Arguments:** + + - `base_cls`: the class to get parameters with respect to. + - `cls`: the class or subscripted generic to get arguments with respect to. + - `error_msg`: if anything goes wrong, mention this in the error message. + + **Returns:** + + A tuple of types indicating the arguments. In addition, any unfilled type variables + are returned as `Any`. + """ + + if not inspect.isclass(base_cls): + raise TypeError(f"{base_cls} should be a class") + if not hasattr(base_cls, "__parameters__"): + raise TypeError(f"{base_cls} should be an unsubscripted generic") + + origin = get_origin_no_specials(cls, error_msg) + if inspect.isclass(cls): + # Unsubscripted + assert origin is None + origin = cls + params = [Any for _ in getattr(cls, "__parameters__", ())] + else: + # Subscripted + assert origin is not None + params: list[TypeLike] = [] + for param in get_args(cls): + if isinstance(param, TypeVar): + params.append(Any) + else: + params.append(param) + if issubclass(origin, base_cls): + out = _get_args_of_impl(base_cls, origin, tuple(params), error_msg) + if out is None: + # Dependency is purely inheritance without subscripting + return tuple(Any for _ in base_cls.__parameters__) + else: + return out + else: + raise TypeError(f"{cls} is not a subclass of {base_cls}") + + +def _get_args_of_impl( + base_cls: type, cls: type, params: tuple[TypeLike, ...], error_msg +) -> Optional[tuple[TypeLike, ...]]: + if cls is base_cls: + return params + assert len(cls.__parameters__) == len(params) + param_lookup = {k: v for k, v in zip(cls.__parameters__, params)} + base_params: set[tuple[TypeLike, ...]] = set() + # If we've gotten this far then `cls` is known to have been subscripted, so it + # should have an `__orig_bases__` attribute. (Where as e.g. `class Foo: pass` does + # not have one) + for x in cls.__orig_bases__: + x_origin = get_origin_no_specials(x, error_msg) + if x_origin in (Generic, Protocol): + continue + if inspect.isclass(x): + # Unsubscripted, ignore. + assert x_origin is None + else: + # Subscripted, should pass in parameters + assert x_origin is not None + assert len(get_args(x)) > 0 + x_params = [ + param_lookup.get(arg, Any) if isinstance(arg, TypeVar) else arg + for arg in get_args(x) + ] + if issubclass(x_origin, base_cls): + base_params_i = _get_args_of_impl( + base_cls, x_origin, tuple(x_params), error_msg + ) + if base_params_i is not None: + base_params.add(base_params_i) + # Else ignore, we won't be able to recurse down to `base_cls` this way. + if len(base_params) == 0: + # `base_cls` is a superclass of `cls`, as we have earlier guards against this. + assert issubclass(cls, base_cls) + # However that dependency is purely normal inheritance, no subscripting. + return None + elif len(base_params) == 1: + return base_params.pop() + else: + if len(params) == 0: + error_cls = cls + else: + error_cls = cls[params] + raise TypeError( + f"{error_cls} inherits from {base_cls} in multiple incompatible ways." + ) diff --git a/pyproject.toml b/pyproject.toml index d6b69ad9..5a0f924c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Mathematics", ] urls = {repository = "https://github.com/patrick-kidger/diffrax" } -dependencies = ["jax>=0.4.23", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "equinox>=0.11.2", "lineax>=0.0.4", "optimistix>=0.0.6"] +dependencies = ["jax>=0.4.23", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "typeguard==2.13.3", "equinox>=0.11.2", "lineax>=0.0.4", "optimistix>=0.0.6"] [build-system] requires = ["hatchling"] @@ -38,13 +38,15 @@ markers = ["slow"] [tool.ruff] extend-include = ["*.ipynb"] +src = [] + +[tool.ruff.lint] fixable = ["I001", "F401"] ignore = ["E402", "E721", "E731", "E741", "F722"] ignore-init-module-imports = true select = ["E", "F", "I001"] -src = [] -[tool.ruff.isort] +[tool.ruff.lint.isort] combine-as-imports = true extra-standard-library = ["typing_extensions"] lines-after-imports = 2 @@ -52,4 +54,5 @@ order-by-type = false [tool.pyright] reportIncompatibleMethodOverride = true +reportIncompatibleVariableOverride = false # Incompatible with eqx.AbstractVar include = ["diffrax", "tests"] diff --git a/test/helpers.py b/test/helpers.py index 5b4c2dc3..87b94b50 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -1,4 +1,4 @@ -from typing import Callable, Literal +from typing import Callable import diffrax import equinox as eqx @@ -107,7 +107,7 @@ def sum_square_diff(y1, y2): def _batch_sde_solve( key: PRNGKeyArray, get_terms: Callable[[diffrax.AbstractBrownianPath], diffrax.AbstractTerm], - levy_area: Literal["", "space-time"], + levy_area: type[diffrax.AbstractBrownianReturn], solver: diffrax.AbstractSolver, w_shape: tuple[int, ...], t0: float, @@ -125,7 +125,7 @@ def _batch_sde_solve( shape=struct, tol=2**-14, key=key, - levy_area=levy_area, + levy_area=levy_area, # pyright: ignore ) terms = get_terms(bm) sol = diffrax.diffeqsolve( @@ -156,7 +156,8 @@ def sde_solver_strong_order( key: PRNGKeyArray, ): dtype = jnp.result_type(*jtu.tree_leaves(y0)) - levy_area = "" # TODO: add a check whether the solver needs levy area + # TODO: add a check whether the solver needs levy area + levy_area = diffrax.BrownianIncrement keys = jr.split(key, num_samples) # deliberately reused across all solves correct_sols = _batch_sde_solve( diff --git a/test/test_brownian.py b/test/test_brownian.py index 4fe293d7..73f56240 100644 --- a/test/test_brownian.py +++ b/test/test_brownian.py @@ -31,7 +31,9 @@ def _make_struct(shape, dtype): @pytest.mark.parametrize( "ctr", [diffrax.UnsafeBrownianPath, diffrax.VirtualBrownianTree] ) -@pytest.mark.parametrize("levy_area", ["", "space-time"]) +@pytest.mark.parametrize( + "levy_area", [diffrax.BrownianIncrement, diffrax.SpaceTimeLevyArea] +) @pytest.mark.parametrize("use_levy", (False, True)) def test_shape_and_dtype(ctr, levy_area, use_levy, getkey): t0 = 0.0 @@ -101,12 +103,10 @@ def is_tuple_of_ints(obj): with context: out = path.evaluate(t0, t1, use_levy=use_levy) if use_levy: - assert isinstance(out, diffrax.LevyVal) + assert isinstance(out, diffrax.AbstractBrownianReturn) w = out.W - h = out.H - if levy_area == "": - assert h is None - else: + if isinstance(out, diffrax.SpaceTimeLevyArea): + h = out.H assert eqx.filter_eval_shape(lambda: h) == shape else: w = out @@ -116,7 +116,9 @@ def is_tuple_of_ints(obj): @pytest.mark.parametrize( "ctr", [diffrax.UnsafeBrownianPath, diffrax.VirtualBrownianTree] ) -@pytest.mark.parametrize("levy_area", ["", "space-time"]) +@pytest.mark.parametrize( + "levy_area", [diffrax.BrownianIncrement, diffrax.SpaceTimeLevyArea] +) @pytest.mark.parametrize("use_levy", (False, True)) def test_statistics(ctr, levy_area, use_levy): # Deterministic key for this test; not using getkey() @@ -134,12 +136,11 @@ def _eval(key): values = jax.vmap(_eval)(keys) if use_levy: - assert isinstance(values, diffrax.LevyVal) + assert isinstance(values, diffrax.AbstractBrownianReturn) w = values.W - h = values.H - if levy_area == "": - assert h is None - else: + + if isinstance(values, diffrax.SpaceTimeLevyArea): + h = values.H assert h is not None assert h.shape == (10000,) ref_dist = stats.norm(loc=0, scale=math.sqrt(5 / 12)) @@ -204,9 +205,14 @@ def conditional_statistics( w_s = bm_s.W w_r = bm_r.W w_u = bm_u.W - h_s = bm_s.H - h_r = bm_r.H - h_u = bm_u.H + if levy_area is diffrax.SpaceTimeLevyArea: + h_s = bm_s.H + h_r = bm_r.H + h_u = bm_u.H + else: + h_s = None + h_r = None + h_u = None else: w_s = bm_s w_r = bm_r @@ -225,7 +231,7 @@ def conditional_statistics( # multiple-testing correction. pvals_w1.append(pval_w1) - if levy_area == "space-time" and use_levy: + if levy_area is diffrax.SpaceTimeLevyArea and use_levy: assert h_s is not None assert h_r is not None assert h_u is not None @@ -263,7 +269,9 @@ def conditional_statistics( return jnp.array(pvals_w1), jnp.array(pvals_w2), jnp.array(pvals_h) -@pytest.mark.parametrize("levy_area", ["", "space-time"]) +@pytest.mark.parametrize( + "levy_area", [diffrax.BrownianIncrement, diffrax.SpaceTimeLevyArea] +) @pytest.mark.parametrize("use_levy", (False, True)) def test_conditional_statistics(levy_area, use_levy): pvals_w1, pvals_w2, pvals_h = conditional_statistics( @@ -275,7 +283,7 @@ def test_conditional_statistics(levy_area, use_levy): min_num_points=90, ) assert jnp.all(pvals_w1 > 0.1 / pvals_w1.shape[0]) - if levy_area == "space-time" and use_levy: + if levy_area is diffrax.SpaceTimeLevyArea and use_levy: assert jnp.all(pvals_w2 > 0.1 / pvals_w2.shape[0]) assert jnp.all(pvals_h > 0.1 / pvals_h.shape[0]) else: @@ -284,9 +292,9 @@ def test_conditional_statistics(levy_area, use_levy): def _levy_area_spline(): - for levy_area in ("", "space-time"): + for levy_area in (diffrax.BrownianIncrement, diffrax.SpaceTimeLevyArea): for spline in ("quad", "sqrt", "zero"): - if levy_area == "space-time" and spline == "quad": + if levy_area is diffrax.SpaceTimeLevyArea and spline == "quad": continue yield levy_area, spline @@ -318,17 +326,17 @@ def pred(pvals): def pred(pvals): return jnp.min(pvals) < 0.001 / pvals.shape[0] and jnp.mean(pvals) < 0.03 - if levy_area == "space-time" and use_levy: + if levy_area is diffrax.SpaceTimeLevyArea and use_levy: assert pred(pvals_w2) assert pred(pvals_h) else: assert len(pvals_w2) == 0 assert len(pvals_h) == 0 - if levy_area == "": + if levy_area is diffrax.BrownianIncrement: assert pred(pvals_w1) - elif spline == "sqrt": # levy_area == "space-time" + elif spline == "sqrt": # levy_area == SpaceTimeLevyArea assert pred(pvals_w1) - else: # levy_area == "space-time" and spline == "zero" + else: # levy_area == SpaceTimeLevyArea and spline == "zero" # We need a milder upper bound on jnp.mean(pvals_w1) because # the presence of space-time Levy area gives W_r (i.e. the output # of the Brownian path) a variance very close to the correct one, @@ -342,7 +350,7 @@ def test_levy_area_reverse_time(): key = jr.PRNGKey(5678) bm_key, sample_key = jr.split(key, 2) bm = diffrax.VirtualBrownianTree( - t0=0, t1=5, tol=2**-5, shape=(), key=bm_key, levy_area="space-time" + t0=0, t1=5, tol=2**-5, shape=(), key=bm_key, levy_area=diffrax.SpaceTimeLevyArea ) ts = jr.uniform(sample_key, shape=(100,), minval=0, maxval=5) diff --git a/test/test_integrate.py b/test/test_integrate.py index 3bef5ce6..23781058 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -1,7 +1,7 @@ import contextlib import math import operator -from typing import cast +from typing import Any, cast import diffrax import equinox as eqx @@ -13,7 +13,7 @@ import scipy.stats from diffrax import ControlTerm, MultiTerm, ODETerm from equinox.internal import ω -from jaxtyping import Array +from jaxtyping import Array, ArrayLike, Float from .helpers import ( all_ode_solvers, @@ -520,3 +520,162 @@ def test_implicit_tol_error(): 0.01, 1.0, ) + + +def test_term_compatibility(): + class TestControl(eqx.Module): + dt: Float[ArrayLike, ""] + + def __rmul__(self, other): + return other.__mul__(self.dt) + + def __mul__(self, other): + return self.dt * other + + class TestSolver(diffrax.Euler): + term_structure = diffrax.AbstractTerm[ + tuple[Float[Array, "n 3"]], tuple[TestControl] + ] + + solver = TestSolver() + incompatible_vf = lambda t, y, args: jnp.ones((2, 1)) + compatible_vf = lambda t, y, args: (jnp.ones((2, 3)),) + incompatible_control = lambda t0, t1: t1 - t0 + compatible_control = lambda t0, t1: (TestControl(t1 - t0),) + + incompatible_terms = [ + diffrax.WeaklyDiagonalControlTerm(incompatible_vf, incompatible_control), + diffrax.WeaklyDiagonalControlTerm(incompatible_vf, compatible_control), + diffrax.WeaklyDiagonalControlTerm(compatible_vf, incompatible_control), + ] + compatible_term = diffrax.WeaklyDiagonalControlTerm( + compatible_vf, compatible_control + ) + for term in incompatible_terms: + with pytest.raises(ValueError, match=r"`terms` must be a PyTree of"): + diffrax.diffeqsolve(term, solver, 0.0, 1.0, 0.1, (jnp.zeros((2, 1)),)) + + diffrax.diffeqsolve( + compatible_term, solver, 0.1, 1.1, 0.1, (jnp.zeros((2, 3)),), args=["str"] + ) + + +def test_term_compatibility_pytree(): + class TestSolver(diffrax.AbstractSolver): + term_structure = { + "a": diffrax.ODETerm, + "b": diffrax.ODETerm[Any], + "c": diffrax.ODETerm[Float[Array, " 3"]], + "d": diffrax.AbstractTerm[Float[Array, " 4"], Any], + "e": diffrax.MultiTerm[ + tuple[diffrax.ODETerm, diffrax.AbstractTerm[Any, Float[Array, " 5"]]] + ], + } + interpolation_cls = diffrax.LocalLinearInterpolation + + def init(self, terms, t0, t1, y0, args): + return None + + def step(self, terms, t0, t1, y0, args, solver_state, made_jump): + def _step(_term, _y): + control = _term.contr(t0, t1) + return _y + _term.vf_prod(t0, _y, args, control) + + _is_term = lambda x: isinstance(x, diffrax.AbstractTerm) + y1 = jtu.tree_map(_step, terms, y0, is_leaf=_is_term) + dense_info = dict(y0=y0, y1=y1) + return y1, None, dense_info, None, diffrax.RESULTS.successful + + def func(self, terms, t0, y0, args): + assert False + + ode_term = diffrax.ODETerm(lambda t, y, args: -y) + solver = TestSolver() + compatible_term = { + "a": ode_term, + "b": ode_term, + "c": ode_term, + "d": ode_term, + "e": diffrax.MultiTerm( + ode_term, + diffrax.WeaklyDiagonalControlTerm( + lambda t, y, args: -y, lambda t0, t1: jnp.array(t1 - t0).repeat(5) + ), + ), + } + compatible_y0 = { + "a": jnp.array(1.0), + "b": jnp.array(2.0), + "c": jnp.arange(3.0), + "d": jnp.arange(4.0), + "e": jnp.arange(5.0), + } + diffrax.diffeqsolve(compatible_term, solver, 0.0, 1.0, 0.1, compatible_y0) + + incompatible_term1 = { + "a": ode_term, + "b": ode_term, + "c": ode_term, + "d": ode_term, + "e": diffrax.MultiTerm( + ode_term, + diffrax.WeaklyDiagonalControlTerm( + lambda t, y, args: -y, + lambda t0, t1: t1 - t0, # wrong control shape + ), + ), + } + incompatible_term2 = { + "a": ode_term, + "b": ode_term, + "c": ode_term, + # Missing "d" piece + "e": diffrax.MultiTerm( + ode_term, + diffrax.WeaklyDiagonalControlTerm( + lambda t, y, args: -y, lambda t0, t1: jnp.array(t1 - t0).repeat(3) + ), + ), + } + incompatible_term3 = { + "a": ode_term, + "b": ode_term, + "c": ode_term, + "d": ode_term, + # No MultiTerm for "e" + "e": diffrax.WeaklyDiagonalControlTerm( + lambda t, y, args: -y, lambda t0, t1: jnp.array(t1 - t0).repeat(3) + ), + } + + incompatible_y0_1 = { + "a": jnp.array(1.0), + "b": jnp.array(2.0), + "c": jnp.arange(4.0), # of length 4, not 3 + "d": jnp.arange(4.0), + "e": jnp.arange(5.0), + } + incompatible_y0_2 = { + "a": jnp.array(1.0), + "b": jnp.array(2.0), + "c": jnp.arange(3.0), + # Missing "d" piece + "e": jnp.arange(5.0), + } + incompatible_y0_3 = jnp.array(4.0) # Completely the wrong structure! + for term in ( + compatible_term, + incompatible_term1, + incompatible_term2, + incompatible_term3, + ): + for y0 in ( + compatible_y0, + incompatible_y0_1, + incompatible_y0_2, + incompatible_y0_3, + ): + if term is compatible_term and y0 is compatible_y0: + continue + with pytest.raises(ValueError, match=r"`terms` must be a PyTree of"): + diffrax.diffeqsolve(term, solver, 0.0, 1.0, 0.1, y0) diff --git a/test/test_term.py b/test/test_term.py index ec9e8b74..bc0aee34 100644 --- a/test/test_term.py +++ b/test/test_term.py @@ -3,24 +3,32 @@ import jax.numpy as jnp import jax.random as jr import jax.tree_util as jtu +from jaxtyping import Array, PyTree, Shaped from .helpers import tree_allclose def test_ode_term(): - vector_field = lambda t, y, args: -y + def vector_field(t, y, args) -> Array: + return -y + term = diffrax.ODETerm(vector_field) dt = term.contr(0, 1) vf = term.vf(0, 1, None) vf_prod = term.vf_prod(0, 1, None, dt) assert tree_allclose(vf_prod, term.prod(vf, dt)) + # `# type: ignore` is used for contrapositive static type checking as per: + # https://github.com/microsoft/pyright/discussions/2411#discussioncomment-2028001 + _: diffrax.ODETerm[Array] = term + __: diffrax.ODETerm[bool] = term # type: ignore + def test_control_term(getkey): vector_field = lambda t, y, args: jr.normal(args, (3, 2)) derivkey = getkey() - class Control(diffrax.AbstractPath): + class Control(diffrax.AbstractPath[Shaped[Array, "2"]]): t0 = 0 t1 = 1 @@ -30,7 +38,7 @@ def evaluate(self, t0, t1=None, left=True): def derivative(self, t, left=True): return jr.normal(derivkey, (2,)) - control = Control() # pyright: ignore + control = Control() term = diffrax.ControlTerm(vector_field, control) args = getkey() dx = term.contr(0, 1) @@ -42,6 +50,27 @@ def derivative(self, t, left=True): assert vf_prod.shape == (3,) assert tree_allclose(vf_prod, term.prod(vf, dx)) + # `# type: ignore` is used for contrapositive static type checking as per: + # https://github.com/microsoft/pyright/discussions/2411#discussioncomment-2028001 + _: diffrax.ControlTerm[PyTree[Array], Array] = term + __: diffrax.ControlTerm[PyTree[Array], diffrax.AbstractBrownianReturn] = term # type: ignore + + # Enable the following runtime checks once beartype supports Generic[TypeVar]. + # https://github.com/beartype/beartype/issues/238 + # import pytest + # from beartype.door import die_if_unbearable + # from beartype.roar import BeartypeCallHintViolation + + # control = Control() + # term = diffrax.ControlTerm(vector_field, control) + + # die_if_unbearable(term, diffrax.ControlTerm[Shaped[Array, "2"]]) + # with pytest.raises(BeartypeCallHintViolation): + # die_if_unbearable(term, diffrax.ControlTerm[Shaped[Array, "3"]]) + + # with pytest.raises(BeartypeCallHintViolation): + # die_if_unbearable(term, diffrax.ControlTerm[diffrax.LevyVal]) + term = term.to_ode() dt = term.contr(0, 1) vf = term.vf(0, y, args) @@ -65,7 +94,7 @@ def evaluate(self, t0, t1=None, left=True): def derivative(self, t, left=True): return jr.normal(derivkey, (3,)) - control = Control() # pyright: ignore + control = Control() term = diffrax.WeaklyDiagonalControlTerm(vector_field, control) args = getkey() dx = term.contr(0, 1) diff --git a/test/test_typing.py b/test/test_typing.py new file mode 100644 index 00000000..746f7578 --- /dev/null +++ b/test/test_typing.py @@ -0,0 +1,296 @@ +from typing import Annotated, Any, Generic, Literal, TypeVar, Union + +import diffrax as dfx +import pytest +from diffrax._custom_types import RealScalarLike +from diffrax._typing import get_args_of, get_origin_no_specials + + +T = TypeVar("T") +S = TypeVar("S") +U = TypeVar("U") + + +class Foo(Generic[T]): + pass + + +class Bar(Generic[S]): + pass + + +class Baz: + pass + + +def test_get_origin_no_specials(): + assert get_origin_no_specials(int, "") is None + assert get_origin_no_specials(tuple[int, ...], "") is tuple + assert get_origin_no_specials(Foo[int], "") is Foo + assert get_origin_no_specials(Annotated[tuple[int, ...], 1337], "") is tuple + # Weird, but legal + assert get_origin_no_specials(Generic[T], "") is Generic # pyright: ignore + + with pytest.raises(NotImplementedError, match="qwerty"): + get_origin_no_specials(Union[int, str], "qwerty") + with pytest.raises(NotImplementedError, match="qwerty"): + get_origin_no_specials(Literal[4], "qwerty") # pyright: ignore + + +def test_get_args_of_not_generic(): + with pytest.raises(TypeError, match="unsubscripted generic"): + get_args_of(Baz, Foo, "") + with pytest.raises(TypeError, match="unsubscripted generic"): + get_args_of(Baz, Foo[int], "") + with pytest.raises(TypeError, match="unsubscripted generic"): + get_args_of(int, Foo, "") + + +def test_get_args_of_not_subclass(): + with pytest.raises(TypeError, match="is not a subclass"): + get_args_of(Foo, Bar, "") + with pytest.raises(TypeError, match="is not a subclass"): + get_args_of(Foo, Baz, "") + with pytest.raises(TypeError, match="is not a subclass"): + get_args_of(Foo, int, "") + + +def test_get_args_of_single_inheritance(): + class Qux1(Foo): + pass + + class Qux2(Foo[int]): + pass + + class Qux3(Foo[T]): + pass + + assert get_args_of(Foo, Qux1, "") == (Any,) + assert get_args_of(Foo, Qux2, "") == (int,) + assert get_args_of(Foo, Qux3, "") == (Any,) + assert get_args_of(Foo, Qux3[str], "") == (str,) + + +def test_get_args_irrelevant_inheritance(): + class Qux1(Foo, str): + pass + + class Qux2(Foo[int], str): + pass + + class Qux3(Foo[T], str): + pass + + assert get_args_of(Foo, Qux1, "") == (Any,) + assert get_args_of(Foo, Qux2, "") == (int,) + assert get_args_of(Foo, Qux3, "") == (Any,) + assert get_args_of(Foo, Qux3[str], "") == (str,) + + +def test_get_args_double_inheritance(): + class Qux1(Foo, Bar): + pass + + class Qux2(Foo[int], Bar): + pass + + class Qux3(Foo[T], Bar): + pass + + assert get_args_of(Foo, Qux1, "") == (Any,) + assert get_args_of(Foo, Qux2, "") == (int,) + assert get_args_of(Foo, Qux3, "") == (Any,) + assert get_args_of(Foo, Qux3[bool], "") == (bool,) + + class Qux4(Foo, Bar[str]): + pass + + class Qux5(Foo[int], Bar[str]): + pass + + class Qux6(Foo[T], Bar[str]): + pass + + assert get_args_of(Foo, Qux4, "") == (Any,) + assert get_args_of(Foo, Qux5, "") == (int,) + assert get_args_of(Foo, Qux6, "") == (Any,) + assert get_args_of(Foo, Qux6[bool], "") == (bool,) + + class Qux7(Foo, Bar[S]): + pass + + class Qux8(Foo[int], Bar[S]): + pass + + class Qux9(Foo[T], Bar[S]): + pass + + assert get_args_of(Foo, Qux7, "") == (Any,) + assert get_args_of(Foo, Qux7[bool], "") == (Any,) + assert get_args_of(Foo, Qux8, "") == (int,) + assert get_args_of(Foo, Qux8[bool], "") == (int,) + assert get_args_of(Foo, Qux9, "") == (Any,) + assert get_args_of(Foo, Qux9[bool, str], "") == (bool,) + + class Qux10(Foo, Bar[T]): + pass + + class Qux11(Foo[int], Bar[T]): + pass + + class Qux12(Foo[T], Bar[T]): + pass + + assert get_args_of(Foo, Qux10, "") == (Any,) + assert get_args_of(Foo, Qux11, "") == (int,) + assert get_args_of(Foo, Qux12, "") == (Any,) + assert get_args_of(Foo, Qux12[bool], "") == (bool,) + + +def test_get_args_double_inheritance_reverse(): + class Qux1(Foo, Bar): + pass + + class Qux2(Foo[int], Bar): + pass + + class Qux3(Foo[T], Bar): + pass + + assert get_args_of(Bar, Qux1, "") == (Any,) + assert get_args_of(Bar, Qux2, "") == (Any,) + assert get_args_of(Bar, Qux3, "") == (Any,) + assert get_args_of(Bar, Qux3[bool], "") == (Any,) + + class Qux4(Foo, Bar[str]): + pass + + class Qux5(Foo[int], Bar[str]): + pass + + class Qux6(Foo[T], Bar[str]): + pass + + assert get_args_of(Bar, Qux4, "") == (str,) + assert get_args_of(Bar, Qux5, "") == (str,) + assert get_args_of(Bar, Qux6, "") == (str,) + assert get_args_of(Bar, Qux6[bool], "") == (str,) + + class Qux7(Foo, Bar[S]): + pass + + class Qux8(Foo[int], Bar[S]): + pass + + class Qux9(Foo[T], Bar[S]): + pass + + assert get_args_of(Bar, Qux7, "") == (Any,) + assert get_args_of(Bar, Qux7[bool], "") == (bool,) + assert get_args_of(Bar, Qux8, "") == (Any,) + assert get_args_of(Bar, Qux8[bool], "") == (bool,) + assert get_args_of(Bar, Qux9, "") == (Any,) + assert get_args_of(Bar, Qux9[bool, str], "") == (str,) + + class Qux10(Foo, Bar[T]): + pass + + class Qux11(Foo[int], Bar[T]): + pass + + class Qux12(Foo[T], Bar[T]): + pass + + assert get_args_of(Bar, Qux10, "") == (Any,) + assert get_args_of(Bar, Qux11, "") == (Any,) + assert get_args_of(Bar, Qux12, "") == (Any,) + assert get_args_of(Bar, Qux12[bool], "") == (bool,) + + +def test_get_args_of_complicated(): + class X1(Generic[T, S]): + pass + + class X2(X1[T, T], Generic[T, S]): + pass + + class X3(X2): + pass + + class X4(X2[int, T]): + pass + + class X5(str, X1[str, str]): + pass + + class X6(X1[S, T], Generic[T, U, S]): + pass + + # This one is invalid at static type-checking time. + class X7(X6[int, str, bool], X2[int, str]): # pyright: ignore + pass + + class X8(X6[bool, T, bool], X2[bool, int]): + pass + + class X9(X3, X2[int, str]): + pass + + # Some of these are invalid at static type-checking time. + assert get_args_of(X1, X1, "") == (Any, Any) + assert get_args_of(X1, X1[int, S], "") == (int, Any) # pyright: ignore + assert get_args_of(X1, X1[int, str], "") == (int, str) + + assert get_args_of(X1, X2, "") == (Any, Any) + assert get_args_of(X1, X2[T, str], "") == (Any, Any) # pyright: ignore + assert get_args_of(X1, X2[str, T], "") == (str, str) # pyright: ignore + assert get_args_of(X1, X2[int, str], "") == (int, int) + + assert get_args_of(X2, X3, "") == (Any, Any) + + assert get_args_of(X2, X4, "") == (int, Any) + assert get_args_of(X2, X4[str], "") == (int, str) + + assert get_args_of(X1, X5, "") == (str, str) + + assert get_args_of(X1, X6, "") == (Any, Any) + assert get_args_of(X1, X6[int, str, bool], "") == (bool, int) + + with pytest.raises(TypeError, match="multiple incompatible ways"): + assert get_args_of(X1, X7, "") == (Any, Any) + assert get_args_of(X6, X7, "") == (int, str, bool) + assert get_args_of(X2, X7, "") == (int, str) + + assert get_args_of(X1, X8, "") == (bool, bool) + assert get_args_of(X1, X8[float], "") == (bool, bool) + assert get_args_of(X6, X8, "") == (bool, Any, bool) + assert get_args_of(X6, X8[float], "") == (bool, float, bool) + assert get_args_of(X2, X8, "") == (bool, int) + assert get_args_of(X2, X8[float], "") == (bool, int) + + assert get_args_of(X3, X9, "") == () + assert get_args_of(X2, X9, "") == (int, str) + assert get_args_of(X1, X9, "") == (int, int) + + +_abstract_args = lambda cls: get_args_of(dfx.AbstractTerm, cls, "") + + +def test_abstract_term(): + assert _abstract_args(dfx.AbstractTerm) == (Any, Any) + assert _abstract_args(dfx.AbstractTerm[int, str]) == (int, str) + + +def test_ode_term(): + assert _abstract_args(dfx.ODETerm) == (Any, RealScalarLike) + assert _abstract_args(dfx.ODETerm[int]) == (int, RealScalarLike) + + +def test_control_term(): + assert _abstract_args(dfx.ControlTerm) == (Any, Any) + assert _abstract_args(dfx.ControlTerm[int, str]) == (int, str) + + +def test_weakly_diagonal_control_term(): + assert _abstract_args(dfx.WeaklyDiagonalControlTerm) == (Any, Any) + assert _abstract_args(dfx.WeaklyDiagonalControlTerm[int, str]) == (int, str)