Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Events #387

Merged
merged 28 commits into from
Jun 29, 2024
Merged
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
2b6d8ff
Changes to how events are handled in diffrax.
cholberg Feb 21, 2024
3e367f2
Test now fails when no root finder is provided
cholberg May 7, 2024
3f67a27
Saving events with `SubSaveAt`s
cholberg May 15, 2024
76762c1
Accounting for `SubSaveAt.fn` returning a PyTree
cholberg May 15, 2024
76cd083
Adjustments to #387 (events):
patrick-kidger May 19, 2024
e2ab3ce
Save values returned by root find when
cholberg May 23, 2024
0853e49
now returns condition function
cholberg May 24, 2024
1488206
Fixed error for . All tests pass now.
cholberg May 24, 2024
4872009
Added additional tests
cholberg May 26, 2024
a1f577c
Fixed save_index update and shape+dtype check for cond_fn
cholberg May 27, 2024
95ac30f
Added PyTree check in _outer_cond_fn
cholberg May 27, 2024
0c820f3
Added tests for checking that events error out correctly under misspe…
cholberg May 27, 2024
57d90c5
Fixed small error in the save_index update for events
cholberg May 27, 2024
1bdf1d2
Updated how events are saved
cholberg May 27, 2024
55e04e8
Added tests for different configurations of saveat
cholberg May 27, 2024
d8a8ba7
Changed to ValueError when cond_fn returns non-boolean/float.
cholberg May 28, 2024
70e044f
Added docstring to Event class
cholberg May 28, 2024
4c509b6
Updated docstring for steady_state_event
cholberg May 28, 2024
fbea794
Updated docstring for ImplicitAdjoint
cholberg May 28, 2024
b158800
Added example to Event docstring
cholberg May 28, 2024
812e5c6
Updated steady state example to use the new syntax
cholberg May 28, 2024
eb16e57
Fixed weird type checker error
cholberg Jun 1, 2024
09c92c3
Updated steady state test to use the new syntax
cholberg Jun 10, 2024
d07c8f4
Doc tweaks for events
patrick-kidger Jun 15, 2024
883841f
Typo in comment
cholberg Jun 16, 2024
0c62f4c
Simplified unsaving
cholberg Jun 18, 2024
7588482
Deleted extra unnecessary argument
cholberg Jun 25, 2024
e4935ae
Changed to strict inequality to be in line with the usual saving behv…
cholberg Jun 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added tests for different configurations of saveat
cholberg committed Jun 25, 2024
commit 55e04e816f12e3595baf1a38625b53a2fb0e6879
149 changes: 149 additions & 0 deletions test/test_event.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@
from typing import cast

import diffrax
import equinox as eqx
import jax
import jax.numpy as jnp
import optimistix as optx
@@ -560,3 +561,151 @@ def cond_fn_2(t, y, args, **kwargs):
event = diffrax.Event(cond_fn=cond_fn)
with pytest.raises(AssertionError):
diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0, event=event)


@pytest.mark.parametrize("steps", (1, 2, 3, 4, 5))
def test_event_save_steps(steps):
term = diffrax.ODETerm(lambda t, y, args: (1.0, 1.0))
solver = diffrax.Tsit5()
t0 = 0
t1 = 10
dt0 = 1
thr = steps - 0.5
y0 = (0.0, -thr)
ts = jnp.array([0.5, 3.5, 5.5])

def cond_fn(t, y, args, **kwargs):
del t, args, kwargs
x, _ = y
return x - thr

root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm)
event = diffrax.Event(cond_fn, root_finder)

def run(saveat):
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
event=event,
saveat=saveat,
)
return cast(Array, sol.ts), cast(tuple, sol.ys)

saveats = [
diffrax.SaveAt(steps=True),
diffrax.SaveAt(steps=True, t1=True),
diffrax.SaveAt(steps=True, t1=True, t0=True),
diffrax.SaveAt(steps=True, fn=lambda t, y, args: (y[0], y[1] + thr)),
]
num_steps = [steps, steps, steps + 1, steps]
yevents = [(thr, 0), (thr, 0), (thr, 0), (thr, thr)]

