-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Serialization to diffrax #441
Comments
Hi Matthias! I coded up something very simple to illustrate some of the differences with diffrax: four different callables, with and without attributes, with a traditional parameter, and with a more generalized argument. import jax.numpy as jnp
import equinox as eqx
import diffrax as dfx
def monoexponential_decay(t, y, k):
"""Format t, y, args: args is a parameter."""
return -k * y
# Making it an eqx.Module does the following:
# - registers the class as a PyTree (so that JAX can handle it)
# - turns it into a dataclass (not so relevant in a tiny thing without attributes)
class MonoexponentialDecay(eqx.Module):
"""Same thing, but as a callable PyTree."""
def __call__(self, t, y, k):
return -k * y
def i1ffl(t, state, u):
"""ODE is a function, with format t, y, args.
args = u is now a dfx.LinearInterpolation, evaluated at t.
Implements a type-1 incoherend feedforward loop with:
u: input (interpolation over discrete data)
x: internal node (tracks input)
y: output node
ODE system:
x' = u - x
y' = u/x - y
"""
x, y = state
d_x = u.evaluate(t) - x
d_y = u.evaluate(t)/x - y
d_state = d_x, d_y
return jnp.array(d_state)
class I1FFL(eqx.Module):
"""Same as above, but as a callable PyTree with an attribute u."""
_u: dfx.LinearInterpolation
def __init__(self, experimental_input):
self._u = experimental_input
def __call__(self, t, state, args):
"""Implements the ODE. Parameter args exists for API compatibility."""
x, y = state
d_x = self._u.evaluate(t) - x
d_y = self._u.evaluate(t)/x - y
d_state = d_x, d_y
return jnp.array(d_state)
solver = dfx.Tsit5()
t0 = 0.
t1 = 10.
dt0 = 0.01
y0 = 10.
k = 0.5
initial_state = jnp.array([1., 1.])
u = dfx.LinearInterpolation(ts=jnp.array([0, 5-0.001, 5+0.001, 10]), ys=jnp.array([1, 1, 5, 5]))
# Solve the ODEs (for demo purposes: no saved time series)
sol_monoexponential_decay = dfx.diffeqsolve(dfx.ODETerm(monoexponential_decay), solver, t0, t1, dt0, y0, args=k)
sol_MonoexponentialDecay = dfx.diffeqsolve(dfx.ODETerm(MonoexponentialDecay()), solver, t0, t1, dt0, y0, args=k)
sol_i1ffl = dfx.diffeqsolve(dfx.ODETerm(i1ffl), solver, t0, t1, dt0, initial_state, args=u)
sol_I1FFL = dfx.diffeqsolve(dfx.ODETerm(I1FFL(u)), solver, t0, t1, dt0, initial_state) One of the core differences with something like scipy is that diffrax can take a PyTree as an argument. This can be a neural network (the original use case), or a callable class, as demonstrated with the I looked at the |
Hi @johannahaffner, I still try to understand why diffrax for these use cases and not use libroadrunner with SBML which is super fast and probably much faster for stiff problems. Could you provide some insights on that. The GPU ODE solvers only seem to make sense for to solve the same ODE with many u0 and p. E.g use DiffEqGPU.jl'sEnsembleGPUArray and EnsembleGPUKernel. Best Matthias |
Hi Matthias, Thank you, that is amazing! I have no experience with libroadrunner, always happy to learn! I use diffrax in combination with optimistix, to estimate parameters across hundreds of individuals, with different initial conditions and experimental inputs, and partial observations of the ODE system (e.g. states A, B, C in one individual and A, B, D in another). My systems are also not too stiff, and generally not very large. Speed-wise, I'm also very happy with diffrax (can share more on benchmarks if you're interested). What I could get with DifferentialEquations.jl or the ensemble options in Julia would likely be on par in terms of speed, at least for the ODE part. |
Serialize SBML ODEs to diffrax/Jax format.
The text was updated successfully, but these errors were encountered: