From 16fedb25640c2cb8bb66de1714ac81483f433931 Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Tue, 21 Jan 2025 19:16:10 -0800 Subject: [PATCH] ULD fix --- diffrax/_solver/foster_langevin_srk.py | 17 ++++++++++++----- test/test_underdamped_langevin.py | 4 ++-- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/diffrax/_solver/foster_langevin_srk.py b/diffrax/_solver/foster_langevin_srk.py index 47e89ee7..044b7438 100644 --- a/diffrax/_solver/foster_langevin_srk.py +++ b/diffrax/_solver/foster_langevin_srk.py @@ -13,6 +13,7 @@ from .._custom_types import ( AbstractBrownianIncrement, + Args, BoolScalarLike, DenseInfo, RealScalarLike, @@ -50,7 +51,7 @@ def _get_args_from_terms( PyTree, PyTree, PyTree, - Callable[[UnderdampedLangevinX], UnderdampedLangevinX], + Callable[[UnderdampedLangevinX, Args], UnderdampedLangevinX], ]: drift, diffusion = terms.terms if isinstance(drift, WrapTerm): @@ -320,7 +321,7 @@ def shape_check_fun(_x, _g, _u, _fx): coeffs = self._recompute_coeffs(h, gamma, tay_coeffs) rho = jtu.tree_map(lambda c, _u: jnp.sqrt(2 * c * _u), gamma, u) - prev_f = grad_f(x0) if self._is_fsal else None + prev_f = grad_f(x0, args) if self._is_fsal else None state_out = SolverState( gamma=gamma, @@ -386,7 +387,6 @@ def step( _PathState, RESULTS, ]: - del args st = solver_state drift, diffusion = terms.terms drift_path, diffusion_path = path_state @@ -422,12 +422,19 @@ def step( prev_f = st.prev_f else: prev_f = lax.cond( - eqxi.unvmap_any(made_jump), lambda: grad_f(x0), lambda: st.prev_f + eqxi.unvmap_any(made_jump), lambda: grad_f(x0, args), lambda: st.prev_f ) # The actual step computation, handled by the subclass x_out, v_out, f_fsal, error = self._compute_step( - h, levy, x0, v0, (gamma, u, grad_f), coeffs, rho, prev_f + h, + levy, + x0, + v0, + (gamma, u, lambda inp: grad_f(inp, args)), + coeffs, + rho, + prev_f, ) def check_shapes_dtypes(arg, *args): diff --git a/test/test_underdamped_langevin.py b/test/test_underdamped_langevin.py index e945cad5..53a43a24 100644 --- a/test/test_underdamped_langevin.py +++ b/test/test_underdamped_langevin.py @@ -59,7 +59,7 @@ def make_pytree(array_factory): "qq": jnp.ones((), dtype), } - def grad_f(x): + def grad_f(x, _): xa = x["rr"] xb = x["qq"] return {"rr": jtu.tree_map(lambda _x: 0.2 * _x, xa), "qq": xb} @@ -242,7 +242,7 @@ def test_different_args(): u1 = (jnp.array([1, 2]), 1) g2 = (jnp.array([1, 2]), jnp.array([1, 3])) u2 = (jnp.array([1, 2]), jnp.ones((2,))) - grad_f = lambda x: x + grad_f = lambda x, _: x w_shape = ( jax.ShapeDtypeStruct((2,), jnp.float64),