diff --git a/.github/release_checklist.md b/.github/release_checklist.md
index b62824793f..11daa8ff70 100644
--- a/.github/release_checklist.md
+++ b/.github/release_checklist.md
@@ -4,4 +4,4 @@
- `CITATION.cff`
- `vcpkg.json`
- Update CHANGELOG.md with a summary of the release
-- Update (and pin) jax and jaxlib to latest version and fix any bugs that arise
+- Update jax and jaxlib to latest version in `pybamm.util.install_jax` and fix any bugs that arise
diff --git a/.github/workflows/build_wheels.yml b/.github/workflows/build_wheels.yml
index 74f712a1b1..c663fdba2f 100644
--- a/.github/workflows/build_wheels.yml
+++ b/.github/workflows/build_wheels.yml
@@ -76,7 +76,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Download wheels
- uses: actions/download-artifact@v1
+ uses: actions/download-artifact@v2
with:
name: wheels
diff --git a/.github/workflows/test_on_push.yml b/.github/workflows/test_on_push.yml
index cc137d93ad..d6d4b1509f 100644
--- a/.github/workflows/test_on_push.yml
+++ b/.github/workflows/test_on_push.yml
@@ -2,7 +2,7 @@ name: PyBaMM
on:
push:
-
+
pull_request:
# everyday at 3 am UTC
@@ -10,13 +10,13 @@ on:
- cron: '0 3 * * *'
jobs:
-
+
style:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Setup python
- uses: actions/setup-python@v1
+ uses: actions/setup-python@v2
with:
python-version: 3.7
@@ -37,19 +37,19 @@ jobs:
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v1
+ uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
-
+
- name: Install Linux system dependencies
if: matrix.os == 'ubuntu-latest'
run: |
sudo apt-get update
sudo apt install gfortran gcc libopenblas-dev graphviz
sudo apt install texlive-full
-
+
# Added fixes to homebrew installs:
- # rm -f /usr/local/bin/2to3
+ # rm -f /usr/local/bin/2to3
# (see https://github.com/actions/virtual-environments/issues/2322)
- name: Install MacOS system dependencies
if: matrix.os == 'macos-latest'
@@ -62,7 +62,7 @@ jobs:
- name: Install Windows system dependencies
if: matrix.os == 'windows-latest'
run: choco install graphviz --version=2.38.0.20190211
-
+
- name: Install standard python dependencies
run: |
python -m pip install --upgrade pip wheel setuptools
@@ -72,9 +72,17 @@ jobs:
if: matrix.os == 'ubuntu-latest'
run: tox -e pybamm-requires
- - name: Run unit tests for GNU/Linux
+ - name: Run unit tests for GNU/Linux with Python 3.7 and 3.8
+ if: matrix.os == 'ubuntu-latest' && matrix.python-version != 3.9
+ run: python -m tox -e quick
+
+ - name: Run unit tests for GNU/Linux with Python 3.9 and generate coverage report
+ if: matrix.os == 'ubuntu-latest' && matrix.python-version == 3.9
+ run: tox -e coverage
+
+ - name: Run integration tests for GNU/Linux
if: matrix.os == 'ubuntu-latest'
- run: python -m tox -e tests
+ run: python -m tox -e integration
- name: Run unit tests for Windows and MacOS
if: matrix.os != 'ubuntu-latest'
@@ -83,16 +91,11 @@ jobs:
- name: Install docs dependencies and run doctests
if: matrix.os == 'ubuntu-latest'
run: tox -e doctests
-
+
- name: Install dev dependencies and run example tests
if: matrix.os == 'ubuntu-latest'
run: tox -e examples
-
- - name: Install and run coverage
- if: success() && (matrix.os == 'ubuntu-latest' && matrix.python-version == 3.9)
- run: tox -e coverage
- name: Upload coverage report
if: matrix.os == 'ubuntu-latest' && matrix.python-version == 3.9
- uses: codecov/codecov-action@v1
-
+ uses: codecov/codecov-action@v2.1.0
diff --git a/CHANGELOG.md b/CHANGELOG.md
index d7194793bc..3d75b89be7 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -2,6 +2,7 @@
## Features
+- `Experiment`s with drive cycles can be solved ([#1793](https://github.com/pybamm-team/PyBaMM/pull/1793))
- Added surface area to volume ratio as a factor to the SEI equations ([#1790](https://github.com/pybamm-team/PyBaMM/pull/1790))
- Half-cell SPM and SPMe have been implemented ([#1731](https://github.com/pybamm-team/PyBaMM/pull/1731))
@@ -14,6 +15,7 @@
- Raise error when trying to convert an `Interpolant` with the "pchip" interpolator to CasADI ([#1791](https://github.com/pybamm-team/PyBaMM/pull/1791))
- Raise error if `Concatenation` is used directly with `Variable` objects (`concatenation` should be used instead) ([#1789](https://github.com/pybamm-team/PyBaMM/pull/1789))
+- Made jax and the PyBaMM JaxSolver optional ([#1767](https://github.com/pybamm-team/PyBaMM/pull/1767))
# [v21.10](https://github.com/pybamm-team/PyBaMM/tree/v21.9) - 2021-10-31
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index e1ecfbd91c..78928d7d60 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -280,7 +280,7 @@ Major PyBaMM features are showcased in [Jupyter notebooks](https://jupyter.org/)
All example notebooks should be listed in [examples/README.md](https://github.com/pybamm-team/PyBaMM/blob/develop/examples/notebooks/README.md). Please follow the (naming and writing) style of existing notebooks where possible.
-Where possible, notebooks are tested daily. A list of slow notebooks (which time-out and fail tests) is maintained in `.slow-books`, these notebooks will be excluded from daily testing.
+All the notebooks are tested daily.
## Citations
diff --git a/README.md b/README.md
index d2f77757c6..83a18d1c42 100644
--- a/README.md
+++ b/README.md
@@ -87,7 +87,9 @@ conda install -c conda-forge pybamm
```
### Optional solvers
-On GNU/Linux and MacOS, an optional [scikits.odes](https://scikits-odes.readthedocs.io/en/latest/)-based solver is available, see [the documentation](https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#scikits-odes-label).
+Following GNU/Linux and macOS solvers are optionally available:
+- [scikits.odes](https://scikits-odes.readthedocs.io/en/latest/)-based solver, see [the documentation](https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-scikits-odes-solver).
+- [jax](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html)-based solver, see [the documentation](https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver).
## 📖 Citing PyBaMM
diff --git a/docs/index.rst b/docs/index.rst
index 0545c825da..a96bd47510 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -46,7 +46,10 @@ PyBaMM is available as a conda package through the conda-forge channel.
Optional solvers
-----------------
-On GNU/Linux and MacOS, an optional `scikits.odes `_ -based solver is available, see :ref:`scikits.odes-label`.
+Following GNU/Linux and macOS solvers are optionally available:
+
+* `scikits.odes `_ -based solver, see `Optional - scikits.odes solver `_.
+* `jax `_ -based solver, see `Optional - JaxSolver `_.
Installation
============
diff --git a/docs/install/GNU-linux.rst b/docs/install/GNU-linux.rst
index b008955c03..cbb06cab42 100644
--- a/docs/install/GNU-linux.rst
+++ b/docs/install/GNU-linux.rst
@@ -45,8 +45,8 @@ User install
------------
We recommend to install PyBaMM within a virtual environment, in order
-not to alter any distribution python files.
-First, make sure you are using python 3.7, 3.8, or 3.9.
+not to alter any distribution python files.
+First, make sure you are using python 3.7, 3.8, or 3.9.
To create a virtual environment ``env`` within your current directory type:
.. code:: bash
@@ -116,9 +116,24 @@ macOS
.. code:: bash
pip install scikits.odes
-
+
Assuming that the SUNDIALS were installed as described :ref:`above`.
+Optional - JaxSolver
+--------------------
+
+Users can install ``jax`` and ``jaxlib`` to use the Jax solver.
+Currently, only GNU/Linux and macOS are supported.
+
+GNU/Linux and macOS
+~~~~~~~~~~~~~~~~~~~
+
+.. code:: bash
+
+ pybamm_install_jax
+
+The ``pybamm_install_jax`` command is installed with PyBaMM. It automatically downloads and installs jax and jaxlib on your system.
+
Developer install
-----------------
diff --git a/docs/requirements.txt b/docs/requirements.txt
index b2754b438d..2dab106555 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -7,8 +7,6 @@ autograd >= 1.2
scikit-fem >= 0.2.0
casadi >= 3.5.0
imageio>=2.9.0
-jax==0.2.12
-jaxlib==0.1.70
jupyter # For example notebooks
pybtex
# Note: Matplotlib is loaded for debug plots but to ensure pybamm runs
@@ -19,4 +17,4 @@ matplotlib >= 2.0
#
guzzle-sphinx-theme
sphinx>4.0
-sympy==1.8
+sympy==1.9
diff --git a/examples/scripts/experiment_drive_cycle.py b/examples/scripts/experiment_drive_cycle.py
new file mode 100644
index 0000000000..2ead334dad
--- /dev/null
+++ b/examples/scripts/experiment_drive_cycle.py
@@ -0,0 +1,52 @@
+#
+# Constant-current constant-voltage charge with US06 Drive Cycle using Experiment Class.
+#
+import pybamm
+import pandas as pd
+import os
+
+os.chdir(pybamm.__path__[0] + "/..")
+
+pybamm.set_logging_level("INFO")
+
+# import drive cycle from file
+drive_cycle_current = pd.read_csv(
+ "pybamm/input/drive_cycles/US06.csv", comment="#", header=None
+).to_numpy()
+
+
+# Map Drive Cycle
+def map_drive_cycle(x, min_op_value, max_op_value):
+ min_ip_value = x[:, 1].min()
+ max_ip_value = x[:, 1].max()
+ x[:, 1] = (x[:, 1] - min_ip_value) / (max_ip_value - min_ip_value) * (
+ max_op_value - min_op_value
+ ) + min_op_value
+ return x
+
+
+# Map current drive cycle to voltage and power
+drive_cycle_power = map_drive_cycle(drive_cycle_current, 1.5, 3.5)
+
+experiment = pybamm.Experiment(
+ [
+ "Charge at 1 A until 4.0 V",
+ "Hold at 4.0 V until 50 mA",
+ "Rest for 30 minutes",
+ "Run US06_A (A)",
+ "Rest for 30 minutes",
+ "Run US06_W (W)",
+ "Rest for 30 minutes",
+ ],
+ drive_cycles={
+ "US06_A": drive_cycle_current,
+ "US06_W": drive_cycle_power,
+ },
+)
+
+model = pybamm.lithium_ion.DFN()
+sim = pybamm.Simulation(model, experiment=experiment, solver=pybamm.CasadiSolver())
+sim.solve()
+
+# Show all plots
+sim.plot()
diff --git a/pybamm/CITATIONS.txt b/pybamm/CITATIONS.txt
index f2e797b7d5..d9177d8bba 100644
--- a/pybamm/CITATIONS.txt
+++ b/pybamm/CITATIONS.txt
@@ -192,17 +192,17 @@
}
@article{Kirk2021,
- author = {Toby L. Kirk and Colin P. Please and S. Jon Chapman},
- title = {Physical Modelling of the Slow Voltage Relaxation Phenomenon in Lithium-Ion Batteries},
- journal = {Journal of The Electrochemical Society},
- year = 2021,
- month = {jun},
- publisher = {The Electrochemical Society},
- volume = {168},
- number = {6},
- pages = {060554},
- doi = {10.1149/1945-7111/ac0bf7},
- url = {https://doi.org/10.1149/1945-7111/ac0bf7},
+ author = {Toby L. Kirk and Colin P. Please and S. Jon Chapman},
+ title = {Physical Modelling of the Slow Voltage Relaxation Phenomenon in Lithium-Ion Batteries},
+ journal = {Journal of The Electrochemical Society},
+ year = 2021,
+ month = {jun},
+ publisher = {The Electrochemical Society},
+ volume = {168},
+ number = {6},
+ pages = {060554},
+ doi = {10.1149/1945-7111/ac0bf7},
+ url = {https://doi.org/10.1149/1945-7111/ac0bf7},
}
@article{Lain2019,
@@ -283,17 +283,17 @@
}
@article{OKane2020,
- doi = {10.1149/1945-7111/ab90ac},
- url = {https://doi.org/10.1149/1945-7111/ab90ac},
- year = {2020},
- month = {may},
- publisher = {The Electrochemical Society},
- volume = {167},
- number = {9},
- pages = {090540},
- author = {Simon E. J. O'Kane and Ian D. Campbell and Mohamed W. J. Marzook and Gregory J. Offer and Monica Marinescu},
- title = {Physical Origin of the Differential Voltage Minimum Associated with Lithium Plating in Li-Ion Batteries},
- journal = {Journal of The Electrochemical Society}
+ doi = {10.1149/1945-7111/ab90ac},
+ url = {https://doi.org/10.1149/1945-7111/ab90ac},
+ year = {2020},
+ month = {may},
+ publisher = {The Electrochemical Society},
+ volume = {167},
+ number = {9},
+ pages = {090540},
+ author = {Simon E. J. O'Kane and Ian D. Campbell and Mohamed W. J. Marzook and Gregory J. Offer and Monica Marinescu},
+ title = {Physical Origin of the Differential Voltage Minimum Associated with Lithium Plating in Li-Ion Batteries},
+ journal = {Journal of The Electrochemical Society}
}
@article{ORegan2021,
diff --git a/pybamm/__init__.py b/pybamm/__init__.py
index aab8a54c4a..b5dfd69d98 100644
--- a/pybamm/__init__.py
+++ b/pybamm/__init__.py
@@ -7,7 +7,7 @@
#
import sys
import os
-import platform
+
#
# Version info
@@ -66,7 +66,7 @@ def version(formatted=False):
#
from .util import Timer, TimerTime, FuzzyDict
from .util import root_dir, load_function, rmse, get_infinite_nested_dict, load
-from .util import get_parameters_filepath
+from .util import get_parameters_filepath, have_jax, install_jax
from .logger import logger, set_logging_level
from .settings import settings
from .citations import Citations, citations, print_citations
@@ -102,12 +102,8 @@ def version(formatted=False):
EvaluatorPython,
)
-if not (
- platform.system() == "Windows"
- or (platform.system() == "Darwin" and "ARM64" in platform.version())
-):
- from .expression_tree.operations.evaluate_python import EvaluatorJax
- from .expression_tree.operations.evaluate_python import JaxCooMatrix
+from .expression_tree.operations.evaluate_python import EvaluatorJax
+from .expression_tree.operations.evaluate_python import JaxCooMatrix
from .expression_tree.operations.jacobian import Jacobian
from .expression_tree.operations.convert_to_casadi import CasadiConverter
@@ -226,13 +222,8 @@ def version(formatted=False):
from .solvers.scikits_ode_solver import ScikitsOdeSolver, have_scikits_odes
from .solvers.scipy_solver import ScipySolver
-# Jax not supported under windows
-if not (
- platform.system() == "Windows"
- or (platform.system() == "Darwin" and "ARM64" in platform.version())
-):
- from .solvers.jax_solver import JaxSolver
- from .solvers.jax_bdf_solver import jax_bdf_integrate
+from .solvers.jax_solver import JaxSolver
+from .solvers.jax_bdf_solver import jax_bdf_integrate
from .solvers.idaklu_solver import IDAKLUSolver, have_idaklu
diff --git a/pybamm/experiments/experiment.py b/pybamm/experiments/experiment.py
index ab2f3d709d..c175a94815 100644
--- a/pybamm/experiments/experiment.py
+++ b/pybamm/experiments/experiment.py
@@ -195,6 +195,7 @@ def read_string(self, cond, drive_cycles):
"electric": op_CC["electric"] + op_CV["electric"],
"time": op_CV["time"],
"period": op_CV["period"],
+ "dc_data": None,
}, event_CV
# Read period
if " period)" in cond:
@@ -223,26 +224,29 @@ def read_string(self, cond, drive_cycles):
drive_cycles[cond_list[1]], end_time
)
# Drive cycle as numpy array
+ dc_name = cond_list[1] + "_ext_{}".format(end_time)
dc_data = ext_drive_cycle
# Find the type of drive cycle ("A", "V", or "W")
typ = cond_list[2][1]
- electric = (dc_data, typ)
+ electric = (dc_name, typ)
time = ext_drive_cycle[:, 0][-1]
period = np.min(np.diff(ext_drive_cycle[:, 0]))
events = None
else:
# e.g. Run US06
# Drive cycle as numpy array
+ dc_name = cond_list[1]
dc_data = drive_cycles[cond_list[1]]
# Find the type of drive cycle ("A", "V", or "W")
typ = cond_list[2][1]
- electric = (dc_data, typ)
+ electric = (dc_name, typ)
# Set time and period to 1 second for first step and
# then calculate the difference in consecutive time steps
time = drive_cycles[cond_list[1]][:, 0][-1]
period = np.min(np.diff(drive_cycles[cond_list[1]][:, 0]))
events = None
else:
+ dc_data = None
if "for" in cond and "or until" in cond:
# e.g. for 3 hours or until 4.2 V
cond_list = cond.split()
@@ -273,7 +277,12 @@ def read_string(self, cond, drive_cycles):
)
)
- return {"electric": electric, "time": time, "period": period}, events
+ return {
+ "electric": electric,
+ "time": time,
+ "period": period,
+ "dc_data": dc_data,
+ }, events
def extend_drive_cycle(self, drive_cycle, end_time):
"Extends the drive cycle to enable for event"
diff --git a/pybamm/expression_tree/operations/evaluate_python.py b/pybamm/expression_tree/operations/evaluate_python.py
index ff8bf92853..15c919858f 100644
--- a/pybamm/expression_tree/operations/evaluate_python.py
+++ b/pybamm/expression_tree/operations/evaluate_python.py
@@ -1,113 +1,110 @@
#
# Write a symbol to python
#
-import pybamm
+import numbers
+from collections import OrderedDict
import numpy as np
import scipy.sparse
-from collections import OrderedDict
-import numbers
-from platform import system, version
+import pybamm
-if not (system() == "Windows" or (system() == "Darwin" and "ARM64" in version())):
+if pybamm.have_jax():
import jax
-
from jax.config import config
config.update("jax_enable_x64", True)
- class JaxCooMatrix:
- """
- A sparse matrix in COO format, with internal arrays using jax device arrays
- This matrix only has two operations supported, a multiply with a scalar, and a
- dot product with a dense vector. It can also be converted to a dense 2D jax
- device array
+class JaxCooMatrix:
+ """
+ A sparse matrix in COO format, with internal arrays using jax device arrays
- Parameters
- ----------
+ This matrix only has two operations supported, a multiply with a scalar, and a
+ dot product with a dense vector. It can also be converted to a dense 2D jax
+ device array
- row: arraylike
- 1D array holding row indices of non-zero entries
- col: arraylike
- 1D array holding col indices of non-zero entries
- data: arraylike
- 1D array holding non-zero entries
- shape: 2-element tuple (x, y)
- where x is the number of rows, and y the number of columns of the matrix
- """
+ Parameters
+ ----------
- def __init__(self, row, col, data, shape):
- self.row = jax.numpy.array(row)
- self.col = jax.numpy.array(col)
- self.data = jax.numpy.array(data)
- self.shape = shape
- self.nnz = len(self.data)
-
- def toarray(self):
- """convert sparse matrix to a dense 2D array"""
- result = jax.numpy.zeros(self.shape, dtype=self.data.dtype)
- return result.at[self.row, self.col].add(self.data)
-
- def dot_product(self, b):
- """
- dot product of matrix with a dense column vector b
-
- Parameters
- ----------
- b: jax device array
- must have shape (n, 1)
- """
- # assume b is a column vector
- result = jax.numpy.zeros((self.shape[0], 1), dtype=b.dtype)
- return result.at[self.row].add(self.data.reshape(-1, 1) * b[self.col])
-
- def scalar_multiply(self, b):
- """
- multiply of matrix with a scalar b
-
- Parameters
- ----------
- b: Number or 1 element jax device array
- scalar value to multiply
- """
- # assume b is a scalar or ndarray with 1 element
- return JaxCooMatrix(
- self.row, self.col, (self.data * b).reshape(-1), self.shape
+ row: arraylike
+ 1D array holding row indices of non-zero entries
+ col: arraylike
+ 1D array holding col indices of non-zero entries
+ data: arraylike
+ 1D array holding non-zero entries
+ shape: 2-element tuple (x, y)
+ where x is the number of rows, and y the number of columns of the matrix
+ """
+
+ def __init__(self, row, col, data, shape):
+ if not pybamm.have_jax():
+ raise ModuleNotFoundError(
+ "Jax is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501
)
- def multiply(self, b):
- """
- general matrix multiply not supported
- """
- raise NotImplementedError
+ self.row = jax.numpy.array(row)
+ self.col = jax.numpy.array(col)
+ self.data = jax.numpy.array(data)
+ self.shape = shape
+ self.nnz = len(self.data)
+
+ def toarray(self):
+ """convert sparse matrix to a dense 2D array"""
+ result = jax.numpy.zeros(self.shape, dtype=self.data.dtype)
+ return result.at[self.row, self.col].add(self.data)
+
+ def dot_product(self, b):
+ """
+ dot product of matrix with a dense column vector b
- def __matmul__(self, b):
- """see self.dot_product"""
- return self.dot_product(b)
+ Parameters
+ ----------
+ b: jax device array
+ must have shape (n, 1)
+ """
+ # assume b is a column vector
+ result = jax.numpy.zeros((self.shape[0], 1), dtype=b.dtype)
+ return result.at[self.row].add(self.data.reshape(-1, 1) * b[self.col])
- def create_jax_coo_matrix(value):
+ def scalar_multiply(self, b):
"""
- Creates a JaxCooMatrix from a scipy.sparse matrix
+ multiply of matrix with a scalar b
Parameters
----------
+ b: Number or 1 element jax device array
+ scalar value to multiply
+ """
+ # assume b is a scalar or ndarray with 1 element
+ return JaxCooMatrix(self.row, self.col, (self.data * b).reshape(-1), self.shape)
- value: scipy.sparse matrix
- the sparse matrix to be converted
+ def multiply(self, b):
"""
- scipy_coo = value.tocoo()
- row = jax.numpy.asarray(scipy_coo.row)
- col = jax.numpy.asarray(scipy_coo.col)
- data = jax.numpy.asarray(scipy_coo.data)
- return JaxCooMatrix(row, col, data, value.shape)
+ general matrix multiply not supported
+ """
+ raise NotImplementedError
+ def __matmul__(self, b):
+ """see self.dot_product"""
+ return self.dot_product(b)
-else:
- def create_jax_coo_matrix(value): # pragma: no cover
- raise NotImplementedError("Jax is not available on Windows")
+def create_jax_coo_matrix(value):
+ """
+ Creates a JaxCooMatrix from a scipy.sparse matrix
+
+ Parameters
+ ----------
+
+ value: scipy.sparse matrix
+ the sparse matrix to be converted
+ """
+ scipy_coo = value.tocoo()
+ row = jax.numpy.asarray(scipy_coo.row)
+ col = jax.numpy.asarray(scipy_coo.col)
+ data = jax.numpy.asarray(scipy_coo.data)
+ return JaxCooMatrix(row, col, data, value.shape)
def id_to_python_variable(symbol_id, constant=False):
@@ -548,6 +545,11 @@ class EvaluatorJax:
"""
def __init__(self, symbol):
+ if not pybamm.have_jax():
+ raise ModuleNotFoundError(
+ "Jax is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501
+ )
+
constants, python_str = pybamm.to_python(symbol, debug=False, output_jax=True)
# replace numpy function calls to jax numpy calls
@@ -582,9 +584,7 @@ def __init__(self, symbol):
args = "t=None, y=None, y_dot=None, inputs=None, known_evals=None"
if self._arg_list:
args = ",".join(self._arg_list) + ", " + args
- python_str = (
- "def evaluate_jax({}):\n".format(args) + python_str
- )
+ python_str = "def evaluate_jax({}):\n".format(args) + python_str
# calculate the final variable that will output the result of calling `evaluate`
# on `symbol`
@@ -609,8 +609,9 @@ def __init__(self, symbol):
exec(compiled_function)
self._static_argnums = tuple(static_argnums)
- self._jit_evaluate = jax.jit(self._evaluate_jax,
- static_argnums=self._static_argnums)
+ self._jit_evaluate = jax.jit(
+ self._evaluate_jax, static_argnums=self._static_argnums
+ )
def get_jacobian(self):
n = len(self._arg_list)
@@ -618,8 +619,9 @@ def get_jacobian(self):
# forward mode autodiff wrt y, which is argument 1 after arg_list
jacobian_evaluate = jax.jacfwd(self._evaluate_jax, argnums=1 + n)
- self._jac_evaluate = jax.jit(jacobian_evaluate,
- static_argnums=self._static_argnums)
+ self._jac_evaluate = jax.jit(
+ jacobian_evaluate, static_argnums=self._static_argnums
+ )
return EvaluatorJaxJacobian(self._jac_evaluate, self._constants)
@@ -629,8 +631,9 @@ def get_sensitivities(self):
# forward mode autodiff wrt inputs, which is argument 3 after arg_list
jacobian_evaluate = jax.jacfwd(self._evaluate_jax, argnums=3 + n)
- self._sens_evaluate = jax.jit(jacobian_evaluate,
- static_argnums=self._static_argnums)
+ self._sens_evaluate = jax.jit(
+ jacobian_evaluate, static_argnums=self._static_argnums
+ )
return EvaluatorJaxSensitivities(self._sens_evaluate, self._constants)
diff --git a/pybamm/simulation.py b/pybamm/simulation.py
index 58f435868e..589b762cd9 100644
--- a/pybamm/simulation.py
+++ b/pybamm/simulation.py
@@ -21,7 +21,7 @@ def is_notebook():
return False # Terminal running IPython
elif shell == "Shell": # pragma: no cover
return True # Google Colab notebook
- else:
+ else: # pragma: no cover
return False # Other type (?)
except NameError:
return False # Probably standard Python interpreter
@@ -170,45 +170,76 @@ def set_up_experiment(self, model, experiment):
"Power switch": 0,
"CCCV switch": 0,
"Current input [A]": 0,
- "Voltage input [V]": 0, # doesn't matter
- "Power input [W]": 0, # doesn't matter
+ "Voltage input [V]": 0,
+ "Power input [W]": 0,
}
op_control = op["electric"][1]
- if op_control in ["A", "C"]:
- capacity = self._parameter_values["Nominal cell capacity [A.h]"]
+ if op["dc_data"] is not None:
+ # If operating condition includes a drive cycle, define the interpolant
+ timescale = self._parameter_values.evaluate(model.timescale)
+ drive_cycle_interpolant = pybamm.Interpolant(
+ op["dc_data"][:, 0],
+ op["dc_data"][:, 1],
+ timescale * (pybamm.t - pybamm.InputParameter("start time")),
+ )
if op_control == "A":
- I = op["electric"][0]
- Crate = I / capacity
- else:
- # Scale C-rate with capacity to obtain current
- Crate = op["electric"][0]
- I = Crate * capacity
- if len(op["electric"]) == 4:
- # Update inputs for CCCV
- op_control = "CCCV" # change to CCCV
- V = op["electric"][2]
operating_inputs.update(
{
- "CCCV switch": 1,
- "Current input [A]": I,
- "Voltage input [V]": V,
+ "Current switch": 1,
+ "Current input [A]": drive_cycle_interpolant,
}
)
- else:
- # Update inputs for constant current
+ if op_control == "V":
+ operating_inputs.update(
+ {
+ "Voltage switch": 1,
+ "Voltage input [V]": drive_cycle_interpolant,
+ }
+ )
+ if op_control == "W":
+ operating_inputs.update(
+ {"Power switch": 1, "Power input [W]": drive_cycle_interpolant}
+ )
+ else:
+ if op_control in ["A", "C"]:
+ capacity = self._parameter_values["Nominal cell capacity [A.h]"]
+ if op_control == "A":
+ I = op["electric"][0]
+ Crate = I / capacity
+ else:
+ # Scale C-rate with capacity to obtain current
+ Crate = op["electric"][0]
+ I = Crate * capacity
+ if len(op["electric"]) == 4:
+ # Update inputs for CCCV
+ op_control = "CCCV" # change to CCCV
+ V = op["electric"][2]
+ operating_inputs.update(
+ {
+ "CCCV switch": 1,
+ "Current input [A]": I,
+ "Voltage input [V]": V,
+ }
+ )
+ else:
+ # Update inputs for constant current
+ operating_inputs.update(
+ {"Current switch": 1, "Current input [A]": I}
+ )
+ elif op_control == "V":
+ # Update inputs for constant voltage
+ V = op["electric"][0]
operating_inputs.update(
- {"Current switch": 1, "Current input [A]": I}
+ {"Voltage switch": 1, "Voltage input [V]": V}
)
- elif op_control == "V":
- # Update inputs for constant voltage
- V = op["electric"][0]
- operating_inputs.update({"Voltage switch": 1, "Voltage input [V]": V})
- elif op_control == "W":
- # Update inputs for constant power
- P = op["electric"][0]
- operating_inputs.update({"Power switch": 1, "Power input [W]": P})
+ elif op_control == "W":
+ # Update inputs for constant power
+ P = op["electric"][0]
+ operating_inputs.update({"Power switch": 1, "Power input [W]": P})
+
# Update period
operating_inputs["period"] = op["period"]
+
# Update events
if events is None:
# make current and voltage values that won't be hit
@@ -851,6 +882,11 @@ def solve(
f"step {step_num}/{cycle_length}: {op_conds_str}"
)
inputs.update(exp_inputs)
+ if current_solution is None:
+ start_time = 0
+ else:
+ start_time = current_solution.t[-1]
+ inputs.update({"start time": start_time})
kwargs["inputs"] = inputs
# Make sure we take at least 2 timesteps
npts = max(int(round(dt / exp_inputs["period"])) + 1, 2)
diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py
index ca2b43dc1a..98fa02dd59 100644
--- a/pybamm/solvers/base_solver.py
+++ b/pybamm/solvers/base_solver.py
@@ -1160,7 +1160,7 @@ def step(
external_variables : dict
A dictionary of external variables and their corresponding
values at the current time
- inputs_dict : dict, optional
+ inputs : dict, optional
Any input parameters to pass to the model when solving
save : bool
Turn on to store the solution of all previous timesteps
@@ -1201,6 +1201,14 @@ def step(
# Set up external variables and inputs
external_variables = external_variables or {}
inputs = inputs or {}
+
+ # Remove interpolant inputs as Casadi can't handle them
+ if isinstance(inputs.get("Current input [A]"), pybamm.Interpolant):
+ del inputs["Current input [A]"]
+ elif isinstance(inputs.get("Voltage input [V]"), pybamm.Interpolant):
+ del inputs["Voltage input [V]"]
+ elif isinstance(inputs.get("Power input [W]"), pybamm.Interpolant):
+ del inputs["Power input [W]"]
ext_and_inputs = {**external_variables, **inputs}
# Check that any inputs that may affect the scaling have not changed
diff --git a/pybamm/solvers/jax_bdf_solver.py b/pybamm/solvers/jax_bdf_solver.py
index 1dd765e16a..4d4c40a76f 100644
--- a/pybamm/solvers/jax_bdf_solver.py
+++ b/pybamm/solvers/jax_bdf_solver.py
@@ -1,792 +1,992 @@
-import operator as op
-import numpy as onp
import collections
-
-import jax
-import jax.numpy as jnp
-from jax import core
-from jax import dtypes
-from jax.util import safe_map, cache, split_list
-from jax.api_util import flatten_fun_nokwargs
-from jax.flatten_util import ravel_pytree
-from jax.tree_util import tree_map, tree_flatten, tree_unflatten, tree_multimap
-from jax.interpreters import partial_eval as pe
+import operator as op
from functools import partial
-from jax import linear_util as lu
-from jax.config import config
-from absl import logging
-
-logging.set_verbosity(logging.ERROR)
-
-config.update("jax_enable_x64", True)
-
-MAX_ORDER = 5
-NEWTON_MAXITER = 4
-ROOT_SOLVE_MAXITER = 15
-MIN_FACTOR = 0.2
-MAX_FACTOR = 10
-
-
-# https://github.com/google/jax/issues/4572#issuecomment-709809897
-def some_hash_function(x):
- return hash(x.tobytes())
-
-
-class HashableArrayWrapper:
- """wrapper for a numpy array to make it hashable"""
- def __init__(self, val):
- self.val = val
-
- def __hash__(self):
- return some_hash_function(self.val)
-
- def __eq__(self, other):
- return (isinstance(other, HashableArrayWrapper) and
- onp.all(onp.equal(self.val, other.val)))
-
-def gnool_jit(fun, static_array_argnums=(), static_argnums=()):
- """redefinition of jax jit to allow static array args"""
- @partial(
- jax.jit,
- static_argnums=static_array_argnums + static_argnums
- )
- def callee(*args):
- args = list(args)
- for i in static_array_argnums:
- args[i] = args[i].val
- return fun(*args)
-
- def caller(*args):
- args = list(args)
- for i in static_array_argnums:
- args[i] = HashableArrayWrapper(args[i])
- return callee(*args)
-
- return caller
-
-
-@jax.partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2, 3))
-def _bdf_odeint(fun, mass, rtol, atol, y0, t_eval, *args):
- """
- This implements a Backward Difference formula (BDF) implicit multistep integrator.
- The basic algorithm is derived in [2]_. This particular implementation follows that
- implemented in the Matlab routine ode15s described in [1]_ and the SciPy
- implementation [3]_, which features the NDF formulas for improved stability, with
- associated differences in the error constants, and calculates the jacobian at
- J(t_{n+1}, y^0_{n+1}). This implementation was based on that implemented in the
- scipy library [3]_, which also mainly follows [1]_ but uses the more standard
- jacobian update.
-
- Parameters
- ----------
-
- func: callable
- function to evaluate the time derivative of the solution `y` at time
- `t` as `func(y, t, *args)`, producing the same shape/structure as `y0`.
- mass: ndarray
- diagonal of the mass matrix with shape (n,)
- y0: ndarray
- initial state vector, has shape (n,)
- t_eval: ndarray
- time points to evaluate the solution, has shape (m,)
- args: (optional)
- tuple of additional arguments for `fun`, which must be arrays
- scalars, or (nested) standard Python containers (tuples, lists, dicts,
- namedtuples, i.e. pytrees) of those types.
- rtol: (optional) float
- relative tolerance for the solver
- atol: (optional) float
- absolute tolerance for the solver
+import numpy as onp
- Returns
- -------
- y: ndarray with shape (n, m)
- calculated state vector at each of the m time points
+import pybamm
- References
- ----------
- .. [1] L. F. Shampine, M. W. Reichelt, "THE MATLAB ODE SUITE", SIAM J. SCI.
- COMPUTE., Vol. 18, No. 1, pp. 1-22, January 1997.
- .. [2] G. D. Byrne, A. C. Hindmarsh, "A Polyalgorithm for the Numerical
- Solution of Ordinary Differential Equations", ACM Transactions on
- Mathematical Software, Vol. 1, No. 1, pp. 71-96, March 1975.
- .. [3] Virtanen, P., Gommers, R., Oliphant, T. E., Haberland, M., Reddy,
- T., Cournapeau, D., ... & van der Walt, S. J. (2020). SciPy 1.0:
- fundamental algorithms for scientific computing in Python.
- Nature methods, 17(3), 261-272.
- """
+if pybamm.have_jax():
+ import jax
+ import jax.numpy as jnp
+ from absl import logging
+ from jax import core, dtypes
+ from jax import linear_util as lu
+ from jax.api_util import flatten_fun_nokwargs
+ from jax.config import config
+ from jax.flatten_util import ravel_pytree
+ from jax.interpreters import partial_eval as pe
+ from jax.tree_util import tree_flatten, tree_map, tree_multimap, tree_unflatten
+ from jax.util import cache, safe_map, split_list
- def fun_bind_inputs(y, t):
- return fun(y, t, *args)
+ config.update("jax_enable_x64", True)
- jac_bind_inputs = jax.jacfwd(fun_bind_inputs, argnums=0)
+ logging.set_verbosity(logging.ERROR)
- t0 = t_eval[0]
- h0 = t_eval[1] - t0
+ MAX_ORDER = 5
+ NEWTON_MAXITER = 4
+ ROOT_SOLVE_MAXITER = 15
+ MIN_FACTOR = 0.2
+ MAX_FACTOR = 10
- stepper = _bdf_init(fun_bind_inputs, jac_bind_inputs, mass, t0, y0, h0, rtol, atol)
- i = 0
- y_out = jnp.empty((len(t_eval), len(y0)), dtype=y0.dtype)
+ # https://github.com/google/jax/issues/4572#issuecomment-709809897
+ def some_hash_function(x):
+ return hash(x.tobytes())
- init_state = [stepper, t_eval, i, y_out]
+ class HashableArrayWrapper:
+ """wrapper for a numpy array to make it hashable"""
- def cond_fun(state):
- _, t_eval, i, _ = state
- return i < len(t_eval)
+ def __init__(self, val):
+ self.val = val
- def body_fun(state):
- stepper, t_eval, i, y_out = state
- stepper = _bdf_step(stepper, fun_bind_inputs, jac_bind_inputs)
- index = jnp.searchsorted(t_eval, stepper.t)
+ def __hash__(self):
+ return some_hash_function(self.val)
- def for_body(j, y_out):
- t = t_eval[j]
- y_out = jax.ops.index_update(
- y_out, jax.ops.index[j, :], _bdf_interpolate(stepper, t)
+ def __eq__(self, other):
+ return isinstance(other, HashableArrayWrapper) and onp.all(
+ onp.equal(self.val, other.val)
)
- return y_out
-
- y_out = jax.lax.fori_loop(i, index, for_body, y_out)
- return [stepper, t_eval, index, y_out]
-
- stepper, t_eval, i, y_out = jax.lax.while_loop(cond_fun, body_fun, init_state)
- return y_out
-
-
-BDFInternalStates = [
- "t",
- "atol",
- "rtol",
- "M",
- "newton_tol",
- "order",
- "h",
- "n_equal_steps",
- "D",
- "y0",
- "scale_y0",
- "kappa",
- "gamma",
- "alpha",
- "c",
- "error_const",
- "J",
- "LU",
- "U",
- "psi",
- "n_function_evals",
- "n_jacobian_evals",
- "n_lu_decompositions",
- "n_steps",
- "consistent_y0_failed",
-]
-BDFState = collections.namedtuple("BDFState", BDFInternalStates)
-
-jax.tree_util.register_pytree_node(
- BDFState, lambda xs: (tuple(xs), None), lambda _, xs: BDFState(*xs)
-)
-
-
-def _bdf_init(fun, jac, mass, t0, y0, h0, rtol, atol):
- """
- Initiation routine for Backward Difference formula (BDF) implicit multistep
- integrator.
-
- See _bdf_odeint function above for details, this function returns a dict with the
- initial state of the solver
-
- Parameters
- ----------
-
- fun: callable
- function with signature (y, t), where t is a scalar time and y is a ndarray with
- shape (n,), returns the rhs of the system of ODE equations as an nd array with
- shape (n,)
- jac: callable
- function with signature (y, t), where t is a scalar time and y is a ndarray with
- shape (n,), returns the jacobian matrix of fun as an ndarray with shape (n,n)
- mass: ndarray
- diagonal of the mass matrix with shape (n,)
- t0: float
- initial time
- y0: ndarray
- initial state vector with shape (n,)
- h0: float
- initial step size
- rtol: (optional) float
- relative tolerance for the solver
- atol: (optional) float
- absolute tolerance for the solver
- """
-
- state = {}
- state["t"] = t0
- state["atol"] = atol
- state["rtol"] = rtol
- state["M"] = mass
- EPS = jnp.finfo(y0.dtype).eps
- state["newton_tol"] = jnp.maximum(10 * EPS / rtol, jnp.minimum(0.03, rtol ** 0.5))
-
- scale_y0 = atol + rtol * jnp.abs(y0)
- y0, not_converged = _select_initial_conditions(
- fun, mass, t0, y0, state["newton_tol"], scale_y0
- )
- state["consistent_y0_failed"] = not_converged
-
- f0 = fun(y0, t0)
- order = 1
- state["order"] = order
- state["h"] = _select_initial_step(atol, rtol, fun, t0, y0, f0, h0)
- state["n_equal_steps"] = 0
- D = jnp.empty((MAX_ORDER + 1, len(y0)), dtype=y0.dtype)
- D = jax.ops.index_update(D, jax.ops.index[0, :], y0)
- D = jax.ops.index_update(D, jax.ops.index[1, :], f0 * state["h"])
- state["D"] = D
- state["y0"] = y0
- state["scale_y0"] = scale_y0
-
- # kappa values for difference orders, taken from Table 1 of [1]
- kappa = jnp.array([0, -0.1850, -1 / 9, -0.0823, -0.0415, 0])
- gamma = jnp.hstack((0, jnp.cumsum(1 / jnp.arange(1, MAX_ORDER + 1))))
- alpha = 1.0 / ((1 - kappa) * gamma)
- c = state["h"] * alpha[order]
- error_const = kappa * gamma + 1 / jnp.arange(1, MAX_ORDER + 2)
-
- state["kappa"] = kappa
- state["gamma"] = gamma
- state["alpha"] = alpha
- state["c"] = c
- state["error_const"] = error_const
-
- J = jac(y0, t0)
- state["J"] = J
-
- state["LU"] = jax.scipy.linalg.lu_factor(state["M"] - c * J)
-
- state["U"] = _compute_R(order, 1)
- state["psi"] = None
-
- state["n_function_evals"] = 2
- state["n_jacobian_evals"] = 1
- state["n_lu_decompositions"] = 1
- state["n_steps"] = 0
-
- tuple_state = BDFState(*[state[k] for k in BDFInternalStates])
- y0, scale_y0 = _predict(tuple_state, D)
- psi = _update_psi(tuple_state, D)
- return tuple_state._replace(y0=y0, scale_y0=scale_y0, psi=psi)
-
-
-def _compute_R(order, factor):
- """
- computes the R matrix with entries
- given by the first equation on page 8 of [1]
-
- This is used to update the differences matrix when step size h is varied according
- to factor = h_{n+1} / h_n
-
- Note that the U matrix also defined in the same section can be also be
- found using factor = 1, which corresponds to R with a constant step size
- """
- I = jnp.arange(1, MAX_ORDER + 1).reshape(-1, 1)
- J = jnp.arange(1, MAX_ORDER + 1)
- M = jnp.empty((MAX_ORDER + 1, MAX_ORDER + 1))
- M = jax.ops.index_update(M, jax.ops.index[1:, 1:], (I - 1 - factor * J) / I)
- M = jax.ops.index_update(M, jax.ops.index[0], 1)
- R = jnp.cumprod(M, axis=0)
-
- return R
+ def gnool_jit(fun, static_array_argnums=(), static_argnums=()):
+ """redefinition of jax jit to allow static array args"""
-def _select_initial_conditions(fun, M, t0, y0, tol, scale_y0):
- # identify algebraic variables as zeros on diagonal
- algebraic_variables = onp.diag(M) == 0.0
+ @partial(jax.jit, static_argnums=static_array_argnums + static_argnums)
+ def callee(*args):
+ args = list(args)
+ for i in static_array_argnums:
+ args[i] = args[i].val
+ return fun(*args)
- # if all differentiable variables then return y0 (can use normal python if since M
- # is static)
- if not onp.any(algebraic_variables):
- return y0, False
+ def caller(*args):
+ args = list(args)
+ for i in static_array_argnums:
+ args[i] = HashableArrayWrapper(args[i])
+ return callee(*args)
- # calculate consistent initial conditions via a newton on -J_a @ delta = f_a This
- # follows this reference:
- #
- # Shampine, L. F., Reichelt, M. W., & Kierzenka, J. A. (1999). Solving index-1 DAEs
- # in MATLAB and Simulink. SIAM review, 41(3), 538-552.
-
- # calculate fun_a, function of algebraic variables
- def fun_a(y_a):
- y_full = jax.ops.index_update(y0, algebraic_variables, y_a)
- return fun(y_full, t0)[algebraic_variables]
-
- y0_a = y0[algebraic_variables]
- scale_y0_a = scale_y0[algebraic_variables]
-
- d = jnp.zeros(y0_a.shape[0], dtype=y0.dtype)
- y_a = jnp.array(y0_a, copy=True)
+ return caller
- # calculate neg jacobian of fun_a
- J_a = jax.jacfwd(fun_a)(y_a)
- LU = jax.scipy.linalg.lu_factor(-J_a)
+ @jax.partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2, 3))
+ def _bdf_odeint(fun, mass, rtol, atol, y0, t_eval, *args):
+ """
+ Implements a Backward Difference formula (BDF) implicit multistep integrator.
+ The basic algorithm is derived in [2]_. This particular implementation follows
+ that implemented in the Matlab routine ode15s described in [1]_ and the SciPy
+ implementation [3]_, which features the NDF formulas for improved stability,
+ with associated differences in the error constants, and calculates the jacobian
+ at J(t_{n+1}, y^0_{n+1}). This implementation was based on that implemented in
+ the scipy library [3]_, which also mainly follows [1]_ but uses the more
+ standard jacobian update.
+
+ Parameters
+ ----------
+
+ func: callable
+ function to evaluate the time derivative of the solution `y` at time
+ `t` as `func(y, t, *args)`, producing the same shape/structure as `y0`.
+ mass: ndarray
+ diagonal of the mass matrix with shape (n,)
+ y0: ndarray
+ initial state vector, has shape (n,)
+ t_eval: ndarray
+ time points to evaluate the solution, has shape (m,)
+ args: (optional)
+ tuple of additional arguments for `fun`, which must be arrays
+ scalars, or (nested) standard Python containers (tuples, lists, dicts,
+ namedtuples, i.e. pytrees) of those types.
+ rtol: (optional) float
+ relative tolerance for the solver
+ atol: (optional) float
+ absolute tolerance for the solver
+
+ Returns
+ -------
+ y: ndarray with shape (n, m)
+ calculated state vector at each of the m time points
+
+ References
+ ----------
+ .. [1] L. F. Shampine, M. W. Reichelt, "THE MATLAB ODE SUITE", SIAM J. SCI.
+ COMPUTE., Vol. 18, No. 1, pp. 1-22, January 1997.
+ .. [2] G. D. Byrne, A. C. Hindmarsh, "A Polyalgorithm for the Numerical
+ Solution of Ordinary Differential Equations", ACM Transactions on
+ Mathematical Software, Vol. 1, No. 1, pp. 71-96, March 1975.
+ .. [3] Virtanen, P., Gommers, R., Oliphant, T. E., Haberland, M., Reddy,
+ T., Cournapeau, D., ... & van der Walt, S. J. (2020). SciPy 1.0:
+ fundamental algorithms for scientific computing in Python.
+ Nature methods, 17(3), 261-272.
+ """
- converged = False
- dy_norm_old = -1.0
- k = 0
- while_state = [k, converged, dy_norm_old, d, y_a]
+ def fun_bind_inputs(y, t):
+ return fun(y, t, *args)
- def while_cond(while_state):
- k, converged, _, _, _ = while_state
- return (converged == False) * (k < ROOT_SOLVE_MAXITER) # noqa: E712
+ jac_bind_inputs = jax.jacfwd(fun_bind_inputs, argnums=0)
- def while_body(while_state):
- k, converged, dy_norm_old, d, y_a = while_state
- f_eval = fun_a(y_a)
- dy = jax.scipy.linalg.lu_solve(LU, f_eval)
- dy_norm = jnp.sqrt(jnp.mean((dy / scale_y0_a) ** 2))
- rate = dy_norm / dy_norm_old
+ t0 = t_eval[0]
+ h0 = t_eval[1] - t0
- d += dy
- y_a = y0_a + d
+ stepper = _bdf_init(
+ fun_bind_inputs, jac_bind_inputs, mass, t0, y0, h0, rtol, atol
+ )
+ i = 0
+ y_out = jnp.empty((len(t_eval), len(y0)), dtype=y0.dtype)
+
+ init_state = [stepper, t_eval, i, y_out]
+
+ def cond_fun(state):
+ _, t_eval, i, _ = state
+ return i < len(t_eval)
+
+ def body_fun(state):
+ stepper, t_eval, i, y_out = state
+ stepper = _bdf_step(stepper, fun_bind_inputs, jac_bind_inputs)
+ index = jnp.searchsorted(t_eval, stepper.t)
+
+ def for_body(j, y_out):
+ t = t_eval[j]
+ y_out = jax.ops.index_update(
+ y_out, jax.ops.index[j, :], _bdf_interpolate(stepper, t)
+ )
+ return y_out
+
+ y_out = jax.lax.fori_loop(i, index, for_body, y_out)
+ return [stepper, t_eval, index, y_out]
+
+ stepper, t_eval, i, y_out = jax.lax.while_loop(cond_fun, body_fun, init_state)
+ return y_out
+
+ BDFInternalStates = [
+ "t",
+ "atol",
+ "rtol",
+ "M",
+ "newton_tol",
+ "order",
+ "h",
+ "n_equal_steps",
+ "D",
+ "y0",
+ "scale_y0",
+ "kappa",
+ "gamma",
+ "alpha",
+ "c",
+ "error_const",
+ "J",
+ "LU",
+ "U",
+ "psi",
+ "n_function_evals",
+ "n_jacobian_evals",
+ "n_lu_decompositions",
+ "n_steps",
+ "consistent_y0_failed",
+ ]
+ BDFState = collections.namedtuple("BDFState", BDFInternalStates)
- # if converged then break out of iteration early
- pred = dy_norm_old >= 0.0
- pred *= rate / (1 - rate) * dy_norm < tol
- converged = (dy_norm == 0.0) + pred
+ jax.tree_util.register_pytree_node(
+ BDFState, lambda xs: (tuple(xs), None), lambda _, xs: BDFState(*xs)
+ )
- dy_norm_old = dy_norm
+ def _bdf_init(fun, jac, mass, t0, y0, h0, rtol, atol):
+ """
+ Initiation routine for Backward Difference formula (BDF) implicit multistep
+ integrator.
+
+ See _bdf_odeint function above for details, this function returns a dict with
+ the initial state of the solver
+
+ Parameters
+ ----------
+
+ fun: callable
+ function with signature (y, t), where t is a scalar time and y is a ndarray
+ with shape (n,), returns the rhs of the system of ODE equations as an nd
+ array with shape (n,)
+ jac: callable
+ function with signature (y, t), where t is a scalar time and y is a ndarray
+ with shape (n,), returns the jacobian matrix of fun as an ndarray with
+ shape (n,n)
+ mass: ndarray
+ diagonal of the mass matrix with shape (n,)
+ t0: float
+ initial time
+ y0: ndarray
+ initial state vector with shape (n,)
+ h0: float
+ initial step size
+ rtol: (optional) float
+ relative tolerance for the solver
+ atol: (optional) float
+ absolute tolerance for the solver
+ """
- return [k + 1, converged, dy_norm_old, d, y_a]
+ state = {}
+ state["t"] = t0
+ state["atol"] = atol
+ state["rtol"] = rtol
+ state["M"] = mass
+ EPS = jnp.finfo(y0.dtype).eps
+ state["newton_tol"] = jnp.maximum(
+ 10 * EPS / rtol, jnp.minimum(0.03, rtol ** 0.5)
+ )
- k, converged, dy_norm_old, d, y_a = jax.lax.while_loop(
- while_cond, while_body, while_state
- )
- y_tilde = jax.ops.index_update(y0, algebraic_variables, y_a)
+ scale_y0 = atol + rtol * jnp.abs(y0)
+ y0, not_converged = _select_initial_conditions(
+ fun, mass, t0, y0, state["newton_tol"], scale_y0
+ )
+ state["consistent_y0_failed"] = not_converged
+
+ f0 = fun(y0, t0)
+ order = 1
+ state["order"] = order
+ state["h"] = _select_initial_step(atol, rtol, fun, t0, y0, f0, h0)
+ state["n_equal_steps"] = 0
+ D = jnp.empty((MAX_ORDER + 1, len(y0)), dtype=y0.dtype)
+ D = jax.ops.index_update(D, jax.ops.index[0, :], y0)
+ D = jax.ops.index_update(D, jax.ops.index[1, :], f0 * state["h"])
+ state["D"] = D
+ state["y0"] = y0
+ state["scale_y0"] = scale_y0
+
+ # kappa values for difference orders, taken from Table 1 of [1]
+ kappa = jnp.array([0, -0.1850, -1 / 9, -0.0823, -0.0415, 0])
+ gamma = jnp.hstack((0, jnp.cumsum(1 / jnp.arange(1, MAX_ORDER + 1))))
+ alpha = 1.0 / ((1 - kappa) * gamma)
+ c = state["h"] * alpha[order]
+ error_const = kappa * gamma + 1 / jnp.arange(1, MAX_ORDER + 2)
+
+ state["kappa"] = kappa
+ state["gamma"] = gamma
+ state["alpha"] = alpha
+ state["c"] = c
+ state["error_const"] = error_const
+
+ J = jac(y0, t0)
+ state["J"] = J
+
+ state["LU"] = jax.scipy.linalg.lu_factor(state["M"] - c * J)
+
+ state["U"] = _compute_R(order, 1)
+ state["psi"] = None
+
+ state["n_function_evals"] = 2
+ state["n_jacobian_evals"] = 1
+ state["n_lu_decompositions"] = 1
+ state["n_steps"] = 0
+
+ tuple_state = BDFState(*[state[k] for k in BDFInternalStates])
+ y0, scale_y0 = _predict(tuple_state, D)
+ psi = _update_psi(tuple_state, D)
+ return tuple_state._replace(y0=y0, scale_y0=scale_y0, psi=psi)
+
+ def _compute_R(order, factor):
+ """
+ computes the R matrix with entries
+ given by the first equation on page 8 of [1]
- return y_tilde, converged
+ This is used to update the differences matrix when step size h is varied
+ according to factor = h_{n+1} / h_n
+ Note that the U matrix also defined in the same section can be also be
+ found using factor = 1, which corresponds to R with a constant step size
+ """
+ I = jnp.arange(1, MAX_ORDER + 1).reshape(-1, 1)
+ J = jnp.arange(1, MAX_ORDER + 1)
+ M = jnp.empty((MAX_ORDER + 1, MAX_ORDER + 1))
+ M = jax.ops.index_update(M, jax.ops.index[1:, 1:], (I - 1 - factor * J) / I)
+ M = jax.ops.index_update(M, jax.ops.index[0], 1)
+ R = jnp.cumprod(M, axis=0)
+
+ return R
+
+ def _select_initial_conditions(fun, M, t0, y0, tol, scale_y0):
+ # identify algebraic variables as zeros on diagonal
+ algebraic_variables = onp.diag(M) == 0.0
+
+ # if all differentiable variables then return y0 (can use normal python if
+ # since M is static)
+ if not onp.any(algebraic_variables):
+ return y0, False
+
+ # calculate consistent initial conditions via a newton on -J_a @ delta = f_a
+ # This follows this reference:
+ #
+ # Shampine, L. F., Reichelt, M. W., & Kierzenka, J. A. (1999).
+ # Solving index-1 DAEs in MATLAB and Simulink. SIAM review, 41(3), 538-552.
-def _select_initial_step(atol, rtol, fun, t0, y0, f0, h0):
- """
- Select a good initial step by stepping forward one step of forward euler, and
- comparing the predicted state against that using the provided function.
+ # calculate fun_a, function of algebraic variables
+ def fun_a(y_a):
+ y_full = jax.ops.index_update(y0, algebraic_variables, y_a)
+ return fun(y_full, t0)[algebraic_variables]
- Optimal step size based on the selected order is obtained using formula (4.12)
- in [1]
+ y0_a = y0[algebraic_variables]
+ scale_y0_a = scale_y0[algebraic_variables]
- References
- ----------
- .. [1] E. Hairer, S. P. Norsett G. Wanner, "Solving Ordinary Differential
- Equations I: Nonstiff Problems", Sec. II.4.
- """
- scale = atol + jnp.abs(y0) * rtol
- y1 = y0 + h0 * f0
- f1 = fun(y1, t0 + h0)
- d2 = jnp.sqrt(jnp.mean(((f1 - f0) / scale) ** 2))
- order = 1
- h1 = h0 * d2 ** (-1 / (order + 1))
- return jnp.minimum(100 * h0, h1)
+ d = jnp.zeros(y0_a.shape[0], dtype=y0.dtype)
+ y_a = jnp.array(y0_a, copy=True)
+ # calculate neg jacobian of fun_a
+ J_a = jax.jacfwd(fun_a)(y_a)
+ LU = jax.scipy.linalg.lu_factor(-J_a)
-def _predict(state, D):
- """
- predict forward to new step (eq 2 in [1])
- """
- n = len(state.y0)
- order = state.order
- orders = jnp.repeat(jnp.arange(MAX_ORDER + 1).reshape(-1, 1), n, axis=1)
- subD = jnp.where(orders <= order, D, 0)
- y0 = jnp.sum(subD, axis=0)
- scale_y0 = state.atol + state.rtol * jnp.abs(state.y0)
- return y0, scale_y0
+ converged = False
+ dy_norm_old = -1.0
+ k = 0
+ while_state = [k, converged, dy_norm_old, d, y_a]
+ def while_cond(while_state):
+ k, converged, _, _, _ = while_state
+ return (converged == False) * (k < ROOT_SOLVE_MAXITER) # noqa: E712
-def _update_psi(state, D):
- """
- update psi term as defined in second equation on page 9 of [1]
- """
- order = state.order
- n = len(state.y0)
- orders = jnp.arange(MAX_ORDER + 1)
- subGamma = jnp.where(orders > 0, jnp.where(orders <= order, state.gamma, 0), 0)
- orders = jnp.repeat(orders.reshape(-1, 1), n, axis=1)
- subD = jnp.where(orders > 0, jnp.where(orders <= order, D, 0), 0)
- psi = jnp.dot(subD.T, subGamma) * state.alpha[order]
- return psi
+ def while_body(while_state):
+ k, converged, dy_norm_old, d, y_a = while_state
+ f_eval = fun_a(y_a)
+ dy = jax.scipy.linalg.lu_solve(LU, f_eval)
+ dy_norm = jnp.sqrt(jnp.mean((dy / scale_y0_a) ** 2))
+ rate = dy_norm / dy_norm_old
+ d += dy
+ y_a = y0_a + d
-def _update_difference_for_next_step(state, d):
- """
- update of difference equations can be done efficiently
- by reusing d and D.
+ # if converged then break out of iteration early
+ pred = dy_norm_old >= 0.0
+ pred *= rate / (1 - rate) * dy_norm < tol
+ converged = (dy_norm == 0.0) + pred
- From first equation on page 4 of [1]:
- d = y_n - y^0_n = D^{k + 1} y_n
+ dy_norm_old = dy_norm
- Standard backwards difference gives
- D^{j + 1} y_n = D^{j} y_n - D^{j} y_{n - 1}
+ return [k + 1, converged, dy_norm_old, d, y_a]
- Combining these gives the following algorithm
- """
- order = state.order
- D = state.D
- D = jax.ops.index_update(D, jax.ops.index[order + 2], d - D[order + 1])
- D = jax.ops.index_update(D, jax.ops.index[order + 1], d)
- i = order
- while_state = [i, D]
+ k, converged, dy_norm_old, d, y_a = jax.lax.while_loop(
+ while_cond, while_body, while_state
+ )
+ y_tilde = jax.ops.index_update(y0, algebraic_variables, y_a)
- def while_cond(while_state):
- i, _ = while_state
- return i >= 0
+ return y_tilde, converged
- def while_body(while_state):
- i, D = while_state
- D = jax.ops.index_add(D, jax.ops.index[i], D[i + 1])
- i -= 1
- return [i, D]
+ def _select_initial_step(atol, rtol, fun, t0, y0, f0, h0):
+ """
+ Select a good initial step by stepping forward one step of forward euler, and
+ comparing the predicted state against that using the provided function.
- i, D = jax.lax.while_loop(while_cond, while_body, while_state)
+ Optimal step size based on the selected order is obtained using formula (4.12)
+ in [1]
- return D
+ References
+ ----------
+ .. [1] E. Hairer, S. P. Norsett G. Wanner, "Solving Ordinary Differential
+ Equations I: Nonstiff Problems", Sec. II.4.
+ """
+ scale = atol + jnp.abs(y0) * rtol
+ y1 = y0 + h0 * f0
+ f1 = fun(y1, t0 + h0)
+ d2 = jnp.sqrt(jnp.mean(((f1 - f0) / scale) ** 2))
+ order = 1
+ h1 = h0 * d2 ** (-1 / (order + 1))
+ return jnp.minimum(100 * h0, h1)
+
+ def _predict(state, D):
+ """
+ predict forward to new step (eq 2 in [1])
+ """
+ n = len(state.y0)
+ order = state.order
+ orders = jnp.repeat(jnp.arange(MAX_ORDER + 1).reshape(-1, 1), n, axis=1)
+ subD = jnp.where(orders <= order, D, 0)
+ y0 = jnp.sum(subD, axis=0)
+ scale_y0 = state.atol + state.rtol * jnp.abs(state.y0)
+ return y0, scale_y0
+
+ def _update_psi(state, D):
+ """
+ update psi term as defined in second equation on page 9 of [1]
+ """
+ order = state.order
+ n = len(state.y0)
+ orders = jnp.arange(MAX_ORDER + 1)
+ subGamma = jnp.where(orders > 0, jnp.where(orders <= order, state.gamma, 0), 0)
+ orders = jnp.repeat(orders.reshape(-1, 1), n, axis=1)
+ subD = jnp.where(orders > 0, jnp.where(orders <= order, D, 0), 0)
+ psi = jnp.dot(subD.T, subGamma) * state.alpha[order]
+ return psi
+
+ def _update_difference_for_next_step(state, d):
+ """
+ update of difference equations can be done efficiently
+ by reusing d and D.
+ From first equation on page 4 of [1]:
+ d = y_n - y^0_n = D^{k + 1} y_n
-def _update_step_size_and_lu(state, factor):
- state = _update_step_size(state, factor)
+ Standard backwards difference gives
+ D^{j + 1} y_n = D^{j} y_n - D^{j} y_{n - 1}
- # redo lu (c has changed)
- LU = jax.scipy.linalg.lu_factor(state.M - state.c * state.J)
- n_lu_decompositions = state.n_lu_decompositions + 1
+ Combining these gives the following algorithm
+ """
+ order = state.order
+ D = state.D
+ D = jax.ops.index_update(D, jax.ops.index[order + 2], d - D[order + 1])
+ D = jax.ops.index_update(D, jax.ops.index[order + 1], d)
+ i = order
+ while_state = [i, D]
- return state._replace(LU=LU, n_lu_decompositions=n_lu_decompositions)
+ def while_cond(while_state):
+ i, _ = while_state
+ return i >= 0
+ def while_body(while_state):
+ i, D = while_state
+ D = jax.ops.index_add(D, jax.ops.index[i], D[i + 1])
+ i -= 1
+ return [i, D]
-def _update_step_size(state, factor):
- """
- If step size h is changed then also need to update the terms in
- the first equation of page 9 of [1]:
+ i, D = jax.lax.while_loop(while_cond, while_body, while_state)
- - constant c = h / (1-kappa) gamma_k term
- - lu factorisation of (M - c * J) used in newton iteration (same equation)
- - psi term
- """
- order = state.order
- h = state.h * factor
- n_equal_steps = 0
- c = h * state.alpha[order]
-
- # update D using equations in section 3.2 of [1]
- RU = _compute_R(order, factor).dot(state.U)
- I = jnp.arange(0, MAX_ORDER + 1).reshape(-1, 1)
- J = jnp.arange(0, MAX_ORDER + 1)
-
- # only update order+1, order+1 entries of D
- RU = jnp.where(
- jnp.logical_and(I <= order, J <= order), RU, jnp.identity(MAX_ORDER + 1)
- )
- D = state.D
- D = jnp.dot(RU.T, D)
- # D = jax.ops.index_update(D, jax.ops.index[:order + 1],
- # jnp.dot(RU.T, D[:order + 1]))
+ return D
- # update psi (D has changed)
- psi = _update_psi(state, D)
+ def _update_step_size_and_lu(state, factor):
+ state = _update_step_size(state, factor)
- # update y0 (D has changed)
- y0, scale_y0 = _predict(state, D)
+ # redo lu (c has changed)
+ LU = jax.scipy.linalg.lu_factor(state.M - state.c * state.J)
+ n_lu_decompositions = state.n_lu_decompositions + 1
- return state._replace(
- n_equal_steps=n_equal_steps, h=h, c=c, D=D, psi=psi, y0=y0, scale_y0=scale_y0
- )
+ return state._replace(LU=LU, n_lu_decompositions=n_lu_decompositions)
+ def _update_step_size(state, factor):
+ """
+ If step size h is changed then also need to update the terms in
+ the first equation of page 9 of [1]:
-def _update_jacobian(state, jac):
- """
- we update the jacobian using J(t_{n+1}, y^0_{n+1})
- following the scipy bdf implementation rather than J(t_n, y_n) as per [1]
- """
- J = jac(state.y0, state.t + state.h)
- n_jacobian_evals = state.n_jacobian_evals + 1
- LU = jax.scipy.linalg.lu_factor(state.M - state.c * J)
- n_lu_decompositions = state.n_lu_decompositions + 1
- return state._replace(
- J=J,
- n_jacobian_evals=n_jacobian_evals,
- LU=LU,
- n_lu_decompositions=n_lu_decompositions,
- )
+ - constant c = h / (1-kappa) gamma_k term
+ - lu factorisation of (M - c * J) used in newton iteration (same equation)
+ - psi term
+ """
+ order = state.order
+ h = state.h * factor
+ n_equal_steps = 0
+ c = h * state.alpha[order]
+
+ # update D using equations in section 3.2 of [1]
+ RU = _compute_R(order, factor).dot(state.U)
+ I = jnp.arange(0, MAX_ORDER + 1).reshape(-1, 1)
+ J = jnp.arange(0, MAX_ORDER + 1)
+
+ # only update order+1, order+1 entries of D
+ RU = jnp.where(
+ jnp.logical_and(I <= order, J <= order), RU, jnp.identity(MAX_ORDER + 1)
+ )
+ D = state.D
+ D = jnp.dot(RU.T, D)
+ # D = jax.ops.index_update(D, jax.ops.index[:order + 1],
+ # jnp.dot(RU.T, D[:order + 1]))
+
+ # update psi (D has changed)
+ psi = _update_psi(state, D)
+
+ # update y0 (D has changed)
+ y0, scale_y0 = _predict(state, D)
+
+ return state._replace(
+ n_equal_steps=n_equal_steps,
+ h=h,
+ c=c,
+ D=D,
+ psi=psi,
+ y0=y0,
+ scale_y0=scale_y0,
+ )
+ def _update_jacobian(state, jac):
+ """
+ we update the jacobian using J(t_{n+1}, y^0_{n+1})
+ following the scipy bdf implementation rather than J(t_n, y_n) as per [1]
+ """
+ J = jac(state.y0, state.t + state.h)
+ n_jacobian_evals = state.n_jacobian_evals + 1
+ LU = jax.scipy.linalg.lu_factor(state.M - state.c * J)
+ n_lu_decompositions = state.n_lu_decompositions + 1
+ return state._replace(
+ J=J,
+ n_jacobian_evals=n_jacobian_evals,
+ LU=LU,
+ n_lu_decompositions=n_lu_decompositions,
+ )
-def _newton_iteration(state, fun):
- tol = state.newton_tol
- c = state.c
- psi = state.psi
- y0 = state.y0
- LU = state.LU
- M = state.M
- scale_y0 = state.scale_y0
- t = state.t + state.h
- d = jnp.zeros(y0.shape, dtype=y0.dtype)
- y = jnp.array(y0, copy=True)
- n_function_evals = state.n_function_evals
-
- converged = False
- dy_norm_old = -1.0
- k = 0
- while_state = [k, converged, dy_norm_old, d, y, n_function_evals]
-
- def while_cond(while_state):
- k, converged, _, _, _, _ = while_state
- return (converged == False) * (k < NEWTON_MAXITER) # noqa: E712
-
- def while_body(while_state):
- k, converged, dy_norm_old, d, y, n_function_evals = while_state
- f_eval = fun(y, t)
- n_function_evals += 1
- b = c * f_eval - M @ (psi + d)
- dy = jax.scipy.linalg.lu_solve(LU, b)
- dy_norm = jnp.sqrt(jnp.mean((dy / scale_y0) ** 2))
- rate = dy_norm / dy_norm_old
-
- # if iteration is not going to converge in NEWTON_MAXITER
- # (assuming the current rate), then abort
- pred = rate >= 1
- pred += rate ** (NEWTON_MAXITER - k) / (1 - rate) * dy_norm > tol
- pred *= dy_norm_old >= 0
- k += pred * (NEWTON_MAXITER - k - 1)
-
- d += dy
- y = y0 + d
-
- # if converged then break out of iteration early
- pred = dy_norm_old >= 0.0
- pred *= rate / (1 - rate) * dy_norm < tol
- converged = (dy_norm == 0.0) + pred
-
- dy_norm_old = dy_norm
-
- return [k + 1, converged, dy_norm_old, d, y, n_function_evals]
-
- k, converged, dy_norm_old, d, y, n_function_evals = jax.lax.while_loop(
- while_cond, while_body, while_state
- )
- return converged, k, y, d, state._replace(n_function_evals=n_function_evals)
+ def _newton_iteration(state, fun):
+ tol = state.newton_tol
+ c = state.c
+ psi = state.psi
+ y0 = state.y0
+ LU = state.LU
+ M = state.M
+ scale_y0 = state.scale_y0
+ t = state.t + state.h
+ d = jnp.zeros(y0.shape, dtype=y0.dtype)
+ y = jnp.array(y0, copy=True)
+ n_function_evals = state.n_function_evals
+
+ converged = False
+ dy_norm_old = -1.0
+ k = 0
+ while_state = [k, converged, dy_norm_old, d, y, n_function_evals]
+
+ def while_cond(while_state):
+ k, converged, _, _, _, _ = while_state
+ return (converged == False) * (k < NEWTON_MAXITER) # noqa: E712
+
+ def while_body(while_state):
+ k, converged, dy_norm_old, d, y, n_function_evals = while_state
+ f_eval = fun(y, t)
+ n_function_evals += 1
+ b = c * f_eval - M @ (psi + d)
+ dy = jax.scipy.linalg.lu_solve(LU, b)
+ dy_norm = jnp.sqrt(jnp.mean((dy / scale_y0) ** 2))
+ rate = dy_norm / dy_norm_old
+
+ # if iteration is not going to converge in NEWTON_MAXITER
+ # (assuming the current rate), then abort
+ pred = rate >= 1
+ pred += rate ** (NEWTON_MAXITER - k) / (1 - rate) * dy_norm > tol
+ pred *= dy_norm_old >= 0
+ k += pred * (NEWTON_MAXITER - k - 1)
+
+ d += dy
+ y = y0 + d
+
+ # if converged then break out of iteration early
+ pred = dy_norm_old >= 0.0
+ pred *= rate / (1 - rate) * dy_norm < tol
+ converged = (dy_norm == 0.0) + pred
+
+ dy_norm_old = dy_norm
+
+ return [k + 1, converged, dy_norm_old, d, y, n_function_evals]
+
+ k, converged, dy_norm_old, d, y, n_function_evals = jax.lax.while_loop(
+ while_cond, while_body, while_state
+ )
+ return converged, k, y, d, state._replace(n_function_evals=n_function_evals)
+ def rms_norm(arg):
+ return jnp.sqrt(jnp.mean(arg ** 2))
-def rms_norm(arg):
- return jnp.sqrt(jnp.mean(arg ** 2))
+ def _prepare_next_step(state, d):
+ D = _update_difference_for_next_step(state, d)
+ psi = _update_psi(state, D)
+ y0, scale_y0 = _predict(state, D)
+ return state._replace(D=D, psi=psi, y0=y0, scale_y0=scale_y0)
+ def _prepare_next_step_order_change(state, d, y, n_iter):
+ order = state.order
-def _prepare_next_step(state, d):
- D = _update_difference_for_next_step(state, d)
- psi = _update_psi(state, D)
- y0, scale_y0 = _predict(state, D)
- return state._replace(D=D, psi=psi, y0=y0, scale_y0=scale_y0)
+ D = _update_difference_for_next_step(state, d)
+ # Note: we are recalculating these from the while loop above, could re-use?
+ scale_y = state.atol + state.rtol * jnp.abs(y)
+ error = state.error_const[order] * d
+ error_norm = rms_norm(error / scale_y)
+ safety = 0.9 * (2 * NEWTON_MAXITER + 1) / (2 * NEWTON_MAXITER + n_iter)
-def _prepare_next_step_order_change(state, d, y, n_iter):
- order = state.order
+ # similar to the optimal step size factor we calculated above for the current
+ # order k, we need to calculate the optimal step size factors for orders
+ # k-1 and k+1. To do this, we note that the error = C_k * D^{k+1} y_n
+ error_m_norm = jnp.where(
+ order > 1,
+ rms_norm(state.error_const[order - 1] * D[order] / scale_y),
+ jnp.inf,
+ )
+ error_p_norm = jnp.where(
+ order < MAX_ORDER,
+ rms_norm(state.error_const[order + 1] * D[order + 2] / scale_y),
+ jnp.inf,
+ )
- D = _update_difference_for_next_step(state, d)
+ error_norms = jnp.array([error_m_norm, error_norm, error_p_norm])
+ factors = error_norms ** (-1 / (jnp.arange(3) + order))
+
+ # now we have the three factors for orders k-1, k and k+1, pick the maximum in
+ # order to maximise the resultant step size
+ max_index = jnp.argmax(factors)
+ order += max_index - 1
+
+ factor = jnp.minimum(MAX_FACTOR, safety * factors[max_index])
+
+ new_state = _update_step_size_and_lu(state._replace(D=D, order=order), factor)
+ return new_state
+
+ def _bdf_step(state, fun, jac):
+ # print('bdf_step', state.t, state.h)
+ # we will try and use the old jacobian unless convergence of newton iteration
+ # fails
+ updated_jacobian = False
+ # initialise step size and try to make the step,
+ # iterate, reducing step size until error is in bounds
+ step_accepted = False
+ y = jnp.empty_like(state.y0)
+ d = jnp.empty_like(state.y0)
+ n_iter = -1
+
+ # loop until step is accepted
+ while_state = [state, step_accepted, updated_jacobian, y, d, n_iter]
+
+ def while_cond(while_state):
+ _, step_accepted, _, _, _, _ = while_state
+ return step_accepted == False # noqa: E712
+
+ def while_body(while_state):
+ state, step_accepted, updated_jacobian, y, d, n_iter = while_state
+
+ # solve BDF equation using y0 as starting point
+ converged, n_iter, y, d, state = _newton_iteration(state, fun)
+ not_converged = converged == False # noqa: E712
+
+ # newton iteration did not converge, but jacobian has already been
+ # evaluated so reduce step size by 0.3 (as per [1]) and try again
+ state = tree_multimap(
+ partial(jnp.where, not_converged * updated_jacobian),
+ _update_step_size_and_lu(state, 0.3),
+ state,
+ )
- # Note: we are recalculating these from the while loop above, could re-use?
- scale_y = state.atol + state.rtol * jnp.abs(y)
- error = state.error_const[order] * d
- error_norm = rms_norm(error / scale_y)
- safety = 0.9 * (2 * NEWTON_MAXITER + 1) / (2 * NEWTON_MAXITER + n_iter)
+ # if not_converged * updated_jacobian:
+ # print('not converged, update step size by 0.3')
+ # if not_converged * (updated_jacobian == False):
+ # print('not converged, update jacobian')
- # similar to the optimal step size factor we calculated above for the current
- # order k, we need to calculate the optimal step size factors for orders
- # k-1 and k+1. To do this, we note that the error = C_k * D^{k+1} y_n
- error_m_norm = jnp.where(
- order > 1, rms_norm(state.error_const[order - 1] * D[order] / scale_y), jnp.inf
- )
- error_p_norm = jnp.where(
- order < MAX_ORDER,
- rms_norm(state.error_const[order + 1] * D[order + 2] / scale_y),
- jnp.inf,
- )
+ # if not converged and jacobian not updated, then update the jacobian and
+ # try again
+ (state, updated_jacobian) = tree_multimap(
+ partial(
+ jnp.where, not_converged * (updated_jacobian == False) # noqa: E712
+ ),
+ (_update_jacobian(state, jac), True),
+ (state, False + updated_jacobian),
+ )
- error_norms = jnp.array([error_m_norm, error_norm, error_p_norm])
- factors = error_norms ** (-1 / (jnp.arange(3) + order))
+ safety = 0.9 * (2 * NEWTON_MAXITER + 1) / (2 * NEWTON_MAXITER + n_iter)
+ scale_y = state.atol + state.rtol * jnp.abs(y)
- # now we have the three factors for orders k-1, k and k+1, pick the maximum in
- # order to maximise the resultant step size
- max_index = jnp.argmax(factors)
- order += max_index - 1
+ # combine eq 3, 4 and 6 from [1] to obtain error
+ # Note that error = C_k * h^{k+1} y^{k+1}
+ # and d = D^{k+1} y_{n+1} \approx h^{k+1} y^{k+1}
+ error = state.error_const[state.order] * d
- factor = jnp.minimum(MAX_FACTOR, safety * factors[max_index])
+ error_norm = rms_norm(error / scale_y)
- new_state = _update_step_size_and_lu(state._replace(D=D, order=order), factor)
- return new_state
+ # calculate optimal step size factor as per eq 2.46 of [2]
+ factor = jnp.maximum(
+ MIN_FACTOR, safety * error_norm ** (-1 / (state.order + 1))
+ )
+ # if converged * (error_norm > 1):
+ # print(
+ # "converged, but error is too large",
+ # error_norm,
+ # factor,
+ # d,
+ # scale_y,
+ # )
+
+ (state, step_accepted) = tree_multimap(
+ partial(jnp.where, converged * (error_norm > 1)), # noqa: E712
+ (_update_step_size_and_lu(state, factor), False),
+ (state, converged),
+ )
-def _bdf_step(state, fun, jac):
- # print('bdf_step', state.t, state.h)
- # we will try and use the old jacobian unless convergence of newton iteration
- # fails
- updated_jacobian = False
- # initialise step size and try to make the step,
- # iterate, reducing step size until error is in bounds
- step_accepted = False
- y = jnp.empty_like(state.y0)
- d = jnp.empty_like(state.y0)
- n_iter = -1
+ return [state, step_accepted, updated_jacobian, y, d, n_iter]
- # loop until step is accepted
- while_state = [state, step_accepted, updated_jacobian, y, d, n_iter]
+ state, step_accepted, updated_jacobian, y, d, n_iter = jax.lax.while_loop(
+ while_cond, while_body, while_state
+ )
- def while_cond(while_state):
- _, step_accepted, _, _, _, _ = while_state
- return step_accepted == False # noqa: E712
+ # take the accepted step
+ n_steps = state.n_steps + 1
+ t = state.t + state.h
- def while_body(while_state):
- state, step_accepted, updated_jacobian, y, d, n_iter = while_state
+ # a change in order is only done after running at order k for k + 1 steps
+ # (see page 83 of [2])
+ n_equal_steps = state.n_equal_steps + 1
- # solve BDF equation using y0 as starting point
- converged, n_iter, y, d, state = _newton_iteration(state, fun)
- not_converged = converged == False # noqa: E712
+ state = state._replace(n_equal_steps=n_equal_steps, t=t, n_steps=n_steps)
- # newton iteration did not converge, but jacobian has already been
- # evaluated so reduce step size by 0.3 (as per [1]) and try again
state = tree_multimap(
- partial(jnp.where, not_converged * updated_jacobian),
- _update_step_size_and_lu(state, 0.3),
- state,
- )
-
- # if not_converged * updated_jacobian:
- # print('not converged, update step size by 0.3')
- # if not_converged * (updated_jacobian == False):
- # print('not converged, update jacobian')
-
- # if not converged and jacobian not updated, then update the jacobian and try
- # again
- (state, updated_jacobian) = tree_multimap(
- partial(
- jnp.where, not_converged * (updated_jacobian == False) # noqa: E712
- ),
- (_update_jacobian(state, jac), True),
- (state, False + updated_jacobian),
+ partial(jnp.where, n_equal_steps < state.order + 1),
+ _prepare_next_step(state, d),
+ _prepare_next_step_order_change(state, d, y, n_iter),
)
- safety = 0.9 * (2 * NEWTON_MAXITER + 1) / (2 * NEWTON_MAXITER + n_iter)
- scale_y = state.atol + state.rtol * jnp.abs(y)
-
- # combine eq 3, 4 and 6 from [1] to obtain error
- # Note that error = C_k * h^{k+1} y^{k+1}
- # and d = D^{k+1} y_{n+1} \approx h^{k+1} y^{k+1}
- error = state.error_const[state.order] * d
+ return state
- error_norm = rms_norm(error / scale_y)
+ def _bdf_interpolate(state, t_eval):
+ """
+ interpolate solution at time values t* where t-h < t* < t
- # calculate optimal step size factor as per eq 2.46 of [2]
- factor = jnp.maximum(
- MIN_FACTOR, safety * error_norm ** (-1 / (state.order + 1))
+ definition of the interpolating polynomial can be found on page 7 of [1]
+ """
+ order = state.order
+ t = state.t
+ h = state.h
+ D = state.D
+ j = 0
+ time_factor = 1.0
+ order_summation = D[0]
+ while_state = [j, time_factor, order_summation]
+
+ def while_cond(while_state):
+ j, _, _ = while_state
+ return j < order
+
+ def while_body(while_state):
+ j, time_factor, order_summation = while_state
+ time_factor *= (t_eval - (t - h * j)) / (h * (1 + j))
+ order_summation += D[j + 1] * time_factor
+ j += 1
+ return [j, time_factor, order_summation]
+
+ j, time_factor, order_summation = jax.lax.while_loop(
+ while_cond, while_body, while_state
)
+ return order_summation
+
+ def block_diag(lst):
+ def block_fun(i, j, Ai, Aj):
+ if i == j:
+ return Ai
+ else:
+ return onp.zeros(
+ (
+ Ai.shape[0] if Ai.ndim > 1 else 1,
+ Aj.shape[1] if Aj.ndim > 1 else 1,
+ ),
+ dtype=Ai.dtype,
+ )
+
+ blocks = [
+ [block_fun(i, j, Ai, Aj) for j, Aj in enumerate(lst)]
+ for i, Ai in enumerate(lst)
+ ]
+
+ return onp.block(blocks)
+
+ # NOTE: the code below (except the docstring on jax_bdf_integrate and other minor
+ # edits), has been modified from the JAX library at https://github.com/google/jax.
+ # The main difference is the addition of support for semi-explicit dae index 1
+ # problems via the addition of a mass matrix.
+ # This is under an Apache license, a short form of which is given here:
+ #
+ # Copyright 2018 Google LLC
+ #
+ # Licensed under the Apache License, Version 2.0 (the "License"); you may not use
+ # this file except in compliance with the License.
+ # You may obtain a copy of the License at
+ #
+ # https://www.apache.org/licenses/LICENSE-2.0
+ #
+ # Unless required by applicable law or agreed to in writing, software distributed
+ # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
+ # CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ # specific language governing permissions and limitations under the License.
- # if converged * (error_norm > 1):
- # print('converged, but error is too large',error_norm, factor, d, scale_y)
+ def flax_while_loop(cond_fun, body_fun, init_val): # pragma: no cover
+ """
+ for debugging purposes, use this instead of jax.lax.while_loop
+ """
+ val = init_val
+ while cond_fun(val):
+ val = body_fun(val)
+ return val
- (state, step_accepted) = tree_multimap(
- partial(jnp.where, converged * (error_norm > 1)), # noqa: E712
- (_update_step_size_and_lu(state, factor), False),
- (state, converged),
- )
+ def flax_fori_loop(start, stop, body_fun, init_val): # pragma: no cover
+ """
+ for debugging purposes, use this instead of jax.lax.fori_loop
+ """
+ val = init_val
+ for i in range(start, stop):
+ val = body_fun(i, val)
+ return val
- return [state, step_accepted, updated_jacobian, y, d, n_iter]
+ def flax_scan(f, init, xs, length=None): # pragma: no cover
+ """
+ for debugging purposes, use this instead of jax.lax.scan
+ """
+ if xs is None:
+ xs = [None] * length
+ carry = init
+ ys = []
+ for x in xs:
+ carry, y = f(carry, x)
+ ys.append(y)
+ return carry, onp.stack(ys)
+
+ @jax.partial(gnool_jit, static_array_argnums=(1,), static_argnums=(0, 2, 3))
+ def _bdf_odeint_wrapper(func, mass, rtol, atol, y0, ts, *args):
+ y0, unravel = ravel_pytree(y0)
+ func = ravel_first_arg(func, unravel)
+ out = _bdf_odeint(func, mass, rtol, atol, y0, ts, *args)
+ return jax.vmap(unravel)(out)
+
+ def _bdf_odeint_fwd(func, mass, rtol, atol, y0, ts, *args):
+ ys = _bdf_odeint(func, mass, rtol, atol, y0, ts, *args)
+ return ys, (ys, ts, args)
+
+ def _bdf_odeint_rev(func, mass, rtol, atol, res, g):
+ ys, ts, args = res
+
+ def aug_dynamics(augmented_state, t, *args):
+ """Original system augmented with vjp_y, vjp_t and vjp_args."""
+ y, y_bar, *_ = augmented_state
+ # `t` here is negative time, so we need to negate again to get back to
+ # normal time. See the `odeint` invocation in `scan_fun` below.
+ y_dot, vjpfun = jax.vjp(func, y, -t, *args)
+
+ # Adjoint equations for semi-explicit dae index 1 system from
+ #
+ # [1] Cao, Y., Li, S., Petzold, L., & Serban, R. (2003). Adjoint sensitivity
+ # analysis for differential-algebraic equations: The adjoint DAE system and
+ # its numerical solution.
+ # SIAM journal on scientific computing, 24(3), 1076-1089.
+ #
+ # y_bar_dot_d = -J_dd^T y_bar_d - J_ad^T y_bar_a
+ # 0 = J_da^T y_bar_d + J_aa^T y_bar_d
+
+ y_bar_dot, *rest = vjpfun(y_bar)
+
+ return (-y_dot, y_bar_dot, *rest)
+
+ algebraic_variables = onp.diag(mass) == 0.0
+ differentiable_variables = algebraic_variables == False # noqa: E712
+ mass_is_I = onp.array_equal(mass, onp.eye(mass.shape[0]))
+ is_dae = onp.any(algebraic_variables)
+
+ if not mass_is_I:
+ M_dd = mass[onp.ix_(differentiable_variables, differentiable_variables)]
+ LU_invM_dd = jax.scipy.linalg.lu_factor(M_dd)
+
+ def initialise(g0, y0, t0):
+ # [1] gives init conditions for y_bar_a = g_d - J_ad^T (J_aa^T)^-1 g_a
+ if mass_is_I:
+ y_bar = g0
+ elif is_dae:
+ J = jax.jacfwd(func)(y0, t0, *args)
+
+ # boolean arguments not implemented in jnp.ix_
+ J_aa = J[onp.ix_(algebraic_variables, algebraic_variables)]
+ J_ad = J[onp.ix_(algebraic_variables, differentiable_variables)]
+ LU = jax.scipy.linalg.lu_factor(J_aa)
+ g0_a = g0[algebraic_variables]
+ invJ_aa = jax.scipy.linalg.lu_solve(LU, g0_a)
+ y_bar = jax.ops.index_update(
+ g0,
+ differentiable_variables,
+ jax.scipy.linalg.lu_solve(LU_invM_dd, g0_a - J_ad @ invJ_aa),
+ )
+ else:
+ y_bar = jax.scipy.linalg.lu_solve(LU_invM_dd, g0)
+ return y_bar
+
+ y_bar = initialise(g[-1], ys[-1], ts[-1])
+ ts_bar = []
+ t0_bar = 0.0
+
+ def arg_to_identity(arg):
+ return onp.identity(arg.shape[0] if arg.ndim > 0 else 1, dtype=arg.dtype)
+
+ def arg_dicts_to_values(args):
+ """
+ Note:JAX puts in empty arrays into args for some reason, we remove them here
+ """
+ return sum((tuple(b.values()) for b in args if isinstance(b, dict)), ())
+
+ aug_mass = (mass, mass, onp.array(1.0)) + arg_dicts_to_values(
+ tree_map(arg_to_identity, args)
+ )
- state, step_accepted, updated_jacobian, y, d, n_iter = jax.lax.while_loop(
- while_cond, while_body, while_state
- )
+ def scan_fun(carry, i):
+ y_bar, t0_bar, args_bar = carry
+ # Compute effect of moving measurement time
+ t_bar = jnp.dot(func(ys[i], ts[i], *args), g[i])
+ t0_bar = t0_bar - t_bar
+ # Run augmented system backwards to previous observation
+ _, y_bar, t0_bar, args_bar = jax_bdf_integrate(
+ aug_dynamics,
+ (ys[i], y_bar, t0_bar, args_bar),
+ jnp.array([-ts[i], -ts[i - 1]]),
+ *args,
+ mass=aug_mass,
+ rtol=rtol,
+ atol=atol,
+ )
+ y_bar, t0_bar, args_bar = tree_map(
+ op.itemgetter(1), (y_bar, t0_bar, args_bar)
+ )
+ # Add gradient from current output
+ y_bar = y_bar + initialise(g[i - 1], ys[i - 1], ts[i - 1])
+ return (y_bar, t0_bar, args_bar), t_bar
- # take the accepted step
- n_steps = state.n_steps + 1
- t = state.t + state.h
+ init_carry = (y_bar, t0_bar, tree_map(jnp.zeros_like, args))
+ (y_bar, t0_bar, args_bar), rev_ts_bar = jax.lax.scan(
+ scan_fun, init_carry, jnp.arange(len(ts) - 1, 0, -1)
+ )
+ ts_bar = jnp.concatenate([jnp.array([t0_bar]), rev_ts_bar[::-1]])
+ return (y_bar, ts_bar, *args_bar)
- # a change in order is only done after running at order k for k + 1 steps
- # (see page 83 of [2])
- n_equal_steps = state.n_equal_steps + 1
+ _bdf_odeint.defvjp(_bdf_odeint_fwd, _bdf_odeint_rev)
- state = state._replace(n_equal_steps=n_equal_steps, t=t, n_steps=n_steps)
+ @cache()
+ def closure_convert(fun, in_tree, in_avals):
+ if config.omnistaging_enabled:
+ wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
+ jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
+ else:
+ in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
+ wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
+ with core.initial_style_staging(): # type: ignore
+ jaxpr, _, consts = pe.trace_to_jaxpr(
+ wrapped_fun, in_pvals, instantiate=True, stage_out=False
+ ) # type: ignore
+ out_tree = out_tree()
- state = tree_multimap(
- partial(jnp.where, n_equal_steps < state.order + 1),
- _prepare_next_step(state, d),
- _prepare_next_step_order_change(state, d, y, n_iter),
- )
+ # We only want to closure convert for constants with respect to which we're
+ # differentiating. As a proxy for that, we hoist consts with float dtype.
+ # TODO(mattjj): revise this approach
+ def is_float(c):
+ return dtypes.issubdtype(dtypes.dtype(c), jnp.inexact)
- return state
+ (closure_consts, hoisted_consts), merge = partition_list(is_float, consts)
+ num_consts = len(hoisted_consts)
+ def converted_fun(y, t, *hconsts_args):
+ hoisted_consts, args = split_list(hconsts_args, [num_consts])
+ consts = merge(closure_consts, hoisted_consts)
+ all_args, _ = tree_flatten((y, t, *args))
+ out_flat = core.eval_jaxpr(jaxpr, consts, *all_args)
+ return tree_unflatten(out_tree, out_flat)
-def _bdf_interpolate(state, t_eval):
- """
- interpolate solution at time values t* where t-h < t* < t
+ return converted_fun, hoisted_consts
- definition of the interpolating polynomial can be found on page 7 of [1]
- """
- order = state.order
- t = state.t
- h = state.h
- D = state.D
- j = 0
- time_factor = 1.0
- order_summation = D[0]
- while_state = [j, time_factor, order_summation]
-
- def while_cond(while_state):
- j, _, _ = while_state
- return j < order
-
- def while_body(while_state):
- j, time_factor, order_summation = while_state
- time_factor *= (t_eval - (t - h * j)) / (h * (1 + j))
- order_summation += D[j + 1] * time_factor
- j += 1
- return [j, time_factor, order_summation]
-
- j, time_factor, order_summation = jax.lax.while_loop(
- while_cond, while_body, while_state
- )
- return order_summation
+ def partition_list(choice, lst):
+ out = [], []
+ which = [out[choice(elt)].append(elt) or choice(elt) for elt in lst]
+ def merge(l1, l2):
+ i1, i2 = iter(l1), iter(l2)
+ return [next(i2 if snd else i1) for snd in which]
-def block_diag(lst):
- def block_fun(i, j, Ai, Aj):
- if i == j:
- return Ai
- else:
- return onp.zeros(
- (
- Ai.shape[0] if Ai.ndim > 1 else 1,
- Aj.shape[1] if Aj.ndim > 1 else 1,
- ),
- dtype=Ai.dtype,
- )
+ return out, merge
- blocks = [
- [block_fun(i, j, Ai, Aj) for j, Aj in enumerate(lst)]
- for i, Ai in enumerate(lst)
- ]
+ def abstractify(x):
+ return core.raise_to_shaped(core.get_aval(x))
- return onp.block(blocks)
+ def ravel_first_arg(f, unravel):
+ return ravel_first_arg_(lu.wrap_init(f), unravel).call_wrapped
-
-# NOTE: the code below (except the docstring on jax_bdf_integrate and other minor
-# edits), has been modified from the JAX library at https://github.com/google/jax.
-# The main difference is the addition of support for semi-explicit dae index 1 problems
-# via the addition of a mass matrix.
-# This is under an Apache license, a short form of which is given here:
-#
-# Copyright 2018 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this
-# file except in compliance with the License. You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software distributed under
-# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific language
-# governing permissions and limitations under the License.
+ @lu.transformation
+ def ravel_first_arg_(unravel, y_flat, *args):
+ y = unravel(y_flat)
+ ans = yield (y,) + args, {}
+ ans_flat, _ = ravel_pytree(ans)
+ yield ans_flat
def jax_bdf_integrate(func, y0, t_eval, *args, rtol=1e-6, atol=1e-6, mass=None):
@@ -828,15 +1028,19 @@ def jax_bdf_integrate(func, y0, t_eval, *args, rtol=1e-6, atol=1e-6, mass=None):
References
----------
.. [1] L. F. Shampine, M. W. Reichelt, "THE MATLAB ODE SUITE", SIAM J. SCI.
- COMPUTE., Vol. 18, No. 1, pp. 1-22, January 1997.
+ COMPUTE., Vol. 18, No. 1, pp. 1-22, January 1997.
.. [2] G. D. Byrne, A. C. Hindmarsh, "A Polyalgorithm for the Numerical
- Solution of Ordinary Differential Equations", ACM Transactions on
- Mathematical Software, Vol. 1, No. 1, pp. 71-96, March 1975.
+ Solution of Ordinary Differential Equations", ACM Transactions on
+ Mathematical Software, Vol. 1, No. 1, pp. 71-96, March 1975.
.. [3] Virtanen, P., Gommers, R., Oliphant, T. E., Haberland, M., Reddy,
- T., Cournapeau, D., ... & van der Walt, S. J. (2020). SciPy 1.0:
- fundamental algorithms for scientific computing in Python.
- Nature methods, 17(3), 261-272.
+ T., Cournapeau, D., ... & van der Walt, S. J. (2020). SciPy 1.0:
+ fundamental algorithms for scientific computing in Python.
+ Nature methods, 17(3), 261-272.
"""
+ if not pybamm.have_jax():
+ raise ModuleNotFoundError(
+ "Jax is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501
+ )
def _check_arg(arg):
if not isinstance(arg, core.Tracer) and not core.valid_jaxtype(arg):
@@ -854,213 +1058,3 @@ def _check_arg(arg):
else:
mass = block_diag(tree_flatten(mass)[0])
return _bdf_odeint_wrapper(converted, mass, rtol, atol, y0, t_eval, *consts, *args)
-
-
-def flax_while_loop(cond_fun, body_fun, init_val): # pragma: no cover
- """
- for debugging purposes, use this instead of jax.lax.while_loop
- """
- val = init_val
- while cond_fun(val):
- val = body_fun(val)
- return val
-
-
-def flax_fori_loop(start, stop, body_fun, init_val): # pragma: no cover
- """
- for debugging purposes, use this instead of jax.lax.fori_loop
- """
- val = init_val
- for i in range(start, stop):
- val = body_fun(i, val)
- return val
-
-
-def flax_scan(f, init, xs, length=None): # pragma: no cover
- """
- for debugging purposes, use this instead of jax.lax.scan
- """
- if xs is None:
- xs = [None] * length
- carry = init
- ys = []
- for x in xs:
- carry, y = f(carry, x)
- ys.append(y)
- return carry, onp.stack(ys)
-
-
-@jax.partial(gnool_jit,
- static_array_argnums=(1,), static_argnums=(0, 2, 3))
-def _bdf_odeint_wrapper(func, mass, rtol, atol, y0, ts, *args):
- y0, unravel = ravel_pytree(y0)
- func = ravel_first_arg(func, unravel)
- out = _bdf_odeint(func, mass, rtol, atol, y0, ts, *args)
- return jax.vmap(unravel)(out)
-
-
-def _bdf_odeint_fwd(func, mass, rtol, atol, y0, ts, *args):
- ys = _bdf_odeint(func, mass, rtol, atol, y0, ts, *args)
- return ys, (ys, ts, args)
-
-
-def _bdf_odeint_rev(func, mass, rtol, atol, res, g):
- ys, ts, args = res
-
- def aug_dynamics(augmented_state, t, *args):
- """Original system augmented with vjp_y, vjp_t and vjp_args."""
- y, y_bar, *_ = augmented_state
- # `t` here is negative time, so we need to negate again to get back to
- # normal time. See the `odeint` invocation in `scan_fun` below.
- y_dot, vjpfun = jax.vjp(func, y, -t, *args)
-
- # Adjoint equations for semi-explicit dae index 1 system from
- #
- # [1] Cao, Y., Li, S., Petzold, L., & Serban, R. (2003). Adjoint sensitivity
- # analysis for differential-algebraic equations: The adjoint DAE system and its
- # numerical solution. SIAM journal on scientific computing, 24(3), 1076-1089.
- #
- # y_bar_dot_d = -J_dd^T y_bar_d - J_ad^T y_bar_a
- # 0 = J_da^T y_bar_d + J_aa^T y_bar_d
-
- y_bar_dot, *rest = vjpfun(y_bar)
-
- return (-y_dot, y_bar_dot, *rest)
-
- algebraic_variables = onp.diag(mass) == 0.0
- differentiable_variables = algebraic_variables == False # noqa: E712
- mass_is_I = onp.array_equal(mass, onp.eye(mass.shape[0]))
- is_dae = onp.any(algebraic_variables)
-
- if not mass_is_I:
- M_dd = mass[onp.ix_(differentiable_variables, differentiable_variables)]
- LU_invM_dd = jax.scipy.linalg.lu_factor(M_dd)
-
- def initialise(g0, y0, t0):
- # [1] gives init conditions for y_bar_a = g_d - J_ad^T (J_aa^T)^-1 g_a
- if mass_is_I:
- y_bar = g0
- elif is_dae:
- J = jax.jacfwd(func)(y0, t0, *args)
-
- # boolean arguments not implemented in jnp.ix_
- J_aa = J[onp.ix_(algebraic_variables, algebraic_variables)]
- J_ad = J[onp.ix_(algebraic_variables, differentiable_variables)]
- LU = jax.scipy.linalg.lu_factor(J_aa)
- g0_a = g0[algebraic_variables]
- invJ_aa = jax.scipy.linalg.lu_solve(LU, g0_a)
- y_bar = jax.ops.index_update(
- g0,
- differentiable_variables,
- jax.scipy.linalg.lu_solve(LU_invM_dd, g0_a - J_ad @ invJ_aa),
- )
- else:
- y_bar = jax.scipy.linalg.lu_solve(LU_invM_dd, g0)
- return y_bar
-
- y_bar = initialise(g[-1], ys[-1], ts[-1])
- ts_bar = []
- t0_bar = 0.0
-
- def arg_to_identity(arg):
- return onp.identity(arg.shape[0] if arg.ndim > 0 else 1, dtype=arg.dtype)
-
- def arg_dicts_to_values(args):
- """
- Note: JAX puts in empty arrays into args for some reason, we remove them here
- """
- return sum((tuple(b.values()) for b in args if isinstance(b, dict)), ())
-
- aug_mass = (mass, mass, onp.array(1.0)) + arg_dicts_to_values(
- tree_map(arg_to_identity, args)
- )
-
- def scan_fun(carry, i):
- y_bar, t0_bar, args_bar = carry
- # Compute effect of moving measurement time
- t_bar = jnp.dot(func(ys[i], ts[i], *args), g[i])
- t0_bar = t0_bar - t_bar
- # Run augmented system backwards to previous observation
- _, y_bar, t0_bar, args_bar = jax_bdf_integrate(
- aug_dynamics,
- (ys[i], y_bar, t0_bar, args_bar),
- jnp.array([-ts[i], -ts[i - 1]]),
- *args,
- mass=aug_mass,
- rtol=rtol,
- atol=atol,
- )
- y_bar, t0_bar, args_bar = tree_map(op.itemgetter(1), (y_bar, t0_bar, args_bar))
- # Add gradient from current output
- y_bar = y_bar + initialise(g[i - 1], ys[i - 1], ts[i - 1])
- return (y_bar, t0_bar, args_bar), t_bar
-
- init_carry = (y_bar, t0_bar, tree_map(jnp.zeros_like, args))
- (y_bar, t0_bar, args_bar), rev_ts_bar = jax.lax.scan(
- scan_fun, init_carry, jnp.arange(len(ts) - 1, 0, -1)
- )
- ts_bar = jnp.concatenate([jnp.array([t0_bar]), rev_ts_bar[::-1]])
- return (y_bar, ts_bar, *args_bar)
-
-
-_bdf_odeint.defvjp(_bdf_odeint_fwd, _bdf_odeint_rev)
-
-
-@cache()
-def closure_convert(fun, in_tree, in_avals):
- if config.omnistaging_enabled:
- wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
- jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
- else:
- in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
- wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
- with core.initial_style_staging(): # type: ignore
- jaxpr, _, consts = pe.trace_to_jaxpr(
- wrapped_fun, in_pvals, instantiate=True, stage_out=False
- ) # type: ignore
- out_tree = out_tree()
-
- # We only want to closure convert for constants with respect to which we're
- # differentiating. As a proxy for that, we hoist consts with float dtype.
- # TODO(mattjj): revise this approach
- def is_float(c):
- return dtypes.issubdtype(dtypes.dtype(c), jnp.inexact)
-
- (closure_consts, hoisted_consts), merge = partition_list(is_float, consts)
- num_consts = len(hoisted_consts)
-
- def converted_fun(y, t, *hconsts_args):
- hoisted_consts, args = split_list(hconsts_args, [num_consts])
- consts = merge(closure_consts, hoisted_consts)
- all_args, _ = tree_flatten((y, t, *args))
- out_flat = core.eval_jaxpr(jaxpr, consts, *all_args)
- return tree_unflatten(out_tree, out_flat)
-
- return converted_fun, hoisted_consts
-
-
-def partition_list(choice, lst):
- out = [], []
- which = [out[choice(elt)].append(elt) or choice(elt) for elt in lst]
-
- def merge(l1, l2):
- i1, i2 = iter(l1), iter(l2)
- return [next(i2 if snd else i1) for snd in which]
-
- return out, merge
-
-
-def abstractify(x):
- return core.raise_to_shaped(core.get_aval(x))
-
-
-def ravel_first_arg(f, unravel):
- return ravel_first_arg_(lu.wrap_init(f), unravel).call_wrapped
-
-
-@lu.transformation
-def ravel_first_arg_(unravel, y_flat, *args):
- y = unravel(y_flat)
- ans = yield (y,) + args, {}
- ans_flat, _ = ravel_pytree(ans)
- yield ans_flat
diff --git a/pybamm/solvers/jax_solver.py b/pybamm/solvers/jax_solver.py
index e0124294e9..a0d3008207 100644
--- a/pybamm/solvers/jax_solver.py
+++ b/pybamm/solvers/jax_solver.py
@@ -1,12 +1,14 @@
#
# Solver class using Scipy's adaptive time stepper
#
+import numpy as onp
+
import pybamm
-import jax
-from jax.experimental.ode import odeint
-import jax.numpy as jnp
-import numpy as onp
+if pybamm.have_jax():
+ import jax
+ import jax.numpy as jnp
+ from jax.experimental.ode import odeint
class JaxSolver(pybamm.BaseSolver):
@@ -56,6 +58,11 @@ def __init__(
extrap_tol=0,
extra_options=None,
):
+ if not pybamm.have_jax():
+ raise ModuleNotFoundError(
+ "Jax is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501
+ )
+
# note: bdf solver itself calculates consistent initial conditions so can set
# root_method to none, allow user to override this behavior
super().__init__(
diff --git a/pybamm/util.py b/pybamm/util.py
index cab5d995b3..a9a15ea1ea 100644
--- a/pybamm/util.py
+++ b/pybamm/util.py
@@ -4,16 +4,21 @@
# The code in this file is adapted from Pints
# (see https://github.com/pints-team/pints)
#
-import importlib
-import numpy as np
+import importlib.util
+import numbers
import os
-import timeit
import pathlib
import pickle
-import pybamm
-import numbers
+import subprocess
+import sys
+import timeit
import warnings
from collections import defaultdict
+from platform import system
+
+import numpy as np
+
+import pybamm
def root_dir():
@@ -336,3 +341,21 @@ def get_parameters_filepath(path):
return path
else:
return os.path.join(pybamm.__path__[0], path)
+
+
+def have_jax():
+ """Check if jax is installed"""
+ return importlib.util.find_spec("jax") is not None
+
+
+def install_jax():
+ """Install jax, jaxlib"""
+ jax_version = "jax==0.2.12"
+ jaxlib_version = "jaxlib==0.1.70"
+
+ if system() == "Windows":
+ raise NotImplementedError("Jax is not available on Windows")
+ else:
+ subprocess.check_call(
+ [sys.executable, "-m", "pip", "install", jax_version, jaxlib_version]
+ )
diff --git a/requirements.txt b/requirements.txt
index 7753d6ab6f..b23c8f2cf9 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-numpy >= 1.16
+numpy >= 1.16
scipy >= 1.3
pandas >= 0.24
anytree >= 2.4.3
@@ -6,11 +6,9 @@ autograd >= 1.2
scikit-fem >= 0.2.0
casadi >= 3.5.0
imageio>=2.9.0
-jax==0.2.12
-jaxlib==0.1.70
jupyter # For example notebooks
pybtex
-sympy==1.8
+sympy==1.9
# Note: Matplotlib is loaded for debug plots but to ensure pybamm runs
# on systems without an attached display it should never be imported
# outside of plot() methods.
diff --git a/run-tests.py b/run-tests.py
index d2be5ffee0..6b220fd3ea 100755
--- a/run-tests.py
+++ b/run-tests.py
@@ -104,35 +104,20 @@ def run_doc_tests():
sys.exit(ret)
-def run_notebook_and_scripts(skip_slow_books=False, executable="python"):
+def run_notebook_and_scripts(executable="python"):
"""
Runs Jupyter notebook tests. Exits if they fail.
"""
- # Ignore slow books?
- ignore_list = []
- if skip_slow_books and os.path.isfile(".slow-books"):
- with open(".slow-books", "r") as f:
- for line in f.readlines():
- line = line.strip()
- if not line or line[:1] == "#":
- continue
- if not line.startswith("examples/"):
- line = "examples/" + line
- if not line.endswith(".ipynb"):
- line = line + ".ipynb"
- if not os.path.isfile(line):
- raise Exception("Slow notebook note found: " + line)
- ignore_list.append(line)
# Scan and run
print("Testing notebooks and scripts with executable `" + str(executable) + "`")
- if not scan_for_nb_and_scripts("examples", True, executable, ignore_list):
+ if not scan_for_nb_and_scripts("examples", True, executable):
print("\nErrors encountered in notebooks")
sys.exit(1)
print("\nOK")
-def scan_for_nb_and_scripts(root, recursive=True, executable="python", ignore_list=[]):
+def scan_for_nb_and_scripts(root, recursive=True, executable="python"):
"""
Scans for, and tests, all notebooks and scripts in a directory.
"""
@@ -142,9 +127,6 @@ def scan_for_nb_and_scripts(root, recursive=True, executable="python", ignore_li
# Scan path
for filename in os.listdir(root):
path = os.path.join(root, filename)
- if path in ignore_list:
- print("Skipping slow book: " + path)
- continue
# Recurse into subdirectories
if recursive and os.path.isdir(path):
@@ -354,11 +336,6 @@ def export_notebook(ipath, opath):
parser.add_argument(
"--examples",
action="store_true",
- help="Test only the fast Jupyter notebooks and scripts in `examples`.",
- )
- parser.add_argument(
- "--allexamples",
- action="store_true",
help="Test all Jupyter notebooks and scripts in `examples`.",
)
parser.add_argument(
@@ -416,12 +393,9 @@ def export_notebook(ipath, opath):
has_run = True
run_doc_tests()
# Notebook tests
- if args.allexamples:
- has_run = True
- run_notebook_and_scripts(executable=interpreter)
elif args.examples:
has_run = True
- run_notebook_and_scripts(True, interpreter)
+ run_notebook_and_scripts(interpreter)
if args.debook:
has_run = True
export_notebook(*args.debook)
diff --git a/setup.py b/setup.py
index 90ba21485b..55fd14a2bc 100644
--- a/setup.py
+++ b/setup.py
@@ -3,7 +3,7 @@
import logging
import subprocess
from pathlib import Path
-from platform import system, version
+from platform import system
import wheel.bdist_wheel as orig
import site
import shutil
@@ -162,11 +162,6 @@ def compile_KLU():
idaklu_ext = Extension("pybamm.solvers.idaklu", ["pybamm/solvers/c_solvers/idaklu.cpp"])
ext_modules = [idaklu_ext] if compile_KLU() else []
-jax_dependencies = []
-if not (system() == "Windows" or (system() == "Darwin" and "ARM64" in version())):
- jax_dependencies = ["jax==0.2.12", "jaxlib==0.1.70"]
-
-
# Load text for description and license
with open("README.md", encoding="utf-8") as f:
readme = f.read()
@@ -198,10 +193,9 @@ def compile_KLU():
"scikit-fem>=0.2.0",
"casadi>=3.5.0",
"imageio>=2.9.0",
- *jax_dependencies,
"jupyter", # For example notebooks
"pybtex",
- "sympy==1.8",
+ "sympy==1.9",
# Note: Matplotlib is loaded for debug plots, but to ensure pybamm runs
# on systems without an attached display, it should never be imported
# outside of plot() methods.
@@ -221,6 +215,7 @@ def compile_KLU():
"pybamm_add_parameter = pybamm.parameters_cli:add_parameter",
"pybamm_rm_parameter = pybamm.parameters_cli:remove_parameter",
"pybamm_install_odes = pybamm.install_odes:main",
+ "pybamm_install_jax = pybamm.util:install_jax",
]
},
)
diff --git a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_mpm.py b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_mpm.py
index f538965b52..98d596dd00 100644
--- a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_mpm.py
+++ b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_mpm.py
@@ -5,7 +5,6 @@
import tests
import numpy as np
import unittest
-from platform import system
class TestMPM(unittest.TestCase):
@@ -29,7 +28,7 @@ def test_optimisations(self):
np.testing.assert_array_almost_equal(original, using_known_evals)
np.testing.assert_array_almost_equal(original, to_python)
- if system() != "Windows":
+ if pybamm.have_jax():
to_jax = optimtest.evaluate_model(to_jax=True)
np.testing.assert_array_almost_equal(original, to_jax)
diff --git a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spm.py b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spm.py
index a8bb0bd82b..22d4cfa6b4 100644
--- a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spm.py
+++ b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spm.py
@@ -5,7 +5,6 @@
import tests
import numpy as np
import unittest
-from platform import system, version
class TestSPM(unittest.TestCase):
@@ -71,9 +70,7 @@ def test_optimisations(self):
np.testing.assert_array_almost_equal(original, using_known_evals)
np.testing.assert_array_almost_equal(original, to_python)
- if not (
- system() == "Windows" or (system() == "Darwin" and "ARM64" in version())
- ):
+ if pybamm.have_jax():
to_jax = optimtest.evaluate_model(to_jax=True)
np.testing.assert_array_almost_equal(original, to_jax)
diff --git a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spme.py b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spme.py
index d1c88be1c1..32093dd957 100644
--- a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spme.py
+++ b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spme.py
@@ -3,10 +3,8 @@
#
import pybamm
import tests
-
import numpy as np
import unittest
-from platform import system, version
class TestSPMe(unittest.TestCase):
@@ -79,9 +77,7 @@ def test_optimisations(self):
np.testing.assert_array_almost_equal(original, using_known_evals)
np.testing.assert_array_almost_equal(original, to_python)
- if not (
- system() == "Windows" or (system() == "Darwin" and "ARM64" in version())
- ):
+ if pybamm.have_jax():
to_jax = optimtest.evaluate_model(to_jax=True)
np.testing.assert_array_almost_equal(original, to_jax)
diff --git a/tests/unit/test_citations.py b/tests/unit/test_citations.py
index 17c2ad2d5b..34401aa42a 100644
--- a/tests/unit/test_citations.py
+++ b/tests/unit/test_citations.py
@@ -3,7 +3,6 @@
#
import pybamm
import unittest
-from platform import system, version
class TestCitations(unittest.TestCase):
@@ -255,10 +254,7 @@ def test_solver_citations(self):
pybamm.IDAKLUSolver()
self.assertIn("Hindmarsh2005", citations._papers_to_cite)
- @unittest.skipIf(
- system() == "Windows" or (system() == "Darwin" and "ARM64" in version()),
- "JAX not supported on windows or Mac M1",
- )
+ @unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
def test_jax_citations(self):
citations = pybamm.citations
citations._reset()
diff --git a/tests/unit/test_experiments/test_experiment.py b/tests/unit/test_experiments/test_experiment.py
index db6f6196e2..b91645460e 100644
--- a/tests/unit/test_experiments/test_experiment.py
+++ b/tests/unit/test_experiments/test_experiment.py
@@ -46,19 +46,44 @@ def test_read_strings(self):
self.assertEqual(
experiment.operating_conditions[:-3],
[
- {"electric": (1, "C"), "time": 1800.0, "period": 20.0},
- {"electric": (0.05, "C"), "time": 1800.0, "period": 20.0},
- {"electric": (-0.5, "C"), "time": 2700.0, "period": 20.0},
- {"electric": (1, "A"), "time": 1800.0, "period": 20.0},
- {"electric": (-0.2, "A"), "time": 2700.0, "period": 60.0},
- {"electric": (1, "W"), "time": 1800.0, "period": 20.0},
- {"electric": (-0.2, "W"), "time": 2700.0, "period": 20.0},
- {"electric": (0, "A"), "time": 600.0, "period": 300.0},
- {"electric": (1, "V"), "time": 20.0, "period": 20.0},
- {"electric": (-1, "C"), "time": None, "period": 20.0},
- {"electric": (4.1, "V"), "time": None, "period": 20.0},
- {"electric": (3, "V"), "time": None, "period": 20.0},
- {"electric": (1 / 3, "C"), "time": 7200.0, "period": 20.0},
+ {"electric": (1, "C"), "time": 1800.0, "period": 20.0, "dc_data": None},
+ {
+ "electric": (0.05, "C"),
+ "time": 1800.0,
+ "period": 20.0,
+ "dc_data": None,
+ },
+ {
+ "electric": (-0.5, "C"),
+ "time": 2700.0,
+ "period": 20.0,
+ "dc_data": None,
+ },
+ {"electric": (1, "A"), "time": 1800.0, "period": 20.0, "dc_data": None},
+ {
+ "electric": (-0.2, "A"),
+ "time": 2700.0,
+ "period": 60.0,
+ "dc_data": None,
+ },
+ {"electric": (1, "W"), "time": 1800.0, "period": 20.0, "dc_data": None},
+ {
+ "electric": (-0.2, "W"),
+ "time": 2700.0,
+ "period": 20.0,
+ "dc_data": None,
+ },
+ {"electric": (0, "A"), "time": 600.0, "period": 300.0, "dc_data": None},
+ {"electric": (1, "V"), "time": 20.0, "period": 20.0, "dc_data": None},
+ {"electric": (-1, "C"), "time": None, "period": 20.0, "dc_data": None},
+ {"electric": (4.1, "V"), "time": None, "period": 20.0, "dc_data": None},
+ {"electric": (3, "V"), "time": None, "period": 20.0, "dc_data": None},
+ {
+ "electric": (1 / 3, "C"),
+ "time": 7200.0,
+ "period": 20.0,
+ "dc_data": None,
+ },
],
)
# Calculation for operating conditions of drive cycle
@@ -72,19 +97,19 @@ def test_read_strings(self):
period_2 = np.min(np.diff(drive_cycle_2[:, 0]))
# Check drive cycle operating conditions
np.testing.assert_array_equal(
- experiment.operating_conditions[-3]["electric"][0], drive_cycle
+ experiment.operating_conditions[-3]["dc_data"], drive_cycle
)
self.assertEqual(experiment.operating_conditions[-3]["electric"][1], "A")
self.assertEqual(experiment.operating_conditions[-3]["time"], time_0)
self.assertEqual(experiment.operating_conditions[-3]["period"], period_0)
np.testing.assert_array_equal(
- experiment.operating_conditions[-2]["electric"][0], drive_cycle_1
+ experiment.operating_conditions[-2]["dc_data"], drive_cycle_1
)
self.assertEqual(experiment.operating_conditions[-2]["electric"][1], "V")
self.assertEqual(experiment.operating_conditions[-2]["time"], time_1)
self.assertEqual(experiment.operating_conditions[-2]["period"], period_1)
np.testing.assert_array_equal(
- experiment.operating_conditions[-1]["electric"][0], drive_cycle_2
+ experiment.operating_conditions[-1]["dc_data"], drive_cycle_2
)
self.assertEqual(experiment.operating_conditions[-1]["electric"][1], "W")
self.assertEqual(experiment.operating_conditions[-1]["time"], time_2)
@@ -128,9 +153,24 @@ def test_read_strings_cccv_combined(self):
self.assertEqual(
experiment.operating_conditions,
[
- {"electric": (0.05, "C"), "time": 1800.0, "period": 60.0},
- {"electric": (-0.5, "C", 1, "V"), "time": None, "period": 60.0},
- {"electric": (0.05, "C"), "time": 1800.0, "period": 60.0},
+ {
+ "electric": (0.05, "C"),
+ "time": 1800.0,
+ "period": 60.0,
+ "dc_data": None,
+ },
+ {
+ "electric": (-0.5, "C", 1, "V"),
+ "time": None,
+ "period": 60.0,
+ "dc_data": None,
+ },
+ {
+ "electric": (0.05, "C"),
+ "time": 1800.0,
+ "period": 60.0,
+ "dc_data": None,
+ },
],
)
self.assertEqual(experiment.events, [None, (0.02, "C"), None])
@@ -146,8 +186,13 @@ def test_read_strings_cccv_combined(self):
self.assertEqual(
experiment.operating_conditions,
[
- {"electric": (-0.5, "C"), "time": None, "period": 60.0},
- {"electric": (1, "V"), "time": None, "period": 60.0},
+ {
+ "electric": (-0.5, "C"),
+ "time": None,
+ "period": 60.0,
+ "dc_data": None,
+ },
+ {"electric": (1, "V"), "time": None, "period": 60.0, "dc_data": None},
],
)
experiment = pybamm.Experiment(
@@ -160,8 +205,13 @@ def test_read_strings_cccv_combined(self):
self.assertEqual(
experiment.operating_conditions,
[
- {"electric": (-0.5, "C"), "time": 120.0, "period": 60.0},
- {"electric": (1, "V"), "time": None, "period": 60.0},
+ {
+ "electric": (-0.5, "C"),
+ "time": 120.0,
+ "period": 60.0,
+ "dc_data": None,
+ },
+ {"electric": (1, "V"), "time": None, "period": 60.0, "dc_data": None},
],
)
@@ -173,11 +223,26 @@ def test_read_strings_repeat(self):
self.assertEqual(
experiment.operating_conditions,
[
- {"electric": (0.01, "A"), "time": 1800.0, "period": 60},
- {"electric": (-0.5, "C"), "time": 2700.0, "period": 60},
- {"electric": (1, "V"), "time": 20.0, "period": 60},
- {"electric": (-0.5, "C"), "time": 2700.0, "period": 60},
- {"electric": (1, "V"), "time": 20.0, "period": 60},
+ {
+ "electric": (0.01, "A"),
+ "time": 1800.0,
+ "period": 60,
+ "dc_data": None,
+ },
+ {
+ "electric": (-0.5, "C"),
+ "time": 2700.0,
+ "period": 60,
+ "dc_data": None,
+ },
+ {"electric": (1, "V"), "time": 20.0, "period": 60, "dc_data": None},
+ {
+ "electric": (-0.5, "C"),
+ "time": 2700.0,
+ "period": 60,
+ "dc_data": None,
+ },
+ {"electric": (1, "V"), "time": 20.0, "period": 60, "dc_data": None},
],
)
self.assertEqual(experiment.period, 60)
@@ -193,10 +258,30 @@ def test_cycle_unpacking(self):
self.assertEqual(
experiment.operating_conditions,
[
- {"electric": (0.05, "C"), "time": 1800.0, "period": 60.0},
- {"electric": (-0.2, "C"), "time": 2700.0, "period": 60.0},
- {"electric": (0.05, "C"), "time": 1800.0, "period": 60.0},
- {"electric": (-0.2, "C"), "time": 2700.0, "period": 60.0},
+ {
+ "electric": (0.05, "C"),
+ "time": 1800.0,
+ "period": 60.0,
+ "dc_data": None,
+ },
+ {
+ "electric": (-0.2, "C"),
+ "time": 2700.0,
+ "period": 60.0,
+ "dc_data": None,
+ },
+ {
+ "electric": (0.05, "C"),
+ "time": 1800.0,
+ "period": 60.0,
+ "dc_data": None,
+ },
+ {
+ "electric": (-0.2, "C"),
+ "time": 2700.0,
+ "period": 60.0,
+ "dc_data": None,
+ },
],
)
self.assertEqual(experiment.cycle_lengths, [2, 1, 1])
diff --git a/tests/unit/test_experiments/test_simulation_with_experiment.py b/tests/unit/test_experiments/test_simulation_with_experiment.py
index 31eb795df9..5b47b672b8 100644
--- a/tests/unit/test_experiments/test_simulation_with_experiment.py
+++ b/tests/unit/test_experiments/test_simulation_with_experiment.py
@@ -158,6 +158,24 @@ def test_run_experiment_cccv_ode(self):
)
self.assertEqual(solutions[1].termination, "final time")
+ def test_run_experiment_drive_cycle(self):
+ drive_cycle = np.array([np.arange(10), np.arange(10)]).T
+ experiment = pybamm.Experiment(
+ [
+ (
+ "Run drive_cycle (A)",
+ "Run drive_cycle (V)",
+ "Run drive_cycle (W)",
+ )
+ ],
+ drive_cycles={"drive_cycle": drive_cycle}
+ )
+ model = pybamm.lithium_ion.DFN()
+ sim = pybamm.Simulation(model, experiment=experiment)
+ self.assertIn(('drive_cycle', 'A'), sim.op_conds_to_model_and_param)
+ self.assertIn(('drive_cycle', 'V'), sim.op_conds_to_model_and_param)
+ self.assertIn(('drive_cycle', 'W'), sim.op_conds_to_model_and_param)
+
def test_run_experiment_old_setup_type(self):
experiment = pybamm.Experiment(
[
diff --git a/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py b/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py
index 64547d1443..ee7ef06bfb 100644
--- a/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py
+++ b/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py
@@ -8,7 +8,9 @@
import numpy as np
import scipy.sparse
from collections import OrderedDict
-from platform import system, version
+
+if pybamm.have_jax():
+ import jax
def test_function(arg):
@@ -457,10 +459,7 @@ def test_evaluator_python(self):
result = evaluator.evaluate(t=t, y=y)
np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))
- @unittest.skipIf(
- system() == "Windows" or (system() == "Darwin" and "ARM64" in version()),
- "JAX not supported on windows or Mac M1",
- )
+ @unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
def test_find_symbols_jax(self):
# test sparse conversion
constant_symbols = OrderedDict()
@@ -473,10 +472,7 @@ def test_find_symbols_jax(self):
list(constant_symbols.values())[0].toarray(), A.entries.toarray()
)
- @unittest.skipIf(
- system() == "Windows" or (system() == "Darwin" and "ARM64" in version()),
- "JAX not supported on windows or Mac M1",
- )
+ @unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
def test_evaluator_jax(self):
a = pybamm.StateVector(slice(0, 1))
b = pybamm.StateVector(slice(1, 2))
@@ -638,10 +634,7 @@ def test_evaluator_jax(self):
result = evaluator.evaluate(t=t, y=y)
np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))
- @unittest.skipIf(
- system() == "Windows" or (system() == "Darwin" and "ARM64" in version()),
- "JAX not supported on windows or Mac M1",
- )
+ @unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
def test_evaluator_jax_jacobian(self):
a = pybamm.StateVector(slice(0, 1))
y_tests = [np.array([[2.0]]), np.array([[1.0]]), np.array([1.0])]
@@ -656,10 +649,7 @@ def test_evaluator_jax_jacobian(self):
result_true = evaluator_jac.evaluate(t=None, y=y)
np.testing.assert_allclose(result_test, result_true)
- @unittest.skipIf(
- system() == "Windows" or (system() == "Darwin" and "ARM64" in version()),
- "JAX not supported on windows or Mac M1",
- )
+ @unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
def test_evaluator_jax_debug(self):
a = pybamm.StateVector(slice(0, 1))
expr = a ** 2
@@ -667,10 +657,7 @@ def test_evaluator_jax_debug(self):
evaluator = pybamm.EvaluatorJax(expr)
evaluator.debug(y=y_test)
- @unittest.skipIf(
- system() == "Windows" or (system() == "Darwin" and "ARM64" in version()),
- "JAX not supported on windows or Mac M1",
- )
+ @unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
def test_evaluator_jax_inputs(self):
a = pybamm.InputParameter("a")
expr = a ** 2
@@ -678,13 +665,8 @@ def test_evaluator_jax_inputs(self):
result = evaluator.evaluate(inputs={"a": 2})
self.assertEqual(result, 4)
- @unittest.skipIf(
- system() == "Windows" or (system() == "Darwin" and "ARM64" in version()),
- "JAX not supported on windows or Mac M1",
- )
+ @unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
def test_jax_coo_matrix(self):
- import jax
-
A = pybamm.JaxCooMatrix([0, 1], [0, 1], [1.0, 2.0], (2, 2))
Adense = jax.numpy.array([[1.0, 0], [0, 2.0]])
v = jax.numpy.array([[2.0], [1.0]])
diff --git a/tests/unit/test_solvers/test_base_solver.py b/tests/unit/test_solvers/test_base_solver.py
index 7967c96949..bd29be3f81 100644
--- a/tests/unit/test_solvers/test_base_solver.py
+++ b/tests/unit/test_solvers/test_base_solver.py
@@ -296,6 +296,21 @@ def test_timescale_input_fail(self):
with self.assertRaisesRegex(pybamm.SolverError, "The model timescale"):
sol = solver.step(old_solution=sol, model=model, dt=1.0, inputs={"a": 20})
+ def test_inputs_step(self):
+ # Make sure interpolant inputs are dropped
+ model = pybamm.BaseModel()
+ v = pybamm.Variable("v")
+ model.rhs = {v: -1}
+ model.initial_conditions = {v: 1}
+ x = np.array([0, 1])
+ interp = pybamm.Interpolant(x, x, pybamm.t)
+ solver = pybamm.CasadiSolver()
+ for input_key in ["Current input [A]", "Voltage input [V]", "Power input [W]"]:
+ sol = solver.step(
+ old_solution=None, model=model, dt=1.0, inputs={input_key: interp}
+ )
+ self.assertFalse(input_key in sol.all_inputs[0])
+
def test_extrapolation_warnings(self):
# Make sure the extrapolation warnings work
model = pybamm.BaseModel()
@@ -327,29 +342,26 @@ def test_extrapolation_warnings(self):
@unittest.skipIf(not pybamm.have_idaklu(), "idaklu solver is not installed")
def test_sensitivities(self):
-
def exact_diff_a(y, a, b):
- return np.array([
- [y[0]**2 + 2 * a],
- [y[0]]
- ])
+ return np.array([[y[0] ** 2 + 2 * a], [y[0]]])
+ @unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
def exact_diff_b(y, a, b):
return np.array([[y[0]], [0]])
- for convert_to_format in ['', 'python', 'casadi', 'jax']:
+ for convert_to_format in ["", "python", "casadi", "jax"]:
model = pybamm.BaseModel()
v = pybamm.Variable("v")
u = pybamm.Variable("u")
a = pybamm.InputParameter("a")
b = pybamm.InputParameter("b")
- model.rhs = {v: a * v**2 + b * v + a**2}
+ model.rhs = {v: a * v ** 2 + b * v + a ** 2}
model.algebraic = {u: a * v - u}
model.initial_conditions = {v: 1, u: a * 1}
model.convert_to_format = convert_to_format
- solver = pybamm.IDAKLUSolver(root_method='lm')
- model.calculate_sensitivities = ['a', 'b']
- solver.set_up(model, inputs={'a': 0, 'b': 0})
+ solver = pybamm.IDAKLUSolver(root_method="lm")
+ model.calculate_sensitivities = ["a", "b"]
+ solver.set_up(model, inputs={"a": 0, "b": 0})
all_inputs = []
for v_value in [0.1, -0.2, 1.5, 8.4]:
for u_value in [0.13, -0.23, 1.3, 13.4]:
@@ -357,24 +369,20 @@ def exact_diff_b(y, a, b):
for b_value in [0.82, 1.9]:
y = np.array([v_value, u_value])
t = 0
- inputs = {'a': a_value, 'b': b_value}
+ inputs = {"a": a_value, "b": b_value}
all_inputs.append((t, y, inputs))
for t, y, inputs in all_inputs:
- if model.convert_to_format == 'casadi':
+ if model.convert_to_format == "casadi":
use_inputs = casadi.vertcat(*[x for x in inputs.values()])
else:
use_inputs = inputs
- sens = model.sensitivities_eval(
- t, y, use_inputs
- )
+ sens = model.sensitivities_eval(t, y, use_inputs)
np.testing.assert_allclose(
- sens['a'],
- exact_diff_a(y, inputs['a'], inputs['b'])
+ sens["a"], exact_diff_a(y, inputs["a"], inputs["b"])
)
np.testing.assert_allclose(
- sens['b'],
- exact_diff_b(y, inputs['a'], inputs['b'])
+ sens["b"], exact_diff_b(y, inputs["a"], inputs["b"])
)
diff --git a/tests/unit/test_solvers/test_idaklu_solver.py b/tests/unit/test_solvers/test_idaklu_solver.py
index ca66e0c9c5..d17d07c63d 100644
--- a/tests/unit/test_solvers/test_idaklu_solver.py
+++ b/tests/unit/test_solvers/test_idaklu_solver.py
@@ -46,6 +46,7 @@ def test_ida_roberts_klu(self):
true_solution = 0.1 * solution.t
np.testing.assert_array_almost_equal(solution.y[0, :], true_solution)
+ @unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
def test_ida_roberts_klu_sensitivities(self):
# this test implements a python version of the ida Roberts
# example provided in sundials
diff --git a/tests/unit/test_solvers/test_jax_bdf_solver.py b/tests/unit/test_solvers/test_jax_bdf_solver.py
index 772bc937d0..b6e7ab92f1 100644
--- a/tests/unit/test_solvers/test_jax_bdf_solver.py
+++ b/tests/unit/test_solvers/test_jax_bdf_solver.py
@@ -4,16 +4,12 @@
import sys
import time
import numpy as np
-from platform import system, version
-if not (system() == "Windows" or (system() == "Darwin" and "ARM64" in version())):
+if pybamm.have_jax():
import jax
-@unittest.skipIf(
- system() == "Windows" or (system() == "Darwin" and "ARM64" in version()),
- "JAX not supported on windows or Mac M1",
-)
+@unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
class TestJaxBDFSolver(unittest.TestCase):
def test_solver(self):
# Create model
diff --git a/tests/unit/test_solvers/test_jax_solver.py b/tests/unit/test_solvers/test_jax_solver.py
index 74dccdaf99..e9956e4295 100644
--- a/tests/unit/test_solvers/test_jax_solver.py
+++ b/tests/unit/test_solvers/test_jax_solver.py
@@ -4,16 +4,12 @@
import sys
import time
import numpy as np
-from platform import system, version
-if not (system() == "Windows" or (system() == "Darwin" and "ARM64" in version())):
+if pybamm.have_jax():
import jax
-@unittest.skipIf(
- system() == "Windows" or (system() == "Darwin" and "ARM64" in version()),
- "JAX not supported on windows or Mac M1",
-)
+@unittest.skipIf(not pybamm.have_jax(), "jax is not installed")
class TestJaxSolver(unittest.TestCase):
def test_model_solver(self):
# Create model
diff --git a/tests/unit/test_solvers/test_scipy_solver.py b/tests/unit/test_solvers/test_scipy_solver.py
index b525d95eac..ceb4fcbab0 100644
--- a/tests/unit/test_solvers/test_scipy_solver.py
+++ b/tests/unit/test_solvers/test_scipy_solver.py
@@ -6,15 +6,12 @@
from tests import get_mesh_for_testing, get_discretisation_for_testing
import warnings
import sys
-from platform import system, version
class TestScipySolver(unittest.TestCase):
def test_model_solver_python_and_jax(self):
- if not (
- system() == "Windows" or (system() == "Darwin" and "ARM64" in version())
- ):
+ if pybamm.have_jax():
formats = ["python", "jax"]
else:
formats = ["python"]
diff --git a/tox.ini b/tox.ini
index 1501988e90..45d0084e4a 100644
--- a/tox.ini
+++ b/tox.ini
@@ -3,7 +3,7 @@ envlist = {windows}-{tests,quick,dev},tests,quick,dev
[testenv]
skipsdist = true
-skip_install = flake8: true
+skip_install = flake8: true
usedevelop = true
passenv = !windows-!mac: SUNDIALS_INST
whitelist_externals = !windows-!mac: sh
@@ -17,13 +17,15 @@ deps =
dev,doctests: sphinx>=1.5
dev,doctests: guzzle-sphinx-theme
!windows-!mac: scikits.odes
-
+
commands =
- tests: python run-tests.py --unit --folder all
- quick: python run-tests.py --unit
- examples: python run-tests.py --examples
- dev-!windows-!mac: sh -c "echo export LD_LIBRARY_PATH={env:LD_LIBRARY_PATH} >> {envbindir}/activate"
- doctests: python run-tests.py --doctest
+ tests-!windows-!mac: sh -c "pybamm_install_jax" # install jax, jaxlib for ubuntu
+ tests: python run-tests.py --unit --folder all
+ quick: python run-tests.py --unit
+ integration: python run-tests.py --unit --folder integration
+ examples: python run-tests.py --examples
+ dev-!windows-!mac: sh -c "echo export LD_LIBRARY_PATH={env:LD_LIBRARY_PATH} >> {envbindir}/activate"
+ doctests: python run-tests.py --doctest
[testenv:pybamm-requires]
platform = [linux,darwin]
@@ -35,7 +37,7 @@ deps =
cmake
commands =
python {toxinidir}/scripts/install_KLU_Sundials.py
- - git clone https://github.com/pybind/pybind11.git {toxinidir}/pybind11
+ - git clone https://github.com/pybind/pybind11.git {toxinidir}/pybind11
[testenv:flake8]
skip_install = true
@@ -43,10 +45,11 @@ deps = flake8>=3
commands = python -m flake8
[testenv:coverage]
-deps =
+deps =
coverage
scikits.odes
-commands =
+commands =
+ !windows-!mac: sh -c "pybamm_install_jax"
coverage run run-tests.py --nosub
# Some tests make use of multiple processes through
# multiprocessing. Coverage data is then generated for each