Skip to content

Commit

Permalink
Optionally use JAX for speed
Browse files Browse the repository at this point in the history
  • Loading branch information
Felipe S. S. Schneider committed Nov 26, 2020
1 parent 17f8f8b commit 967f4eb
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 17 deletions.
Binary file modified docs/_static/drc.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 3 additions & 4 deletions docsrc/tutorials/simulation/1-basic-reaction-simulation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ units, as long as they match with the reaction rate constants) and solve the
initial value problem. Below we set the substrate to one molar and the enzyme
to 5% of it:

>>> y0 = [0.05, 1.00, 0.00, 0.00, 0.00]
>>> import numpy as np
>>> y0 = np.array([0.05, 1.00, 0.00, 0.00, 0.00])

One return value that ``core.parse_reactions`` has given us was the :math:`A`
matrix, whose entry :math:`A_{ij}` stores the coefficient of the i-th compound
Expand Down Expand Up @@ -136,7 +137,7 @@ concentrations, calculates the derivative with respect to time:

>>> dydt = simulate.get_dydt(scheme, k)
>>> dydt(0.0, y0) # t = 0.0
array([-83333.3, -83333.3, 83333.3, 0. , 0. ])
DeviceArray([-83333.3, -83333.3, 83333.3, 0. , 0. ], dtype=float64)

From the above we see that the equilibrium will likely be rapidly satisfied,
while no product is being created at time zero, since there's no
Expand All @@ -146,8 +147,6 @@ Let's now do a one minute simulation with ``get_y`` (methods Radau or BDF are
recommended for likely stiff equations such as those):

>>> y, r = simulate.get_y(dydt, y0)

>>> import numpy as np
>>> t = np.linspace(y.t_min, 60.0) # seconds

We can graph concentrations over time with ``t`` and ``y``:
Expand Down
31 changes: 26 additions & 5 deletions overreact/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,20 @@
from scipy.integrate import solve_ivp as _solve_ivp

from overreact import core as _core
from overreact import misc as _misc

logger = logging.getLogger(__name__)

_found_jax = _misc._find_package("thermo")
if _found_jax:
import jax.numpy as jnp
from jax import jit
from jax.config import config

config.update("jax_enable_x64", True)
else:
jnp = np


def get_y(dydt, y0, t_span=None, method="Radau"):
"""Simulate a reaction scheme from its rate function.
Expand Down Expand Up @@ -147,11 +158,16 @@ def get_dydt(scheme, k, ef=1.0e3):
Examples
--------
>>> import numpy as np
>>> from overreact import core
>>> scheme = core.parse_reactions("A <=> B")
>>> dydt = get_dydt(scheme, [1, 1])
>>> dydt(0.0, [1., 1.])
array([0., 0.])
>>> dydt(0.0, np.array([1., 1.]))
DeviceArray([0., 0.], dtype=float64)
If available, JAX is used for JIT compilation. This will make `dydt`
complain if given lists instead of numpy arrays. So stick to the safer,
faster side as above.
The actually used reaction rate constants can be inspected with the `k`
attribute of `dydt`:
Expand All @@ -177,9 +193,14 @@ def get_dydt(scheme, k, ef=1.0e3):
k_adj[~is_half_equilibrium].max() / k_adj[is_half_equilibrium].min()
)

def _dydt(t, y, k=k_adj, A=A):
r = k * np.prod(np.power(y, np.where(A > 0, 0, -A).T), axis=1)
return np.dot(A, r)
M = np.where(A > 0, 0, -A).T

def _dydt(t, y, k=k_adj, M=M):
r = k * jnp.prod(jnp.power(y, M), axis=1)
return jnp.dot(A, r)

if _found_jax:
_dydt = jit(_dydt)

_dydt.k = k_adj
return _dydt
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
install_requires=["cclib>=1.6.3", "scipy>=1.4.0"],
extras_require={
"cli": ["rich>=9.2.0"],
"fast": ["jax>=0.2.6", "jaxlib>=0.1.57"],
"solvents": ["thermo>=0.1.39"],
},
tests_require=["matplotlib>=2.1.1", "pytest>=5.2.1"],
Expand Down
17 changes: 10 additions & 7 deletions tests/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@ def test_get_dydt_calculates_reaction_rate():

dydt = simulate.get_dydt(scheme, [2.0])

assert dydt(0.0, [1.0, 0.0]) == pytest.approx([-2.0, 2.0])
assert dydt(5.0, [1.0, 0.0]) == pytest.approx([-2.0, 2.0])
assert dydt(0.0, [1.0, 1.0]) == pytest.approx([-2.0, 2.0])
assert dydt(0.0, [10.0, 0.0]) == pytest.approx([-20.0, 20.0])
# TODO(schneiderfelipe): if jax is used, dydt won't accept lists, only
# ndarrays. Should we wrap the jitted code and use np.asanyarray in the
# wrapping function?
assert dydt(0.0, np.array([1.0, 0.0])) == pytest.approx([-2.0, 2.0])
assert dydt(5.0, np.array([1.0, 0.0])) == pytest.approx([-2.0, 2.0])
assert dydt(0.0, np.array([1.0, 1.0])) == pytest.approx([-2.0, 2.0])
assert dydt(0.0, np.array([10.0, 0.0])) == pytest.approx([-20.0, 20.0])


def test_get_y_propagates_reaction_automatically():
Expand All @@ -44,7 +47,7 @@ def test_get_y_propagates_reaction_automatically():
assert y.t_max == 1000.0
assert y(y.t_min) == pytest.approx(y0)
assert y(y.t_max) == pytest.approx(
[1.668212890625, 0.6728515625, 0.341787109375], 7e-5
[1.668212890625, 0.6728515625, 0.341787109375], 9e-5
)
assert r(y.t_min) == pytest.approx([-31.99, -127.96, 31.99])
assert r(y.t_max) == pytest.approx([0.0, 0.0, 0.0], abs=2e-4)
Expand Down Expand Up @@ -85,9 +88,9 @@ def test_get_y_conservation_in_equilibria():
assert y.t_min == 0.0
assert y.t_max == 10.0
assert y(y.t_min) == pytest.approx(y0)
assert y(y.t_max) == pytest.approx([0.5, 0.5])
assert y(y.t_max) == pytest.approx([0.5, 0.5], 3e-5)
assert r(y.t_min) == pytest.approx([-1, 1])
assert r(y.t_max) == pytest.approx([0.0, 0.0], abs=2e-7)
assert r(y.t_max) == pytest.approx([0.0, 0.0], abs=3e-5)

assert y.t_min == t[0]
assert y.t_max == t[-1]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_thermo_solv.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def test_translational_entropy_liquid_phase():
data.atomnos, data.atomcoords, method="izato"
)
assert free_volume / (constants.angstrom ** 3 * constants.N_A) == pytest.approx(
0.0993, 1.12847e-1
0.0993, 1.13105e-1
)

assert _thermo.calc_trans_entropy(
Expand Down

0 comments on commit 967f4eb

Please sign in to comment.