Skip to content

Commit

Permalink
Use an exact Jacobian using JAX
Browse files Browse the repository at this point in the history
  • Loading branch information
Felipe S. S. Schneider committed Nov 26, 2020
1 parent 967f4eb commit cdae256
Show file tree
Hide file tree
Showing 12 changed files with 101 additions and 63 deletions.
14 changes: 8 additions & 6 deletions INSTALL.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ The package, together with the above dependencies, can be installed from
Optionally, extra functionality is provided such as a command-line interface
and solvent properties::

pip install 'overreact[cli,solvents]'
pip install 'overreact[cli,fast,solvents]'

This last line installs `Rich <https://github.com/willmcgugan/rich>`_
and `thermo <https://github.com/CalebBell/thermo>`_ as well.
Rich is used in the command-line interface, and thermo is used
to calculate the dynamic viscosity of solvents in the context of the
:doc:`tutorials/collins-kimball` for diffusion-limited reactions.
This last line installs `Rich <https://github.com/willmcgugan/rich>`_,
`JAX <https://jax.readthedocs.io/en/latest/index.html>`_ and
`thermo <https://github.com/CalebBell/thermo>`_ as well.
Rich is used in the command-line interface, JAX helps speedup calculations,
and thermo is used to calculate the dynamic viscosity of solvents in the
context of the :doc:`tutorials/collins-kimball` for diffusion-limited
reactions.
Binary file modified docs/_static/drc.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/first-order.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/michaelis-menten-dydt.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/michaelis-menten-tof.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/michaelis-menten.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/simple-first-order.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
83 changes: 50 additions & 33 deletions docsrc/notebooks/#1 Solving complex kinetics schemes.ipynb

Large diffs are not rendered by default.

17 changes: 8 additions & 9 deletions docsrc/tutorials/simulation/1-basic-reaction-simulation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,14 @@ Fortunately, overreact is able to give you a function that, giving time and
concentrations, calculates the derivative with respect to time:

>>> dydt = simulate.get_dydt(scheme, k)
>>> dydt(0.0, y0) # t = 0.0
DeviceArray([-83333.3, -83333.3, 83333.3, 0. , 0. ], dtype=float64)
>>> dydt(0.0, y0) # t = 0.0 # doctest: +SKIP
array([-83333.3, -83333.3, 83333.3, 0. , 0. ])

From the above we see that the equilibrium will likely be rapidly satisfied,
while no product is being created at time zero, since there's no
enzyme-substrate complex yet.

Let's now do a one minute simulation with ``get_y`` (methods Radau or BDF are
recommended for likely stiff equations such as those):
Let's now do a one minute simulation with ``get_y``:

>>> y, r = simulate.get_y(dydt, y0)
>>> t = np.linspace(y.t_min, 60.0) # seconds
Expand All @@ -167,15 +166,15 @@ Text(...)

.. figure:: ../../_static/michaelis-menten.png

A one minute simulation of the Michaelis-Menten model for the enzyme Pepsin,
A one-minute simulation of the Michaelis-Menten model for the enzyme Pepsin,
an endopeptidase that breaks down proteins into smaller peptides. Observe
that the rapid equilibrium justifies the commonly applied steady state
that the rapid equilibrium justifies the commonly applied steady-state
approximation.

The simulation time was enough to convert all substrate into products and
regenerate the initial enzyme molecules:

>>> y(y.t_max)
>>> y(y.t_max) # doctest: +SKIP
array([0.05, 0.00, 0.00, 0.00, 1.00])

Getting rates back
Expand All @@ -198,7 +197,7 @@ Text(...)

.. figure:: ../../_static/michaelis-menten-dydt.png

Time derivative of concentrations for the one minute simulation of the Michaelis-Menten model for the enzyme Pepsin above.
The time derivative of concentrations for the one-minute simulation of the Michaelis-Menten model for the enzyme Pepsin above.

Furthermore, we can get the turnover frequency (TOF) as:

Expand All @@ -217,4 +216,4 @@ Text(...)

.. figure:: ../../_static/michaelis-menten-tof.png

Turnover frequency for the enzyme Pepsin above, in the one minute simulation of the Michaelis-Menten model.
The turnover frequency for the enzyme Pepsin above, in the one-minute simulation of the Michaelis-Menten model.
2 changes: 1 addition & 1 deletion overreact/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def get_drc(
compounds,
y0,
t_span=None,
method="Radau",
method="BDF",
qrrho=True,
scale="l mol-1 s-1",
temperature=298.15,
Expand Down
46 changes: 33 additions & 13 deletions overreact/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

logger = logging.getLogger(__name__)

_found_jax = _misc._find_package("thermo")
_found_jax = _misc._find_package("jax")
if _found_jax:
import jax.numpy as jnp
from jax import jacfwd
from jax import jit
from jax.config import config

Expand All @@ -27,7 +28,7 @@
jnp = np


def get_y(dydt, y0, t_span=None, method="Radau"):
def get_y(dydt, y0, t_span=None, method="BDF"):
"""Simulate a reaction scheme from its rate function.
This uses scipy's ``solve_ivp`` under the hood.
Expand All @@ -45,9 +46,8 @@ def get_y(dydt, y0, t_span=None, method="Radau"):
any zeroth-, first- or second-order reactions).
method : str, optional
Integration method to use. See `scipy.integrade.solve_ivp` for details.
If not sure, first try to run "RK45". If it makes unusually many
iterations, diverges, or fails, your problem is likely to be stiff and
you should use "BDF" or "Radau" (default).
Kinetics problems are very often stiff and, as such, "RK45" is
normally unsuited. "Radau", "BDF" or "LSODA" are good choices.
Returns
-------
Expand Down Expand Up @@ -78,10 +78,10 @@ def get_y(dydt, y0, t_span=None, method="Radau"):
Both `y` and `r` can be used to check concentrations and rates in any
point in time. In particular, both are vectorized:
>>> y(t)
array([[1. , 0.83243215, ..., 0.50000008, 0.50000005],
[0. , 0.16756785, ..., 0.49999992, 0.49999995]])
>>> r(t)
>>> y(t) # doctest: +SKIP
array([[1. , 0.83244929, ..., 0.49999842, 0.49999888],
[0. , 0.16755071, ..., 0.50000158, 0.50000112]])
>>> r(t) # doctest: +SKIP
array([[-1.00000000e+00, ..., -1.01639971e-07],
[ 1.00000000e+00, ..., 1.01639971e-07]])
"""
Expand Down Expand Up @@ -110,7 +110,11 @@ def get_y(dydt, y0, t_span=None, method="Radau"):
]
logger.info(f"simulation time span = {t_span} s")

res = _solve_ivp(dydt, t_span, y0, method=method, dense_output=True)
jac = None
if hasattr(dydt, "jac"):
jac = dydt.jac

res = _solve_ivp(dydt, t_span, y0, method=method, dense_output=True, jac=jac)
y = res.sol

def r(t):
Expand Down Expand Up @@ -141,7 +145,9 @@ def get_dydt(scheme, k, ef=1.0e3):
-------
dydt : callable
Reaction rate function. The actual reaction rate constants employed
are stored in the attribute `k` of the returned function.
are stored in the attribute `k` of the returned function. If JAX is
available, the attribute `jac` will hold the Jacobian function of
`dydt`.
Warns
-----
Expand All @@ -162,8 +168,8 @@ def get_dydt(scheme, k, ef=1.0e3):
>>> from overreact import core
>>> scheme = core.parse_reactions("A <=> B")
>>> dydt = get_dydt(scheme, [1, 1])
>>> dydt(0.0, np.array([1., 1.]))
DeviceArray([0., 0.], dtype=float64)
>>> dydt(0.0, np.array([1., 1.])) # doctest: +SKIP
array([0., 0.])
If available, JAX is used for JIT compilation. This will make `dydt`
complain if given lists instead of numpy arrays. So stick to the safer,
Expand All @@ -175,6 +181,13 @@ def get_dydt(scheme, k, ef=1.0e3):
>>> dydt.k
array([1, 1])
If JAX is available, the Jacobian function will be available as
`dydt.jac`:
>>> dydt.jac(0.0, np.array([1., 1.])) # doctest: +SKIP
DeviceArray([[-1., 1.],
[ 1., -1.]], dtype=float64)
"""
scheme = _core._check_scheme(scheme)
is_half_equilibrium = np.asanyarray(scheme.is_half_equilibrium)
Expand Down Expand Up @@ -202,5 +215,12 @@ def _dydt(t, y, k=k_adj, M=M):
if _found_jax:
_dydt = jit(_dydt)

def _jac(t, y, k=k_adj, M=M):
# _jac(t, y)[i, j] == d f_i / d y_j
# shape is (n_compounds, n_compounds)
return jacfwd(lambda _y: _dydt(t, _y, k, M))(y)

_dydt.jac = _jac

_dydt.k = k_adj
return _dydt
2 changes: 1 addition & 1 deletion tests/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_get_y_propagates_reaction_with_fixed_time():
[1.668212890625, 0.6728515625, 0.341787109375], 9e-5
)
assert r(y.t_min) == pytest.approx([-31.99, -127.96, 31.99])
assert r(y.t_max) == pytest.approx([0.0, 0.0, 0.0], abs=3e-6)
assert r(y.t_max) == pytest.approx([0.0, 0.0, 0.0], abs=4e-5)


def test_get_y_conservation_in_equilibria():
Expand Down

0 comments on commit cdae256

Please sign in to comment.