for saveat, n, yevent in zip(saveats, num_steps, yevents):
ts, ys = run(saveat)
xs, zs = ys
xevent, zevent = yevent
assert jnp.sum(jnp.isfinite(ts)) == n
assert jnp.sum(jnp.isfinite(xs)) == n
assert jnp.sum(jnp.isfinite(zs)) == n
assert jnp.all(jnp.isclose(ts[n - 1], thr, atol=1e-5))
assert jnp.all(jnp.isclose(xs[n - 1], xevent, atol=1e-5))
assert jnp.all(jnp.isclose(zs[n - 1], zevent, atol=1e-5))


@pytest.mark.parametrize("steps", (1, 2, 3, 4, 5))
def test_event_save_ts(steps):
term = diffrax.ODETerm(lambda t, y, args: (1.0, 1.0))
solver = diffrax.Tsit5()
t0 = 0
t1 = 10
dt0 = 1
thr = steps - 0.5
y0 = (0.0, -thr)
ts = jnp.array([0.5, 3.5, 5.5])

def cond_fn(t, y, args, **kwargs):
del t, args, kwargs
x, _ = y
return x - thr

root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm)
event = diffrax.Event(cond_fn, root_finder)

def run(saveat):
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
event=event,
saveat=saveat,
)
return cast(Array, sol.ts), cast(tuple, sol.ys)

saveats = [
diffrax.SaveAt(ts=ts),
diffrax.SaveAt(ts=ts, t1=True),
diffrax.SaveAt(ts=ts, t0=True),
diffrax.SaveAt(ts=ts, steps=True),
diffrax.SaveAt(ts=ts, fn=lambda t, y, args: (y[0], y[1] + thr)),
]
save_finals = [False, True, False, True, False]
yevents = [(thr, 0), (thr, 0), (thr, 0), (thr, 0), (thr, thr)]
for saveat, save_final, yevent in zip(saveats, save_finals, yevents):
ts, ys = run(saveat)
xs, zs = ys
xevent, zevent = yevent
if save_final:
assert jnp.all(jnp.isclose(ts[jnp.isfinite(ts)][-1], thr, atol=1e-5))
assert jnp.all(jnp.isclose(xs[jnp.isfinite(xs)][-1], xevent, atol=1e-5))
assert jnp.all(jnp.isclose(zs[jnp.isfinite(zs)][-1], zevent, atol=1e-5))
else:
assert jnp.all(ts[jnp.isfinite(ts)] <= thr)


@pytest.mark.parametrize("steps", (1, 2, 3, 4, 5))
def test_event_save_subsaveat(steps):
term = diffrax.ODETerm(lambda t, y, args: jnp.array([1.0, 1.0]))
solver = diffrax.Tsit5()
t0 = 0.0
t1 = 10.0
dt0 = 1.0
thr = steps - 0.5
y0 = jnp.array([0.0, -thr])
ts = jnp.arange(t0, t1, 3.0)
ts_event = jnp.sum(ts <= thr)
last_t = jnp.array(ts[ts_event - 1])

def cond_fn(t, y, args, **kwargs):
del t, args, kwargs
x, _ = y
return x - thr

root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm)
event = diffrax.Event(cond_fn, root_finder)

class Saved(eqx.Module):
y: Array

def save_fn(t, y, args):
del t, args
ynorm = jnp.einsum("i,i->", y, y)
return Saved(jnp.array([ynorm]))

last_save = save_fn(None, y0 + last_t, None).y
subsaveat_a = diffrax.SubSaveAt(ts=ts, fn=save_fn)
subsaveat_b = diffrax.SubSaveAt(steps=True)
saveat = diffrax.SaveAt(subs=[subsaveat_a, subsaveat_b])
sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0, event=event, saveat=saveat)
ts_1, ts_2 = cast(list, sol.ts)
ys_1, ys_2 = cast(list, sol.ys)
assert jnp.sum(jnp.isfinite(ts_1)) == ts_event
assert jnp.sum(jnp.isfinite(ts_2)) == steps
assert jnp.all(jnp.isclose(ys_2[steps - 1], jnp.array([thr, 0]), atol=1e-5))
assert jnp.all(jnp.isclose(ys_1.y[ts_event - 1], last_save, atol=1e-5))