Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

in closure_convert, hoist all tracers involved in staging #14760

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

froystig
Copy link
Member

@froystig froystig commented Mar 2, 2023

with @mattjj

We need to hoist all DynamicJaxprTracers in jax.closure_convert, not only those that are float-dtyped. Leaving any DynamicJaxprTracer in the closure of a custom_{jvp,vjp}'ed function might result in a tracer leak, as staging to jaxpr could already be complete by the time the closure is consumed.

@froystig froystig added the pull ready Ready for copybara import and testing label Mar 2, 2023
@froystig froystig requested a review from mattjj March 2, 2023 23:56
@froystig froystig self-assigned this Mar 2, 2023
@froystig froystig force-pushed the closure-convert-hoist-everything branch from cdf14a7 to b43e6e7 Compare March 8, 2023 17:13
@froystig froystig force-pushed the closure-convert-hoist-everything branch 2 times, most recently from 10d9a3c to ddc1a5b Compare March 8, 2023 20:56
@froystig froystig force-pushed the closure-convert-hoist-everything branch 2 times, most recently from a1b6bc6 to 904cd9d Compare May 15, 2023 15:34
We need to hoist all `DynamicJaxprTracer`s in `jax.closure_convert`,
not only those that are float-dtyped. Leaving any `DynamicJaxprTracer`
in the closure of a `custom_{jvp,vjp}`'ed function might result in a
tracer leak, as staging to jaxpr could already be complete by the time
the closure is consumed.

Co-authored-by: Matthew Johnson <[email protected]>
@froystig froystig force-pushed the closure-convert-hoist-everything branch from 904cd9d to 0e6fea4 Compare May 15, 2023 15:55
@froystig
Copy link
Member Author

We can't quite merge this because it breaks this test:

https://github.com/google/jax/blob/15caafd93726f2f991053699c0d0ab530aa88b05/tests/api_test.py#L7354-L7362

Skimming discussion #9951 to which it links, the approach we'd intend going forward is to use custom_jvp (or custom_vjp) with the symbolic_zeros option. My guess is that we should try doing this ourselves in odeint, and then rewrite the test at a higher level, to test derivatives of odeint with integer arguments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants