Skip to content

Commit

Permalink
ULD fix
Browse files Browse the repository at this point in the history
  • Loading branch information
lockwo committed Jan 22, 2025
1 parent d0f161c commit 16fedb2
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
17 changes: 12 additions & 5 deletions diffrax/_solver/foster_langevin_srk.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from .._custom_types import (
AbstractBrownianIncrement,
Args,
BoolScalarLike,
DenseInfo,
RealScalarLike,
Expand Down Expand Up @@ -50,7 +51,7 @@ def _get_args_from_terms(
PyTree,
PyTree,
PyTree,
Callable[[UnderdampedLangevinX], UnderdampedLangevinX],
Callable[[UnderdampedLangevinX, Args], UnderdampedLangevinX],
]:
drift, diffusion = terms.terms
if isinstance(drift, WrapTerm):
Expand Down Expand Up @@ -320,7 +321,7 @@ def shape_check_fun(_x, _g, _u, _fx):

coeffs = self._recompute_coeffs(h, gamma, tay_coeffs)
rho = jtu.tree_map(lambda c, _u: jnp.sqrt(2 * c * _u), gamma, u)
prev_f = grad_f(x0) if self._is_fsal else None
prev_f = grad_f(x0, args) if self._is_fsal else None

state_out = SolverState(
gamma=gamma,
Expand Down Expand Up @@ -386,7 +387,6 @@ def step(
_PathState,
RESULTS,
]:
del args
st = solver_state
drift, diffusion = terms.terms
drift_path, diffusion_path = path_state
Expand Down Expand Up @@ -422,12 +422,19 @@ def step(
prev_f = st.prev_f
else:
prev_f = lax.cond(
eqxi.unvmap_any(made_jump), lambda: grad_f(x0), lambda: st.prev_f
eqxi.unvmap_any(made_jump), lambda: grad_f(x0, args), lambda: st.prev_f
)

# The actual step computation, handled by the subclass
x_out, v_out, f_fsal, error = self._compute_step(
h, levy, x0, v0, (gamma, u, grad_f), coeffs, rho, prev_f
h,
levy,
x0,
v0,
(gamma, u, lambda inp: grad_f(inp, args)),
coeffs,
rho,
prev_f,
)

def check_shapes_dtypes(arg, *args):
Expand Down
4 changes: 2 additions & 2 deletions test/test_underdamped_langevin.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def make_pytree(array_factory):
"qq": jnp.ones((), dtype),
}

def grad_f(x):
def grad_f(x, _):
xa = x["rr"]
xb = x["qq"]
return {"rr": jtu.tree_map(lambda _x: 0.2 * _x, xa), "qq": xb}
Expand Down Expand Up @@ -242,7 +242,7 @@ def test_different_args():
u1 = (jnp.array([1, 2]), 1)
g2 = (jnp.array([1, 2]), jnp.array([1, 3]))
u2 = (jnp.array([1, 2]), jnp.ones((2,)))
grad_f = lambda x: x
grad_f = lambda x, _: x

w_shape = (
jax.ShapeDtypeStruct((2,), jnp.float64),
Expand Down

0 comments on commit 16fedb2

Please sign in to comment.