Zodiax is a lightweight extension to the object-oriented Jax framework Equinox. Equinox allows for differentiable classes that are recognised as a valid Jax type and Zodiax adds lightweight methods to simplify interfacing with these classes! Zodiax was originially built in the development of dLux and was designed to make working with large nested classes structures simple and flexible.
Zodiax is directly integrated with both Jax and Equinox, gaining all of their core features:
Accelerated Numpy: a Numpy like API that can run on GPU and TPU
Automatic Differentiation: Allows for optimisation and inference in extremely high dimensional spaces
Just-In-Time Compilation: Compliles code into XLA at runtime and optimising execution across hardware
Automatic Vectorisation: Allows for simple parallelism across hardware and asynchronys execution
Object Oriented Jax: Allows for differentiable classes that are recognised as a valid Jax type
Inbuilt Neural Networks: Has pre-built neural network layers classes
Path-Based Pytree Interface: Path based indexing allows for easy interfacing with large and highly nested physical models
Leaf Manipulation Methods: Inbuilt methods allow for easy manipulation of Pytrees mirroring the Jax Array API
Documentation: louisdesdoigts.github.io/zodiax/
Contributors: Louis Desdoigts
Requires: Python 3.9+, Jax 0.4.25+
Installation: pip install zodiax
Docs installation: pip install "zodiax[docs]"
Test installation: pip install "zodiax[tests]"
Create a regular class that inherits from zodiax.Base
import jax
import zodiax as zdx
import jax.numpy as np
class Linear(zdx.Base):
m : Jax.Array
b : Jax.Array
def __init__(self, m, b):
self.m = m
self.b = b
def model(self, x):
return self.m * x + self.b
linear = Linear(1., 1.)
Its that simple! The linear
class is now a fully differentiable object that gives us all the benefits of jax with an object-oriented interface! Lets see how we can jit-compile and take gradients of this class.
@jax.jit
@jax.grad
def loss_fn(model, xs, ys):
return np.square(model.model(xs) - ys).sum()
xs = np.arange(5)
ys = 2*np.arange(5)
grads = loss_fn(linear, xs, ys)
print(grads)
print(grads.m, grads.b)
> Linear(m=f32[], b=f32[])
> -40.0 -10.0
The grads
object is an instance of the Linear
class with the gradients of the parameters with respect to the loss function!