-
Take the following simple example:
Note the last line with
Am I misunderstanding this function? I am not entirely sure I need to use it in the circumstances, but I am having trouble forming a mental model of why these are different calls? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
@jlperla your example now works in jax.vmap(residual, in_axes=(None, 0, 0), state_axes={...: None})(model, X, Y) But I'd suggest upgrading to the new API instead and keep your example as is. |
Beta Was this translation helpful? Give feedback.
-
Verified that this works on 0.9 on both windows and macos! |
Beta Was this translation helpful? Give feedback.
I'm very happy we made the change to JAX-style transforms, starting from
0.9.0
the mental model for Flax transforms is the same as for JAX transforms which should help users a lot.