Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v0.11.4 #654

Merged
merged 20 commits into from
Apr 14, 2024
Merged

v0.11.4 #654

merged 20 commits into from
Apr 14, 2024

Conversation

patrick-kidger
Copy link
Owner

@patrick-kidger patrick-kidger commented Feb 6, 2024

Features

Compatibility

  • eqx.error_if is now compatible with JAX 0.4.26. (Which changed JAX's own reporting of error messages slightly.)

  • Added a warning that checks for doing something like:

    class MyModule(eqx.Module):
    	fn: Callable
    
        def __init__(self, ...):
    	    self.fn = jax.vmap(some_fn)

    As this is an easy source of bugs. (The vmap'd function is not a PyTree so will not propagate anything in the PyTree stucture of some_fn.)

Technical internal stuff

  • eqx.internal.while_loop(..., kind="checkpointed") will now only propagate forward JVP tracers for those outputs which are perturbed due to the input to the loop being perturbed. (Rather than all of them.) This change just means that later calls to a nondifferentiable operation, like jax.pure_callback or eqx.internal.nondifferentiable, will no longer crash at trace time. (See Problems with progress bar and jax.grad diffrax#396.)

  • eqx.internal.while_loop(..., kind="bounded") will now handle certain vmap+grad combinations without crashing. (It seems like AJX is adding some spurious batch tracers.) (See pytree output structure mismatch error in backprop during vmap optimistix#48 (comment))

  • the transpose rule for eqx.internal.create_vprim now understands symbolic zeros, fixing a crash for grad-of-vmap-of-<lineax.linear_solve that we only use some outputs from>. (See pytree output structure mismatch error in backprop during vmap optimistix#48.)

  • The type annotation for the input of any converter function used in eqx.field(converter=...) will now be used as the type annotation in any dataclass-autogenerated __init__ functions. In particular this should mean such functions are now compatible with runtime type checkers like beartype. (jaxtyping users, you were already covered: this checks the assigned annotations instead.)

patrick-kidger and others added 17 commits February 6, 2024 17:00
* add new padding options for Conv and ConvTranspose

* Update _conv.py

* Add tests for the padding of `Conv` and `ConvTranspose`

* Fix some type hints

* Fix the type of padding_t
* rope embeddings added

* added sinusoidial embedding

* added rope to mha

* added caching and compute-on-the-fly approach if no max_seq_len given and added process heads to MHA

* remove `use_rope_embedding` flag

* fixed merge related errors

* removed unnecessary state_len flag and placed shape checking in if-clause

* rope embeddings added

* added sinusoidial embedding

* added rope to mha

* added caching and compute-on-the-fly approach if no max_seq_len given and added process heads to MHA

* remove `use_rope_embedding` flag

* fixed merge related errors

* removed unnecessary state_len flag and placed shape checking in if-clause

* worked in review

* export new embeddings

* removed state len again, oops

* add ensure_compile_time_eval

* remove max_seq_len completely

* removed unnecessary if check

* improved docstrings

* better mem, adhering to strict jax config

* fixed dtype promotion

* removed dtype float and use float(seq_len) instead

* jnp.arange(0.0, ...) to force floats

* Adjustments to RoPE:

- Changed how the rotation is done to match the ESM2 implementation.
- Lots of doc tidy-ups.
- Removed SinusoidalPositionalEmbedding. I think I want to be more certain that this is correct before merging it.

* added rope tests

* typo

* fixed tests and annotations

* removed internal_sinus cache

---------

Co-authored-by: Patrick Kidger <[email protected]>
* add dtype and format code

* add a simple test for checking dtype other than float32

* fix default dtype and format code

* refine documentation for the dtype argument
* added dtypes

* fixed norms and added info to _spectral_norm
…erter` annotation if available.

In particular this should mean that Equinox modules are now compatible with beartype decorators.
homerjed and others added 3 commits April 7, 2024 10:46
* eqx.filter_shard; test + example

* fixed line lengths

* fixed?

* double checking..
The main improvement here is that a checkpointed while loop will now only propagate perturbations for those outputs that are actually perturbed. Whilst this doesn't affect the backward pass at all (we were already trimming the cotangents according to this criterion), this now means that any calls to `eqx.nondifferentiable`, or any primitive without a JVP rule, will now no longer throw an error.

In addition, this commit includes a couple of crash fixes (needed to pass the new test).
@patrick-kidger patrick-kidger changed the title Dev v0.11.4 Apr 14, 2024
@patrick-kidger patrick-kidger merged commit b88edca into main Apr 14, 2024
2 checks passed
@patrick-kidger patrick-kidger deleted the dev branch April 14, 2024 12:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants