Skip to content

Equinox v0.11.5

Compare
Choose a tag to compare
@github-actions github-actions released this 18 Aug 19:11
· 48 commits to main since this release

JAX compatibility

Recent versions of JAX (0.4.28+) have made some changes to:

  • Hashing of tracers;
  • Tree-map'ing over Nones;
  • Callbacks;
  • Pretty-printing.

With this update, we should now be compatible with both old and new versions of JAX: this fixes both some new crashes, and some new warnings. (#719, #724, #753, #758, thanks @jakevdp, @hawkinsp!)

Better errors

  • The error messages from eqx.error_if are now substantially more informative: they include traceback information including the stack, and mention the availability of the EQX_ON_ERROR variable. We also do a much better job hiding the large unhelpful printouts that XLA gives by default. (#785, #803)

  • The default value of EQX_ON_ERROR_BREAKPOINT_FRAMES is now 1. (#777) The impact of this is that using eqx.error_if alongside EQX_ON_ERROR=breakpoint will now:

    • reliably always open a debugger, rather than sometimes crashing at trace-time due to upstream JAX bug #16732.
    • however, by default the debugger will no longer include any additional stack frames above it (accessed via u).
    • much of the above is now explained in a printed-out informative message prior to the debugger opening.

Bugfixes

  • eqx.filter_{jacfwd, jacrev} now only apply filtering to their inputs but not their outputs. Previously this was problematic as there was no way to represent static-input-by-static-output in the returned Jacobian, so pieces were silently dropped. (#734, thanks @lockwo!)

  • eqx.tree_at can now be used to replace empty tuples. (#715, #717, #722, thanks @lockwo!)

  • eqx.filter_custom_jvp no longer raises a trace-time crash in some scenarios in which its **kwargs were erroneously counted as having tangents. (#745 (comment), #749)

  • No longer getting a trace-time crash when doing a particular combination of vmap + autodiff + checkpointed while loops. This occurred when using optimistix.BFGS around diffrax.diffeqsolve. (#777)

  • Fixed a trace-time crash when:

    • using a checkpointed while loop...
    • ...with a body function that has a closed-over tracer...
    • ...and that closed-over tracer is differentiated...
    • ...and there are no other closed-over tracers that are differentiated...
    • ...and the dependency on that tracer is only linear.
    • (patrick-kidger/diffrax#387 (comment), #752, thanks @dkweiss31!)
  • Fixed a trace-time crash when composing the grad of vmap of lineax.linear_solve. (patrick-kidger/lineax#101, #795, thanks @rhacking!)

  • eqx.nn.RMSNorm now uses at least 32-bit precision for numerical stability (#723, thanks @AakashKumarNain!)

New features

Other changes

New Contributors

Full Changelog: v0.11.4...v0.11.5