-
-
Notifications
You must be signed in to change notification settings - Fork 150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
v0.11.4 #654
Merged
Merged
v0.11.4 #654
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
* add new padding options for Conv and ConvTranspose * Update _conv.py * Add tests for the padding of `Conv` and `ConvTranspose` * Fix some type hints * Fix the type of padding_t
…ssigning a jax-transformed layer.
* rope embeddings added * added sinusoidial embedding * added rope to mha * added caching and compute-on-the-fly approach if no max_seq_len given and added process heads to MHA * remove `use_rope_embedding` flag * fixed merge related errors * removed unnecessary state_len flag and placed shape checking in if-clause * rope embeddings added * added sinusoidial embedding * added rope to mha * added caching and compute-on-the-fly approach if no max_seq_len given and added process heads to MHA * remove `use_rope_embedding` flag * fixed merge related errors * removed unnecessary state_len flag and placed shape checking in if-clause * worked in review * export new embeddings * removed state len again, oops * add ensure_compile_time_eval * remove max_seq_len completely * removed unnecessary if check * improved docstrings * better mem, adhering to strict jax config * fixed dtype promotion * removed dtype float and use float(seq_len) instead * jnp.arange(0.0, ...) to force floats * Adjustments to RoPE: - Changed how the rotation is done to match the ESM2 implementation. - Lots of doc tidy-ups. - Removed SinusoidalPositionalEmbedding. I think I want to be more certain that this is correct before merging it. * added rope tests * typo * fixed tests and annotations * removed internal_sinus cache --------- Co-authored-by: Patrick Kidger <[email protected]>
* added dtypes * fixed norms and added info to _spectral_norm
…erter` annotation if available. In particular this should mean that Equinox modules are now compatible with beartype decorators.
* eqx.filter_shard; test + example * fixed line lengths * fixed? * double checking..
The main improvement here is that a checkpointed while loop will now only propagate perturbations for those outputs that are actually perturbed. Whilst this doesn't affect the backward pass at all (we were already trimming the cotangents according to this criterion), this now means that any calls to `eqx.nondifferentiable`, or any primitive without a JVP rule, will now no longer throw an error. In addition, this commit includes a couple of crash fixes (needed to pass the new test).
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Features
Added
eqx.filter_shard
. This lowers tojax.lax.with_sharding_constraint
as a single way to transfer data, or reshard data, both inside and outside of JIT! (No morejax.device_put
.) In addition, the parallelism example has been updated to use this simpler new functionality. (Thanks @homerjed and @dlwh! Sharding - shardeqx.Module
as well as inputs? #688,eqx.filter_shard
; test + updateexamples/parallelism.ipynb
#691)Added
eqx.filter_{jacfwd,jacrev,hessian}
. These do what you expect! (Thanks @lockwo! Add filter hessian #677)Added
eqx.nn.RotaryPostionalEmbedding
. This is designed to be used in conjunction with the existingeqx.nn.MultiheadAttention
. (Thanks @Artur-Galstyan! RoPE Embeddings #568)Added support for
padding='VALID'
,padding='SAME'
,padding='SAME_LOWER'
to the convolutional layers:eqx.nn.{Conv, ...}
. (Thanks @ChenAo-Phys! New padding options forConv
andConvTranspose
#658)Added support for
padding_mode='ZEROS'
,padding_mode='REFLECT'
,padding_mode='REPLICATE'
,padding_mode='CIRCULAR'
to the convolutional layers:eqx.nn.{Conv, ...}
. (Thanks @ChenAo-Phys! New padding options forConv
andConvTranspose
#658)Added a
dtype
argument toeqx.nn.{MultiheadAttention, Linear, Conv, ...}
for specifying the dtype of their parameters. In additioneqx.nn.BatchNorm
will now also uses itsdtype
argument to determine the dtype of its weights and bias, not just the dtype of its moving statistics. (Thanks @Artur-Galstyan and @AakashKumarNain! Simpledtype
argument addition #680, Add dtypes to the rest of eqx.nn #689)Compatibility
eqx.error_if
is now compatible with JAX 0.4.26. (Which changed JAX's own reporting of error messages slightly.)Added a warning that checks for doing something like:
As this is an easy source of bugs. (The vmap'd function is not a PyTree so will not propagate anything in the PyTree stucture of
some_fn
.)Technical internal stuff
eqx.internal.while_loop(..., kind="checkpointed")
will now only propagate forward JVP tracers for those outputs which are perturbed due to the input to the loop being perturbed. (Rather than all of them.) This change just means that later calls to a nondifferentiable operation, likejax.pure_callback
oreqx.internal.nondifferentiable
, will no longer crash at trace time. (See Problems with progress bar andjax.grad
diffrax#396.)eqx.internal.while_loop(..., kind="bounded")
will now handle certain vmap+grad combinations without crashing. (It seems like AJX is adding some spurious batch tracers.) (See pytree output structure mismatch error in backprop during vmap optimistix#48 (comment))the transpose rule for
eqx.internal.create_vprim
now understands symbolic zeros, fixing a crash forgrad-of-vmap-of-<lineax.linear_solve that we only use some outputs from>
. (See pytree output structure mismatch error in backprop during vmap optimistix#48.)The type annotation for the input of any converter function used in
eqx.field(converter=...)
will now be used as the type annotation in anydataclass
-autogenerated__init__
functions. In particular this should mean such functions are now compatible with runtime type checkers like beartype. (jaxtyping users, you were already covered: this checks the assigned annotations instead.)