diff --git a/optimistix/_ad.py b/optimistix/_ad.py index 38d6bcd..a881cf5 100644 --- a/optimistix/_ad.py +++ b/optimistix/_ad.py @@ -10,6 +10,8 @@ import lineax as lx from equinox.internal import ω from jaxtyping import PyTree +from jax.interpreters import ad +from jax._src.ad_util import stop_gradient_p from ._misc import tree_full_like @@ -110,3 +112,12 @@ def _for_jvp(_diff): t_residual = tree_full_like(residual, 0) return (root, residual), (t_root, t_residual) + + +# Work arond JAX issue #22011, +# as well as https://github.com/patrick-kidger/diffrax/pull/387#issuecomment-2174488365 +def stop_gradient_transpose(ct, x): + return ct, + + +ad.primitive_transposes[stop_gradient_p] = stop_gradient_transpose