Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transpose rule for stop_gradient #68

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions optimistix/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import lineax as lx
from equinox.internal import ω
from jaxtyping import PyTree
from jax.interpreters import ad
from jax._src.ad_util import stop_gradient_p

from ._misc import tree_full_like

Expand Down Expand Up @@ -110,3 +112,12 @@ def _for_jvp(_diff):
t_residual = tree_full_like(residual, 0)

return (root, residual), (t_root, t_residual)


# Work arond JAX issue #22011,
# as well as https://github.com/patrick-kidger/diffrax/pull/387#issuecomment-2174488365
def stop_gradient_transpose(ct, x):
return ct,


ad.primitive_transposes[stop_gradient_p] = stop_gradient_transpose