Skip to content

Commit

Permalink
Saving events with SubSaveAts
Browse files Browse the repository at this point in the history
Previously, updating the last element of ys and ts did not handle the case where multiple `SubSaveAt`s were used. This is now fixed by adding a `jtu.tree_map` in the appropriate place.
  • Loading branch information
cholberg committed May 15, 2024
1 parent 5e9e1fa commit 8fd300c
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,9 +1195,12 @@ def _call_real(cond_fn_i, event_mask_i):
tevent = roots.value

# We might need to change this in order to get more accurate derivatives
yevent = jnp.where(event_happened, interpolator.evaluate(tevent), ys[-1])
yevent = jtu.tree_map(
lambda _y: jnp.where(event_happened, interpolator.evaluate(tevent), _y[-1]),
ys,
)
ys = jtu.tree_map(lambda _y, _yevent: _y.at[-1].set(_yevent), ys, yevent)
ts = ts.at[-1].set(tevent)
ts = jtu.tree_map(lambda _t: _t.at[-1].set(tevent), ts)
else:
event_mask = None

Expand Down

0 comments on commit 8fd300c

Please sign in to comment.