Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialization to diffrax #441

Open
matthiaskoenig opened this issue Apr 30, 2024 · 3 comments
Open

Serialization to diffrax #441

matthiaskoenig opened this issue Apr 30, 2024 · 3 comments
Labels

Comments

@matthiaskoenig
Copy link
Owner

Serialize SBML ODEs to diffrax/Jax format.

@johannahaffner
Copy link

johannahaffner commented May 1, 2024

Hi Matthias!

I coded up something very simple to illustrate some of the differences with diffrax: four different callables, with and without attributes, with a traditional parameter, and with a more generalized argument.
(The argument could also be a combination of the examples used below, such as a tuple (k, u), as long as the callable knows how to unpack whatever is passed as args.)

import jax.numpy as jnp
import equinox as eqx
import diffrax as dfx

def monoexponential_decay(t, y, k):
    """Format t, y, args: args is a parameter."""
    return -k * y

# Making it an eqx.Module does the following: 
# - registers the class as a PyTree (so that JAX can handle it) 
# - turns it into a dataclass (not so relevant in a tiny thing without attributes)
class MonoexponentialDecay(eqx.Module):  
    """Same thing, but as a callable PyTree."""
    def __call__(self, t, y, k):
        return -k * y

def i1ffl(t, state, u):
    """ODE is a function, with format t, y, args. 
    args = u is now a dfx.LinearInterpolation, evaluated at t.

    Implements a type-1 incoherend feedforward loop with: 

    u: input (interpolation over discrete data)
    x: internal node (tracks input)
    y: output node

    ODE system: 

    x' = u - x
    y' = u/x - y
    """
    x, y = state
    
    d_x = u.evaluate(t) - x
    d_y = u.evaluate(t)/x - y
    d_state = d_x, d_y
    
    return jnp.array(d_state)

class I1FFL(eqx.Module):
    """Same as above, but as a callable PyTree with an attribute u."""
    _u: dfx.LinearInterpolation
    
    def __init__(self, experimental_input):
        self._u = experimental_input

    def __call__(self, t, state, args):
        """Implements the ODE. Parameter args exists for API compatibility."""
        x, y = state

        d_x = self._u.evaluate(t) - x
        d_y = self._u.evaluate(t)/x - y
        d_state = d_x, d_y

        return jnp.array(d_state)

solver = dfx.Tsit5()
t0 = 0.
t1 = 10.
dt0 = 0.01
y0 = 10.
k = 0.5
initial_state = jnp.array([1., 1.])
u = dfx.LinearInterpolation(ts=jnp.array([0, 5-0.001, 5+0.001, 10]), ys=jnp.array([1, 1, 5, 5]))

# Solve the ODEs (for demo purposes: no saved time series)
sol_monoexponential_decay = dfx.diffeqsolve(dfx.ODETerm(monoexponential_decay), solver, t0, t1, dt0, y0, args=k)
sol_MonoexponentialDecay = dfx.diffeqsolve(dfx.ODETerm(MonoexponentialDecay()), solver, t0, t1, dt0, y0, args=k)
sol_i1ffl = dfx.diffeqsolve(dfx.ODETerm(i1ffl), solver, t0, t1, dt0, initial_state, args=u)
sol_I1FFL = dfx.diffeqsolve(dfx.ODETerm(I1FFL(u)), solver, t0, t1, dt0, initial_state)

One of the core differences with something like scipy is that diffrax can take a PyTree as an argument. This can be a neural network (the original use case), or a callable class, as demonstrated with the LinearInterpolation.
I was wondering if this presents an issue?

I looked at the SBML2ODE class and ran the odefac_example, looks very neat! A workaround to something more general (like a control term) could of course be to convert the ODE to Python, and then edit the generated definition, for example to include a method called on an object, such as LinearInterpolation.evaluate(t).

@matthiaskoenig
Copy link
Owner Author

Hi @johannahaffner,
thanks for the great example. I will create some SBML examples for these and create a first version of the exporter.

I still try to understand why diffrax for these use cases and not use libroadrunner with SBML which is super fast and probably much faster for stiff problems. Could you provide some insights on that.

The GPU ODE solvers only seem to make sense for to solve the same ODE with many u0 and p. E.g use DiffEqGPU.jl'sEnsembleGPUArray and EnsembleGPUKernel.
See: Automated translation and accelerated solving of differential equations on multiple GPU platforms. But then only for non-stiff simple systems (see the benchmarks in the paper for the stiff systems where the CPU methods are faster). The Julia methods are much faster then diffrax (20x) for these use cases, so if it is speed then you should probably look towards Julia or using CPU solvers in C++ (i.e. libroadrunner with python bindings). The speedup of diffrax is basically against native python in scipy (but basically nobody is using this today to integrate ODEs, everybody is using wrappers to sundials CVODE/LSODA or other C++/Fortran libraries.

Best Matthias

@johannahaffner
Copy link

johannahaffner commented May 6, 2024

Hi Matthias,

Thank you, that is amazing!

I have no experience with libroadrunner, always happy to learn!

I use diffrax in combination with optimistix, to estimate parameters across hundreds of individuals, with different initial conditions and experimental inputs, and partial observations of the ODE system (e.g. states A, B, C in one individual and A, B, D in another). My systems are also not too stiff, and generally not very large.
I'm really happy with the reliable autodiff in JAX/diffrax/optimistix, and compile times for the nonlinear solve + ODE solver.

Speed-wise, I'm also very happy with diffrax (can share more on benchmarks if you're interested). What I could get with DifferentialEquations.jl or the ensemble options in Julia would likely be on par in terms of speed, at least for the ODE part.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants