Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adjoit
Browse files Browse the repository at this point in the history
lockwo committed Jan 24, 2025
1 parent 9a19d68 commit 4994982
Showing 1 changed file with 11 additions and 14 deletions.
25 changes: 11 additions & 14 deletions test/test_adjoint.py
Original file line number Diff line number Diff line change
@@ -221,17 +221,17 @@ def test_direct_brownian():
key, subkey = jax.random.split(key)
driftkey, diffusionkey, ykey = jr.split(subkey, 3)
drift_mlp = eqx.nn.MLP(
in_size=3,
out_size=3,
in_size=2,
out_size=2,
width_size=8,
depth=2,
activation=jax.nn.swish,
final_activation=jnp.tanh,
key=driftkey,
)
diffusion_mlp = eqx.nn.MLP(
in_size=3,
out_size=3,
in_size=2,
out_size=2,
width_size=8,
depth=2,
activation=jax.nn.swish,
@@ -251,16 +251,16 @@ class DiffusionField(eqx.Module):
def __call__(self, t, y, args):
return lx.DiagonalLinearOperator(self.force(y))

y0 = jr.normal(ykey, (3,))
y0 = jr.normal(ykey, (2,))

k1, k2, k3 = jax.random.split(key, 3)

vbt = diffrax.VirtualBrownianTree(
0.3, 9.5, 1e-4, (3,), k1, levy_area=diffrax.SpaceTimeLevyArea
0.3, 9.5, 1e-4, (2,), k1, levy_area=diffrax.SpaceTimeLevyArea
)
dbp = diffrax.UnsafeBrownianPath((3,), k2, levy_area=diffrax.SpaceTimeLevyArea)
dbp = diffrax.UnsafeBrownianPath((2,), k2, levy_area=diffrax.SpaceTimeLevyArea)
dbp_pre = diffrax.UnsafeBrownianPath(
(3,), k3, levy_area=diffrax.SpaceTimeLevyArea, precompute=int(9.5 / 0.1)
(2,), k3, levy_area=diffrax.SpaceTimeLevyArea, precompute=int(9.5 / 0.1)
)

vbt_terms = diffrax.MultiTerm(
@@ -301,17 +301,14 @@ def _run(y0__args__term, saveat, adjoint):
# Only does gradients with respect to y0
def _run_finite_diff(y0__args__term, saveat, adjoint):
y0, args, term = y0__args__term
y0_a = y0 + jnp.array([1e-5, 0, 0])
y0_b = y0 + jnp.array([0, 1e-5, 0])
y0_c = y0 + jnp.array([0, 0, 1e-5])
y0_a = y0 + jnp.array([1e-5, 0])
y0_b = y0 + jnp.array([0, 1e-5])
val = _run((y0, args, term), saveat, adjoint)
val_a = _run((y0_a, args, term), saveat, adjoint)
val_b = _run((y0_b, args, term), saveat, adjoint)
val_c = _run((y0_c, args, term), saveat, adjoint)
out_a = (val_a - val) / 1e-5
out_b = (val_b - val) / 1e-5
out_c = (val_c - val) / 1e-5
return jnp.stack([out_a, out_b, out_c])
return jnp.stack([out_a, out_b])

for t0 in (True, False):
for t1 in (True, False):

0 comments on commit 4994982

Please sign in to comment.