Skip to content

Commit

Permalink
Parametric control types (#364)
Browse files Browse the repository at this point in the history
* Added parametric control types

* Demo of AbstractTerm change

* Correct initial parametric control type implementation.

* Parametric AbstractTerm initial implementation.

* Update tests and fix hinting

* Implement review comments.

* Add parametric control check to integrator.

* Update and test parametric control check

* Introduce new LevyArea types

* Updated Brownian path LevyArea types

* Replace Union types in isinstance checks

* Remove rogue comment

* Revert _brownian_arch to single assignment

* Revert _evaluate_leaf key splitting

* Rename variables in test_term

* Update isinstance and issubclass checks

* Safer handling in _denormalise_bm_inc

* Fix style in integrate control type check

* Add draft vector_field typing

* Add draft vector_field typing

* Fix term test

* Revert extemporaneous modifications in _tree

* Rename TimeLevyArea to BrownianIncrement and simplify diff

* Rename AbstractLevyReturn to AbstractBrownianReturn

* Rename _LevyArea to _BrownianReturn

* Enhance _term_compatiblity checks

* Fix merge issues

* Bump pre-commit and fix type hints

* Clean up from self-review

* Explicitly add typeguard to deps

* Bump ruff config to new syntax

* Parameterised terms: fixed term compatibility + spurious pyright errors

Phew, this ended up being a complicated one!

Let's start with the easy stuff:
- Disabled spurious pyright errors due to incompatible between pyright and `eqx.AbstractVar`.
- Now using ruff.lint and pinned exact typeguard version.

Now on to the hard stuff:
- Fixed term compatibibility missing some edge cases.

Edge cases? What edge cases? Well, what we had before was basically predicated around doing
```python
vf, contr = get_args(term_cls)
```
recalling that we may have e.g. `term_cls = AbstractTerm[SomeVectorField, SomeControl]`. So far so simple: get the arguments of a subscripted generic, no big deal.

What this failed to account for is that we may also have subclasses of this generic, e.g. `term_cls = ODETerm[SomeVectorField]`, such that some of the type variables have already been filled in when defining it:
```python
class ODETerm(AbstractTerm[_VF, RealScaleLike]): ...
```
so in this case, `get_args(term_cls)` simply returns a 1-tuple of `(SomeVectorField,)`. Oh no! Somehow we have to traverse both the filled-in type variables (to find that one of our type variables is `SomeVectorField` due to subscripting) *and* the type hierarchy (to figure out that the other type variable was filled in during the definition).

Once again, for clarity: given a subscriptable base class `AbstractTerm[_VF, _Control]` and some arbitrary possible-subscripted subclass, we need to find the values of `_VF` and `_Control`, regardless of whehther they have been passed in via subscripting the final class (and are `get_args`-able) or have been filled in during subclassing (and require traversing pseudo-type-hierarchies of `__orig_bases__`).

Any sane implementation would simply... not bother. There is no way that the hassle of figuring this out was going to be worth the small amount of type safety this brings...

So anyway, after a few hours working on this *far* past the point I should be going to sleep, this problem this is now solved. This PR introduces a new `get_args_of` function, called as `get_args_of(superclass, subclass, error_msg_if_necessary)`. This acts analogous to `get_args`, but instead of looking up both parameters (the type variables we want filled in) and the arguments (the values those type variables have been filled in with) on the same class, it looks up the parameters on the superclass, and their filled-in-values on the subclass. Pure madness.

(I'm also tagging @leycec here because this is exactly the kind of insane typing hackery that he seems to really enjoy.)

Does anyone else remember the days when this was a package primarily concerned about solving differential equations?

---------

Co-authored-by: Patrick Kidger <[email protected]>
  • Loading branch information
tttc3 and patrick-kidger committed Apr 20, 2024
1 parent 6c93faa commit 34cbe5c
Show file tree
Hide file tree
Showing 24 changed files with 1,001 additions and 219 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.7
rev: v0.2.2
hooks:
- id: ruff # linter
types_or: [ python, pyi, jupyter ]
args: [ --fix ]
- id: ruff-format # formatter
types_or: [ python, pyi, jupyter ]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.316
rev: v1.1.350
hooks:
- id: pyright
additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typing_extensions]
6 changes: 5 additions & 1 deletion diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
UnsafeBrownianPath as UnsafeBrownianPath,
VirtualBrownianTree as VirtualBrownianTree,
)
from ._custom_types import LevyVal as LevyVal
from ._custom_types import (
AbstractBrownianReturn as AbstractBrownianReturn,
BrownianIncrement as BrownianIncrement,
SpaceTimeLevyArea as SpaceTimeLevyArea,
)
from ._event import (
AbstractDiscreteTerminatingEvent as AbstractDiscreteTerminatingEvent,
DiscreteTerminatingEvent as DiscreteTerminatingEvent,
Expand Down
7 changes: 5 additions & 2 deletions diffrax/_adjoint.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import abc
import functools as ft
import warnings
from collections.abc import Iterable
from typing import Any, Optional, Union
from collections.abc import Callable, Iterable
from typing import Any, cast, Optional, Union

import equinox as eqx
import equinox.internal as eqxi
Expand All @@ -20,6 +20,9 @@
from ._term import AbstractTerm, AdjointTerm


ω = cast(Callable, ω)


def _is_none(x):
return x is None

Expand Down
13 changes: 8 additions & 5 deletions diffrax/_brownian/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import abc
from typing import Optional, Union
from typing import Optional, TypeVar, Union

from equinox.internal import AbstractVar
from jaxtyping import Array, PyTree

from .._custom_types import LevyArea, LevyVal, RealScalarLike
from .._custom_types import AbstractBrownianReturn, RealScalarLike
from .._path import AbstractPath


class AbstractBrownianPath(AbstractPath):
_Control = TypeVar("_Control", bound=Union[PyTree[Array], AbstractBrownianReturn])


class AbstractBrownianPath(AbstractPath[_Control]):
"""Abstract base class for all Brownian paths."""

levy_area: AbstractVar[LevyArea]
levy_area: AbstractVar[type[AbstractBrownianReturn]]

@abc.abstractmethod
def evaluate(
Expand All @@ -20,7 +23,7 @@ def evaluate(
t1: Optional[RealScalarLike] = None,
left: bool = True,
use_levy: bool = False,
) -> Union[PyTree[Array], LevyVal]:
) -> _Control:
r"""Samples a Brownian increment $w(t_1) - w(t_0)$.
Each increment has distribution $\mathcal{N}(0, t_1 - t_0)$.
Expand Down
44 changes: 24 additions & 20 deletions diffrax/_brownian/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
import lineax.internal as lxi
from jaxtyping import Array, PRNGKeyArray, PyTree

from .._custom_types import levy_tree_transpose, LevyArea, LevyVal, RealScalarLike
from .._custom_types import (
BrownianIncrement,
levy_tree_transpose,
RealScalarLike,
SpaceTimeLevyArea,
)
from .._misc import (
force_bitcast_convert_type,
is_tuple_of_ints,
Expand Down Expand Up @@ -42,26 +47,24 @@ class UnsafeBrownianPath(AbstractBrownianPath):
"""

shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
levy_area: LevyArea = eqx.field(static=True)
levy_area: type[Union[BrownianIncrement, SpaceTimeLevyArea]] = eqx.field(
static=True
)
key: PRNGKeyArray

def __init__(
self,
shape: Union[tuple[int, ...], PyTree[jax.ShapeDtypeStruct]],
key: PRNGKeyArray,
levy_area: LevyArea = "",
levy_area: type[Union[BrownianIncrement, SpaceTimeLevyArea]],
):
self.shape = (
jax.ShapeDtypeStruct(shape, lxi.default_floating_dtype())
if is_tuple_of_ints(shape)
else shape
)
self.key = key
if levy_area not in ["", "space-time"]:
raise ValueError(
f"levy_area must be one of '', 'space-time', but got {levy_area}."
)
self.levy_area = levy_area
self.levy_area = levy_area # pyright: ignore[reportIncompatibleVariableOverride]

if any(
not jnp.issubdtype(x.dtype, jnp.inexact)
Expand All @@ -70,11 +73,11 @@ def __init__(
raise ValueError("UnsafeBrownianPath dtypes all have to be floating-point.")

@property
def t0(self):
def t0(self): # pyright: ignore[reportIncompatibleVariableOverride]
return -jnp.inf

@property
def t1(self):
def t1(self): # pyright: ignore[reportIncompatibleVariableOverride]
return jnp.inf

@eqx.filter_jit
Expand All @@ -84,7 +87,7 @@ def evaluate(
t1: Optional[RealScalarLike] = None,
left: bool = True,
use_levy: bool = False,
) -> Union[PyTree[Array], LevyVal]:
) -> Union[PyTree[Array], BrownianIncrement, SpaceTimeLevyArea]:
del left
if t1 is None:
dtype = jnp.result_type(t0)
Expand Down Expand Up @@ -112,7 +115,7 @@ def evaluate(
)
if use_levy:
out = levy_tree_transpose(self.shape, out)
assert isinstance(out, LevyVal)
assert isinstance(out, (BrownianIncrement, SpaceTimeLevyArea))
return out

@staticmethod
Expand All @@ -121,25 +124,26 @@ def _evaluate_leaf(
t1: RealScalarLike,
key,
shape: jax.ShapeDtypeStruct,
levy_area: str,
levy_area: type[Union[BrownianIncrement, SpaceTimeLevyArea]],
use_levy: bool,
):
w_std = jnp.sqrt(t1 - t0).astype(shape.dtype)
w = jr.normal(key, shape.shape, shape.dtype) * w_std
dt = t1 - t0

if levy_area == "space-time":
if levy_area is SpaceTimeLevyArea:
key, key_hh = jr.split(key, 2)
hh_std = w_std / math.sqrt(12)
hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std
elif levy_area == "":
hh = None
levy_val = SpaceTimeLevyArea(dt=dt, W=w, H=hh, K=None)
elif levy_area is BrownianIncrement:
levy_val = BrownianIncrement(dt=dt, W=w)
else:
assert False
w = jr.normal(key, shape.shape, shape.dtype) * w_std

if use_levy:
return LevyVal(dt=t1 - t0, W=w, H=hh, K=None)
else:
return w
return levy_val
return w


UnsafeBrownianPath.__init__.__doc__ = """
Expand Down
Loading

0 comments on commit 34cbe5c

Please sign in to comment.