diff --git a/.github/release_checklist.md b/.github/release_checklist.md index b62824793f..11daa8ff70 100644 --- a/.github/release_checklist.md +++ b/.github/release_checklist.md @@ -4,4 +4,4 @@ - `CITATION.cff` - `vcpkg.json` - Update CHANGELOG.md with a summary of the release -- Update (and pin) jax and jaxlib to latest version and fix any bugs that arise +- Update jax and jaxlib to latest version in `pybamm.util.install_jax` and fix any bugs that arise diff --git a/.github/workflows/build_wheels.yml b/.github/workflows/build_wheels.yml index 74f712a1b1..c663fdba2f 100644 --- a/.github/workflows/build_wheels.yml +++ b/.github/workflows/build_wheels.yml @@ -76,7 +76,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Download wheels - uses: actions/download-artifact@v1 + uses: actions/download-artifact@v2 with: name: wheels diff --git a/.github/workflows/test_on_push.yml b/.github/workflows/test_on_push.yml index cc137d93ad..d6d4b1509f 100644 --- a/.github/workflows/test_on_push.yml +++ b/.github/workflows/test_on_push.yml @@ -2,7 +2,7 @@ name: PyBaMM on: push: - + pull_request: # everyday at 3 am UTC @@ -10,13 +10,13 @@ on: - cron: '0 3 * * *' jobs: - + style: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - name: Setup python - uses: actions/setup-python@v1 + uses: actions/setup-python@v2 with: python-version: 3.7 @@ -37,19 +37,19 @@ jobs: steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - + - name: Install Linux system dependencies if: matrix.os == 'ubuntu-latest' run: | sudo apt-get update sudo apt install gfortran gcc libopenblas-dev graphviz sudo apt install texlive-full - + # Added fixes to homebrew installs: - # rm -f /usr/local/bin/2to3 + # rm -f /usr/local/bin/2to3 # (see https://github.com/actions/virtual-environments/issues/2322) - name: Install MacOS system dependencies if: matrix.os == 'macos-latest' @@ -62,7 +62,7 @@ jobs: - name: Install Windows system dependencies if: matrix.os == 'windows-latest' run: choco install graphviz --version=2.38.0.20190211 - + - name: Install standard python dependencies run: | python -m pip install --upgrade pip wheel setuptools @@ -72,9 +72,17 @@ jobs: if: matrix.os == 'ubuntu-latest' run: tox -e pybamm-requires - - name: Run unit tests for GNU/Linux + - name: Run unit tests for GNU/Linux with Python 3.7 and 3.8 + if: matrix.os == 'ubuntu-latest' && matrix.python-version != 3.9 + run: python -m tox -e quick + + - name: Run unit tests for GNU/Linux with Python 3.9 and generate coverage report + if: matrix.os == 'ubuntu-latest' && matrix.python-version == 3.9 + run: tox -e coverage + + - name: Run integration tests for GNU/Linux if: matrix.os == 'ubuntu-latest' - run: python -m tox -e tests + run: python -m tox -e integration - name: Run unit tests for Windows and MacOS if: matrix.os != 'ubuntu-latest' @@ -83,16 +91,11 @@ jobs: - name: Install docs dependencies and run doctests if: matrix.os == 'ubuntu-latest' run: tox -e doctests - + - name: Install dev dependencies and run example tests if: matrix.os == 'ubuntu-latest' run: tox -e examples - - - name: Install and run coverage - if: success() && (matrix.os == 'ubuntu-latest' && matrix.python-version == 3.9) - run: tox -e coverage - name: Upload coverage report if: matrix.os == 'ubuntu-latest' && matrix.python-version == 3.9 - uses: codecov/codecov-action@v1 - + uses: codecov/codecov-action@v2.1.0 diff --git a/CHANGELOG.md b/CHANGELOG.md index d7194793bc..3d75b89be7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## Features +- `Experiment`s with drive cycles can be solved ([#1793](https://github.com/pybamm-team/PyBaMM/pull/1793)) - Added surface area to volume ratio as a factor to the SEI equations ([#1790](https://github.com/pybamm-team/PyBaMM/pull/1790)) - Half-cell SPM and SPMe have been implemented ([#1731](https://github.com/pybamm-team/PyBaMM/pull/1731)) @@ -14,6 +15,7 @@ - Raise error when trying to convert an `Interpolant` with the "pchip" interpolator to CasADI ([#1791](https://github.com/pybamm-team/PyBaMM/pull/1791)) - Raise error if `Concatenation` is used directly with `Variable` objects (`concatenation` should be used instead) ([#1789](https://github.com/pybamm-team/PyBaMM/pull/1789)) +- Made jax and the PyBaMM JaxSolver optional ([#1767](https://github.com/pybamm-team/PyBaMM/pull/1767)) # [v21.10](https://github.com/pybamm-team/PyBaMM/tree/v21.9) - 2021-10-31 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e1ecfbd91c..78928d7d60 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -280,7 +280,7 @@ Major PyBaMM features are showcased in [Jupyter notebooks](https://jupyter.org/) All example notebooks should be listed in [examples/README.md](https://github.com/pybamm-team/PyBaMM/blob/develop/examples/notebooks/README.md). Please follow the (naming and writing) style of existing notebooks where possible. -Where possible, notebooks are tested daily. A list of slow notebooks (which time-out and fail tests) is maintained in `.slow-books`, these notebooks will be excluded from daily testing. +All the notebooks are tested daily. ## Citations diff --git a/README.md b/README.md index d2f77757c6..83a18d1c42 100644 --- a/README.md +++ b/README.md @@ -87,7 +87,9 @@ conda install -c conda-forge pybamm ``` ### Optional solvers -On GNU/Linux and MacOS, an optional [scikits.odes](https://scikits-odes.readthedocs.io/en/latest/)-based solver is available, see [the documentation](https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#scikits-odes-label). +Following GNU/Linux and macOS solvers are optionally available: +- [scikits.odes](https://scikits-odes.readthedocs.io/en/latest/)-based solver, see [the documentation](https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-scikits-odes-solver). +- [jax](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html)-based solver, see [the documentation](https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver). ## 📖 Citing PyBaMM diff --git a/docs/index.rst b/docs/index.rst index 0545c825da..a96bd47510 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -46,7 +46,10 @@ PyBaMM is available as a conda package through the conda-forge channel. Optional solvers ----------------- -On GNU/Linux and MacOS, an optional `scikits.odes `_ -based solver is available, see :ref:`scikits.odes-label`. +Following GNU/Linux and macOS solvers are optionally available: + +* `scikits.odes `_ -based solver, see `Optional - scikits.odes solver `_. +* `jax `_ -based solver, see `Optional - JaxSolver `_. Installation ============ diff --git a/docs/install/GNU-linux.rst b/docs/install/GNU-linux.rst index b008955c03..cbb06cab42 100644 --- a/docs/install/GNU-linux.rst +++ b/docs/install/GNU-linux.rst @@ -45,8 +45,8 @@ User install ------------ We recommend to install PyBaMM within a virtual environment, in order -not to alter any distribution python files. -First, make sure you are using python 3.7, 3.8, or 3.9. +not to alter any distribution python files. +First, make sure you are using python 3.7, 3.8, or 3.9. To create a virtual environment ``env`` within your current directory type: .. code:: bash @@ -116,9 +116,24 @@ macOS .. code:: bash pip install scikits.odes - + Assuming that the SUNDIALS were installed as described :ref:`above`. +Optional - JaxSolver +-------------------- + +Users can install ``jax`` and ``jaxlib`` to use the Jax solver. +Currently, only GNU/Linux and macOS are supported. + +GNU/Linux and macOS +~~~~~~~~~~~~~~~~~~~ + +.. code:: bash + + pybamm_install_jax + +The ``pybamm_install_jax`` command is installed with PyBaMM. It automatically downloads and installs jax and jaxlib on your system. + Developer install ----------------- diff --git a/docs/requirements.txt b/docs/requirements.txt index b2754b438d..2dab106555 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -7,8 +7,6 @@ autograd >= 1.2 scikit-fem >= 0.2.0 casadi >= 3.5.0 imageio>=2.9.0 -jax==0.2.12 -jaxlib==0.1.70 jupyter # For example notebooks pybtex # Note: Matplotlib is loaded for debug plots but to ensure pybamm runs @@ -19,4 +17,4 @@ matplotlib >= 2.0 # guzzle-sphinx-theme sphinx>4.0 -sympy==1.8 +sympy==1.9 diff --git a/examples/scripts/experiment_drive_cycle.py b/examples/scripts/experiment_drive_cycle.py new file mode 100644 index 0000000000..2ead334dad --- /dev/null +++ b/examples/scripts/experiment_drive_cycle.py @@ -0,0 +1,52 @@ +# +# Constant-current constant-voltage charge with US06 Drive Cycle using Experiment Class. +# +import pybamm +import pandas as pd +import os + +os.chdir(pybamm.__path__[0] + "/..") + +pybamm.set_logging_level("INFO") + +# import drive cycle from file +drive_cycle_current = pd.read_csv( + "pybamm/input/drive_cycles/US06.csv", comment="#", header=None +).to_numpy() + + +# Map Drive Cycle +def map_drive_cycle(x, min_op_value, max_op_value): + min_ip_value = x[:, 1].min() + max_ip_value = x[:, 1].max() + x[:, 1] = (x[:, 1] - min_ip_value) / (max_ip_value - min_ip_value) * ( + max_op_value - min_op_value + ) + min_op_value + return x + + +# Map current drive cycle to voltage and power +drive_cycle_power = map_drive_cycle(drive_cycle_current, 1.5, 3.5) + +experiment = pybamm.Experiment( + [ + "Charge at 1 A until 4.0 V", + "Hold at 4.0 V until 50 mA", + "Rest for 30 minutes", + "Run US06_A (A)", + "Rest for 30 minutes", + "Run US06_W (W)", + "Rest for 30 minutes", + ], + drive_cycles={ + "US06_A": drive_cycle_current, + "US06_W": drive_cycle_power, + }, +) + +model = pybamm.lithium_ion.DFN() +sim = pybamm.Simulation(model, experiment=experiment, solver=pybamm.CasadiSolver()) +sim.solve() + +# Show all plots +sim.plot() diff --git a/pybamm/CITATIONS.txt b/pybamm/CITATIONS.txt index f2e797b7d5..d9177d8bba 100644 --- a/pybamm/CITATIONS.txt +++ b/pybamm/CITATIONS.txt @@ -192,17 +192,17 @@ } @article{Kirk2021, - author = {Toby L. Kirk and Colin P. Please and S. Jon Chapman}, - title = {Physical Modelling of the Slow Voltage Relaxation Phenomenon in Lithium-Ion Batteries}, - journal = {Journal of The Electrochemical Society}, - year = 2021, - month = {jun}, - publisher = {The Electrochemical Society}, - volume = {168}, - number = {6}, - pages = {060554}, - doi = {10.1149/1945-7111/ac0bf7}, - url = {https://doi.org/10.1149/1945-7111/ac0bf7}, + author = {Toby L. Kirk and Colin P. Please and S. Jon Chapman}, + title = {Physical Modelling of the Slow Voltage Relaxation Phenomenon in Lithium-Ion Batteries}, + journal = {Journal of The Electrochemical Society}, + year = 2021, + month = {jun}, + publisher = {The Electrochemical Society}, + volume = {168}, + number = {6}, + pages = {060554}, + doi = {10.1149/1945-7111/ac0bf7}, + url = {https://doi.org/10.1149/1945-7111/ac0bf7}, } @article{Lain2019, @@ -283,17 +283,17 @@ } @article{OKane2020, - doi = {10.1149/1945-7111/ab90ac}, - url = {https://doi.org/10.1149/1945-7111/ab90ac}, - year = {2020}, - month = {may}, - publisher = {The Electrochemical Society}, - volume = {167}, - number = {9}, - pages = {090540}, - author = {Simon E. J. O'Kane and Ian D. Campbell and Mohamed W. J. Marzook and Gregory J. Offer and Monica Marinescu}, - title = {Physical Origin of the Differential Voltage Minimum Associated with Lithium Plating in Li-Ion Batteries}, - journal = {Journal of The Electrochemical Society} + doi = {10.1149/1945-7111/ab90ac}, + url = {https://doi.org/10.1149/1945-7111/ab90ac}, + year = {2020}, + month = {may}, + publisher = {The Electrochemical Society}, + volume = {167}, + number = {9}, + pages = {090540}, + author = {Simon E. J. O'Kane and Ian D. Campbell and Mohamed W. J. Marzook and Gregory J. Offer and Monica Marinescu}, + title = {Physical Origin of the Differential Voltage Minimum Associated with Lithium Plating in Li-Ion Batteries}, + journal = {Journal of The Electrochemical Society} } @article{ORegan2021, diff --git a/pybamm/__init__.py b/pybamm/__init__.py index aab8a54c4a..b5dfd69d98 100644 --- a/pybamm/__init__.py +++ b/pybamm/__init__.py @@ -7,7 +7,7 @@ # import sys import os -import platform + # # Version info @@ -66,7 +66,7 @@ def version(formatted=False): # from .util import Timer, TimerTime, FuzzyDict from .util import root_dir, load_function, rmse, get_infinite_nested_dict, load -from .util import get_parameters_filepath +from .util import get_parameters_filepath, have_jax, install_jax from .logger import logger, set_logging_level from .settings import settings from .citations import Citations, citations, print_citations @@ -102,12 +102,8 @@ def version(formatted=False): EvaluatorPython, ) -if not ( - platform.system() == "Windows" - or (platform.system() == "Darwin" and "ARM64" in platform.version()) -): - from .expression_tree.operations.evaluate_python import EvaluatorJax - from .expression_tree.operations.evaluate_python import JaxCooMatrix +from .expression_tree.operations.evaluate_python import EvaluatorJax +from .expression_tree.operations.evaluate_python import JaxCooMatrix from .expression_tree.operations.jacobian import Jacobian from .expression_tree.operations.convert_to_casadi import CasadiConverter @@ -226,13 +222,8 @@ def version(formatted=False): from .solvers.scikits_ode_solver import ScikitsOdeSolver, have_scikits_odes from .solvers.scipy_solver import ScipySolver -# Jax not supported under windows -if not ( - platform.system() == "Windows" - or (platform.system() == "Darwin" and "ARM64" in platform.version()) -): - from .solvers.jax_solver import JaxSolver - from .solvers.jax_bdf_solver import jax_bdf_integrate +from .solvers.jax_solver import JaxSolver +from .solvers.jax_bdf_solver import jax_bdf_integrate from .solvers.idaklu_solver import IDAKLUSolver, have_idaklu diff --git a/pybamm/experiments/experiment.py b/pybamm/experiments/experiment.py index ab2f3d709d..c175a94815 100644 --- a/pybamm/experiments/experiment.py +++ b/pybamm/experiments/experiment.py @@ -195,6 +195,7 @@ def read_string(self, cond, drive_cycles): "electric": op_CC["electric"] + op_CV["electric"], "time": op_CV["time"], "period": op_CV["period"], + "dc_data": None, }, event_CV # Read period if " period)" in cond: @@ -223,26 +224,29 @@ def read_string(self, cond, drive_cycles): drive_cycles[cond_list[1]], end_time ) # Drive cycle as numpy array + dc_name = cond_list[1] + "_ext_{}".format(end_time) dc_data = ext_drive_cycle # Find the type of drive cycle ("A", "V", or "W") typ = cond_list[2][1] - electric = (dc_data, typ) + electric = (dc_name, typ) time = ext_drive_cycle[:, 0][-1] period = np.min(np.diff(ext_drive_cycle[:, 0])) events = None else: # e.g. Run US06 # Drive cycle as numpy array + dc_name = cond_list[1] dc_data = drive_cycles[cond_list[1]] # Find the type of drive cycle ("A", "V", or "W") typ = cond_list[2][1] - electric = (dc_data, typ) + electric = (dc_name, typ) # Set time and period to 1 second for first step and # then calculate the difference in consecutive time steps time = drive_cycles[cond_list[1]][:, 0][-1] period = np.min(np.diff(drive_cycles[cond_list[1]][:, 0])) events = None else: + dc_data = None if "for" in cond and "or until" in cond: # e.g. for 3 hours or until 4.2 V cond_list = cond.split() @@ -273,7 +277,12 @@ def read_string(self, cond, drive_cycles): ) ) - return {"electric": electric, "time": time, "period": period}, events + return { + "electric": electric, + "time": time, + "period": period, + "dc_data": dc_data, + }, events def extend_drive_cycle(self, drive_cycle, end_time): "Extends the drive cycle to enable for event" diff --git a/pybamm/expression_tree/operations/evaluate_python.py b/pybamm/expression_tree/operations/evaluate_python.py index ff8bf92853..15c919858f 100644 --- a/pybamm/expression_tree/operations/evaluate_python.py +++ b/pybamm/expression_tree/operations/evaluate_python.py @@ -1,113 +1,110 @@ # # Write a symbol to python # -import pybamm +import numbers +from collections import OrderedDict import numpy as np import scipy.sparse -from collections import OrderedDict -import numbers -from platform import system, version +import pybamm -if not (system() == "Windows" or (system() == "Darwin" and "ARM64" in version())): +if pybamm.have_jax(): import jax - from jax.config import config config.update("jax_enable_x64", True) - class JaxCooMatrix: - """ - A sparse matrix in COO format, with internal arrays using jax device arrays - This matrix only has two operations supported, a multiply with a scalar, and a - dot product with a dense vector. It can also be converted to a dense 2D jax - device array +class JaxCooMatrix: + """ + A sparse matrix in COO format, with internal arrays using jax device arrays - Parameters - ---------- + This matrix only has two operations supported, a multiply with a scalar, and a + dot product with a dense vector. It can also be converted to a dense 2D jax + device array - row: arraylike - 1D array holding row indices of non-zero entries - col: arraylike - 1D array holding col indices of non-zero entries - data: arraylike - 1D array holding non-zero entries - shape: 2-element tuple (x, y) - where x is the number of rows, and y the number of columns of the matrix - """ + Parameters + ---------- - def __init__(self, row, col, data, shape): - self.row = jax.numpy.array(row) - self.col = jax.numpy.array(col) - self.data = jax.numpy.array(data) - self.shape = shape - self.nnz = len(self.data) - - def toarray(self): - """convert sparse matrix to a dense 2D array""" - result = jax.numpy.zeros(self.shape, dtype=self.data.dtype) - return result.at[self.row, self.col].add(self.data) - - def dot_product(self, b): - """ - dot product of matrix with a dense column vector b - - Parameters - ---------- - b: jax device array - must have shape (n, 1) - """ - # assume b is a column vector - result = jax.numpy.zeros((self.shape[0], 1), dtype=b.dtype) - return result.at[self.row].add(self.data.reshape(-1, 1) * b[self.col]) - - def scalar_multiply(self, b): - """ - multiply of matrix with a scalar b - - Parameters - ---------- - b: Number or 1 element jax device array - scalar value to multiply - """ - # assume b is a scalar or ndarray with 1 element - return JaxCooMatrix( - self.row, self.col, (self.data * b).reshape(-1), self.shape + row: arraylike + 1D array holding row indices of non-zero entries + col: arraylike + 1D array holding col indices of non-zero entries + data: arraylike + 1D array holding non-zero entries + shape: 2-element tuple (x, y) + where x is the number of rows, and y the number of columns of the matrix + """ + + def __init__(self, row, col, data, shape): + if not pybamm.have_jax(): + raise ModuleNotFoundError( + "Jax is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501 ) - def multiply(self, b): - """ - general matrix multiply not supported - """ - raise NotImplementedError + self.row = jax.numpy.array(row) + self.col = jax.numpy.array(col) + self.data = jax.numpy.array(data) + self.shape = shape + self.nnz = len(self.data) + + def toarray(self): + """convert sparse matrix to a dense 2D array""" + result = jax.numpy.zeros(self.shape, dtype=self.data.dtype) + return result.at[self.row, self.col].add(self.data) + + def dot_product(self, b): + """ + dot product of matrix with a dense column vector b - def __matmul__(self, b): - """see self.dot_product""" - return self.dot_product(b) + Parameters + ---------- + b: jax device array + must have shape (n, 1) + """ + # assume b is a column vector + result = jax.numpy.zeros((self.shape[0], 1), dtype=b.dtype) + return result.at[self.row].add(self.data.reshape(-1, 1) * b[self.col]) - def create_jax_coo_matrix(value): + def scalar_multiply(self, b): """ - Creates a JaxCooMatrix from a scipy.sparse matrix + multiply of matrix with a scalar b Parameters ---------- + b: Number or 1 element jax device array + scalar value to multiply + """ + # assume b is a scalar or ndarray with 1 element + return JaxCooMatrix(self.row, self.col, (self.data * b).reshape(-1), self.shape) - value: scipy.sparse matrix - the sparse matrix to be converted + def multiply(self, b): """ - scipy_coo = value.tocoo() - row = jax.numpy.asarray(scipy_coo.row) - col = jax.numpy.asarray(scipy_coo.col) - data = jax.numpy.asarray(scipy_coo.data) - return JaxCooMatrix(row, col, data, value.shape) + general matrix multiply not supported + """ + raise NotImplementedError + def __matmul__(self, b): + """see self.dot_product""" + return self.dot_product(b) -else: - def create_jax_coo_matrix(value): # pragma: no cover - raise NotImplementedError("Jax is not available on Windows") +def create_jax_coo_matrix(value): + """ + Creates a JaxCooMatrix from a scipy.sparse matrix + + Parameters + ---------- + + value: scipy.sparse matrix + the sparse matrix to be converted + """ + scipy_coo = value.tocoo() + row = jax.numpy.asarray(scipy_coo.row) + col = jax.numpy.asarray(scipy_coo.col) + data = jax.numpy.asarray(scipy_coo.data) + return JaxCooMatrix(row, col, data, value.shape) def id_to_python_variable(symbol_id, constant=False): @@ -548,6 +545,11 @@ class EvaluatorJax: """ def __init__(self, symbol): + if not pybamm.have_jax(): + raise ModuleNotFoundError( + "Jax is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501 + ) + constants, python_str = pybamm.to_python(symbol, debug=False, output_jax=True) # replace numpy function calls to jax numpy calls @@ -582,9 +584,7 @@ def __init__(self, symbol): args = "t=None, y=None, y_dot=None, inputs=None, known_evals=None" if self._arg_list: args = ",".join(self._arg_list) + ", " + args - python_str = ( - "def evaluate_jax({}):\n".format(args) + python_str - ) + python_str = "def evaluate_jax({}):\n".format(args) + python_str # calculate the final variable that will output the result of calling `evaluate` # on `symbol` @@ -609,8 +609,9 @@ def __init__(self, symbol): exec(compiled_function) self._static_argnums = tuple(static_argnums) - self._jit_evaluate = jax.jit(self._evaluate_jax, - static_argnums=self._static_argnums) + self._jit_evaluate = jax.jit( + self._evaluate_jax, static_argnums=self._static_argnums + ) def get_jacobian(self): n = len(self._arg_list) @@ -618,8 +619,9 @@ def get_jacobian(self): # forward mode autodiff wrt y, which is argument 1 after arg_list jacobian_evaluate = jax.jacfwd(self._evaluate_jax, argnums=1 + n) - self._jac_evaluate = jax.jit(jacobian_evaluate, - static_argnums=self._static_argnums) + self._jac_evaluate = jax.jit( + jacobian_evaluate, static_argnums=self._static_argnums + ) return EvaluatorJaxJacobian(self._jac_evaluate, self._constants) @@ -629,8 +631,9 @@ def get_sensitivities(self): # forward mode autodiff wrt inputs, which is argument 3 after arg_list jacobian_evaluate = jax.jacfwd(self._evaluate_jax, argnums=3 + n) - self._sens_evaluate = jax.jit(jacobian_evaluate, - static_argnums=self._static_argnums) + self._sens_evaluate = jax.jit( + jacobian_evaluate, static_argnums=self._static_argnums + ) return EvaluatorJaxSensitivities(self._sens_evaluate, self._constants) diff --git a/pybamm/simulation.py b/pybamm/simulation.py index 58f435868e..589b762cd9 100644 --- a/pybamm/simulation.py +++ b/pybamm/simulation.py @@ -21,7 +21,7 @@ def is_notebook(): return False # Terminal running IPython elif shell == "Shell": # pragma: no cover return True # Google Colab notebook - else: + else: # pragma: no cover return False # Other type (?) except NameError: return False # Probably standard Python interpreter @@ -170,45 +170,76 @@ def set_up_experiment(self, model, experiment): "Power switch": 0, "CCCV switch": 0, "Current input [A]": 0, - "Voltage input [V]": 0, # doesn't matter - "Power input [W]": 0, # doesn't matter + "Voltage input [V]": 0, + "Power input [W]": 0, } op_control = op["electric"][1] - if op_control in ["A", "C"]: - capacity = self._parameter_values["Nominal cell capacity [A.h]"] + if op["dc_data"] is not None: + # If operating condition includes a drive cycle, define the interpolant + timescale = self._parameter_values.evaluate(model.timescale) + drive_cycle_interpolant = pybamm.Interpolant( + op["dc_data"][:, 0], + op["dc_data"][:, 1], + timescale * (pybamm.t - pybamm.InputParameter("start time")), + ) if op_control == "A": - I = op["electric"][0] - Crate = I / capacity - else: - # Scale C-rate with capacity to obtain current - Crate = op["electric"][0] - I = Crate * capacity - if len(op["electric"]) == 4: - # Update inputs for CCCV - op_control = "CCCV" # change to CCCV - V = op["electric"][2] operating_inputs.update( { - "CCCV switch": 1, - "Current input [A]": I, - "Voltage input [V]": V, + "Current switch": 1, + "Current input [A]": drive_cycle_interpolant, } ) - else: - # Update inputs for constant current + if op_control == "V": + operating_inputs.update( + { + "Voltage switch": 1, + "Voltage input [V]": drive_cycle_interpolant, + } + ) + if op_control == "W": + operating_inputs.update( + {"Power switch": 1, "Power input [W]": drive_cycle_interpolant} + ) + else: + if op_control in ["A", "C"]: + capacity = self._parameter_values["Nominal cell capacity [A.h]"] + if op_control == "A": + I = op["electric"][0] + Crate = I / capacity + else: + # Scale C-rate with capacity to obtain current + Crate = op["electric"][0] + I = Crate * capacity + if len(op["electric"]) == 4: + # Update inputs for CCCV + op_control = "CCCV" # change to CCCV + V = op["electric"][2] + operating_inputs.update( + { + "CCCV switch": 1, + "Current input [A]": I, + "Voltage input [V]": V, + } + ) + else: + # Update inputs for constant current + operating_inputs.update( + {"Current switch": 1, "Current input [A]": I} + ) + elif op_control == "V": + # Update inputs for constant voltage + V = op["electric"][0] operating_inputs.update( - {"Current switch": 1, "Current input [A]": I} + {"Voltage switch": 1, "Voltage input [V]": V} ) - elif op_control == "V": - # Update inputs for constant voltage - V = op["electric"][0] - operating_inputs.update({"Voltage switch": 1, "Voltage input [V]": V}) - elif op_control == "W": - # Update inputs for constant power - P = op["electric"][0] - operating_inputs.update({"Power switch": 1, "Power input [W]": P}) + elif op_control == "W": + # Update inputs for constant power + P = op["electric"][0] + operating_inputs.update({"Power switch": 1, "Power input [W]": P}) + # Update period operating_inputs["period"] = op["period"] + # Update events if events is None: # make current and voltage values that won't be hit @@ -851,6 +882,11 @@ def solve( f"step {step_num}/{cycle_length}: {op_conds_str}" ) inputs.update(exp_inputs) + if current_solution is None: + start_time = 0 + else: + start_time = current_solution.t[-1] + inputs.update({"start time": start_time}) kwargs["inputs"] = inputs # Make sure we take at least 2 timesteps npts = max(int(round(dt / exp_inputs["period"])) + 1, 2) diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index ca2b43dc1a..98fa02dd59 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -1160,7 +1160,7 @@ def step( external_variables : dict A dictionary of external variables and their corresponding values at the current time - inputs_dict : dict, optional + inputs : dict, optional Any input parameters to pass to the model when solving save : bool Turn on to store the solution of all previous timesteps @@ -1201,6 +1201,14 @@ def step( # Set up external variables and inputs external_variables = external_variables or {} inputs = inputs or {} + + # Remove interpolant inputs as Casadi can't handle them + if isinstance(inputs.get("Current input [A]"), pybamm.Interpolant): + del inputs["Current input [A]"] + elif isinstance(inputs.get("Voltage input [V]"), pybamm.Interpolant): + del inputs["Voltage input [V]"] + elif isinstance(inputs.get("Power input [W]"), pybamm.Interpolant): + del inputs["Power input [W]"] ext_and_inputs = {**external_variables, **inputs} # Check that any inputs that may affect the scaling have not changed diff --git a/pybamm/solvers/jax_bdf_solver.py b/pybamm/solvers/jax_bdf_solver.py index 1dd765e16a..4d4c40a76f 100644 --- a/pybamm/solvers/jax_bdf_solver.py +++ b/pybamm/solvers/jax_bdf_solver.py @@ -1,792 +1,992 @@ -import operator as op -import numpy as onp import collections - -import jax -import jax.numpy as jnp -from jax import core -from jax import dtypes -from jax.util import safe_map, cache, split_list -from jax.api_util import flatten_fun_nokwargs -from jax.flatten_util import ravel_pytree -from jax.tree_util import tree_map, tree_flatten, tree_unflatten, tree_multimap -from jax.interpreters import partial_eval as pe +import operator as op from functools import partial -from jax import linear_util as lu -from jax.config import config -from absl import logging - -logging.set_verbosity(logging.ERROR) - -config.update("jax_enable_x64", True) - -MAX_ORDER = 5 -NEWTON_MAXITER = 4 -ROOT_SOLVE_MAXITER = 15 -MIN_FACTOR = 0.2 -MAX_FACTOR = 10 - - -# https://github.com/google/jax/issues/4572#issuecomment-709809897 -def some_hash_function(x): - return hash(x.tobytes()) - - -class HashableArrayWrapper: - """wrapper for a numpy array to make it hashable""" - def __init__(self, val): - self.val = val - - def __hash__(self): - return some_hash_function(self.val) - - def __eq__(self, other): - return (isinstance(other, HashableArrayWrapper) and - onp.all(onp.equal(self.val, other.val))) - -def gnool_jit(fun, static_array_argnums=(), static_argnums=()): - """redefinition of jax jit to allow static array args""" - @partial( - jax.jit, - static_argnums=static_array_argnums + static_argnums - ) - def callee(*args): - args = list(args) - for i in static_array_argnums: - args[i] = args[i].val - return fun(*args) - - def caller(*args): - args = list(args) - for i in static_array_argnums: - args[i] = HashableArrayWrapper(args[i]) - return callee(*args) - - return caller - - -@jax.partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2, 3)) -def _bdf_odeint(fun, mass, rtol, atol, y0, t_eval, *args): - """ - This implements a Backward Difference formula (BDF) implicit multistep integrator. - The basic algorithm is derived in [2]_. This particular implementation follows that - implemented in the Matlab routine ode15s described in [1]_ and the SciPy - implementation [3]_, which features the NDF formulas for improved stability, with - associated differences in the error constants, and calculates the jacobian at - J(t_{n+1}, y^0_{n+1}). This implementation was based on that implemented in the - scipy library [3]_, which also mainly follows [1]_ but uses the more standard - jacobian update. - - Parameters - ---------- - - func: callable - function to evaluate the time derivative of the solution `y` at time - `t` as `func(y, t, *args)`, producing the same shape/structure as `y0`. - mass: ndarray - diagonal of the mass matrix with shape (n,) - y0: ndarray - initial state vector, has shape (n,) - t_eval: ndarray - time points to evaluate the solution, has shape (m,) - args: (optional) - tuple of additional arguments for `fun`, which must be arrays - scalars, or (nested) standard Python containers (tuples, lists, dicts, - namedtuples, i.e. pytrees) of those types. - rtol: (optional) float - relative tolerance for the solver - atol: (optional) float - absolute tolerance for the solver +import numpy as onp - Returns - ------- - y: ndarray with shape (n, m) - calculated state vector at each of the m time points +import pybamm - References - ---------- - .. [1] L. F. Shampine, M. W. Reichelt, "THE MATLAB ODE SUITE", SIAM J. SCI. - COMPUTE., Vol. 18, No. 1, pp. 1-22, January 1997. - .. [2] G. D. Byrne, A. C. Hindmarsh, "A Polyalgorithm for the Numerical - Solution of Ordinary Differential Equations", ACM Transactions on - Mathematical Software, Vol. 1, No. 1, pp. 71-96, March 1975. - .. [3] Virtanen, P., Gommers, R., Oliphant, T. E., Haberland, M., Reddy, - T., Cournapeau, D., ... & van der Walt, S. J. (2020). SciPy 1.0: - fundamental algorithms for scientific computing in Python. - Nature methods, 17(3), 261-272. - """ +if pybamm.have_jax(): + import jax + import jax.numpy as jnp + from absl import logging + from jax import core, dtypes + from jax import linear_util as lu + from jax.api_util import flatten_fun_nokwargs + from jax.config import config + from jax.flatten_util import ravel_pytree + from jax.interpreters import partial_eval as pe + from jax.tree_util import tree_flatten, tree_map, tree_multimap, tree_unflatten + from jax.util import cache, safe_map, split_list - def fun_bind_inputs(y, t): - return fun(y, t, *args) + config.update("jax_enable_x64", True) - jac_bind_inputs = jax.jacfwd(fun_bind_inputs, argnums=0) + logging.set_verbosity(logging.ERROR) - t0 = t_eval[0] - h0 = t_eval[1] - t0 + MAX_ORDER = 5 + NEWTON_MAXITER = 4 + ROOT_SOLVE_MAXITER = 15 + MIN_FACTOR = 0.2 + MAX_FACTOR = 10 - stepper = _bdf_init(fun_bind_inputs, jac_bind_inputs, mass, t0, y0, h0, rtol, atol) - i = 0 - y_out = jnp.empty((len(t_eval), len(y0)), dtype=y0.dtype) + # https://github.com/google/jax/issues/4572#issuecomment-709809897 + def some_hash_function(x): + return hash(x.tobytes()) - init_state = [stepper, t_eval, i, y_out] + class HashableArrayWrapper: + """wrapper for a numpy array to make it hashable""" - def cond_fun(state): - _, t_eval, i, _ = state - return i < len(t_eval) + def __init__(self, val): + self.val = val - def body_fun(state): - stepper, t_eval, i, y_out = state - stepper = _bdf_step(stepper, fun_bind_inputs, jac_bind_inputs) - index = jnp.searchsorted(t_eval, stepper.t) + def __hash__(self): + return some_hash_function(self.val) - def for_body(j, y_out): - t = t_eval[j] - y_out = jax.ops.index_update( - y_out, jax.ops.index[j, :], _bdf_interpolate(stepper, t) + def __eq__(self, other): + return isinstance(other, HashableArrayWrapper) and onp.all( + onp.equal(self.val, other.val) ) - return y_out - - y_out = jax.lax.fori_loop(i, index, for_body, y_out) - return [stepper, t_eval, index, y_out] - - stepper, t_eval, i, y_out = jax.lax.while_loop(cond_fun, body_fun, init_state) - return y_out - - -BDFInternalStates = [ - "t", - "atol", - "rtol", - "M", - "newton_tol", - "order", - "h", - "n_equal_steps", - "D", - "y0", - "scale_y0", - "kappa", - "gamma", - "alpha", - "c", - "error_const", - "J", - "LU", - "U", - "psi", - "n_function_evals", - "n_jacobian_evals", - "n_lu_decompositions", - "n_steps", - "consistent_y0_failed", -] -BDFState = collections.namedtuple("BDFState", BDFInternalStates) - -jax.tree_util.register_pytree_node( - BDFState, lambda xs: (tuple(xs), None), lambda _, xs: BDFState(*xs) -) - - -def _bdf_init(fun, jac, mass, t0, y0, h0, rtol, atol): - """ - Initiation routine for Backward Difference formula (BDF) implicit multistep - integrator. - - See _bdf_odeint function above for details, this function returns a dict with the - initial state of the solver - - Parameters - ---------- - - fun: callable - function with signature (y, t), where t is a scalar time and y is a ndarray with - shape (n,), returns the rhs of the system of ODE equations as an nd array with - shape (n,) - jac: callable - function with signature (y, t), where t is a scalar time and y is a ndarray with - shape (n,), returns the jacobian matrix of fun as an ndarray with shape (n,n) - mass: ndarray - diagonal of the mass matrix with shape (n,) - t0: float - initial time - y0: ndarray - initial state vector with shape (n,) - h0: float - initial step size - rtol: (optional) float - relative tolerance for the solver - atol: (optional) float - absolute tolerance for the solver - """ - - state = {} - state["t"] = t0 - state["atol"] = atol - state["rtol"] = rtol - state["M"] = mass - EPS = jnp.finfo(y0.dtype).eps - state["newton_tol"] = jnp.maximum(10 * EPS / rtol, jnp.minimum(0.03, rtol ** 0.5)) - - scale_y0 = atol + rtol * jnp.abs(y0) - y0, not_converged = _select_initial_conditions( - fun, mass, t0, y0, state["newton_tol"], scale_y0 - ) - state["consistent_y0_failed"] = not_converged - - f0 = fun(y0, t0) - order = 1 - state["order"] = order - state["h"] = _select_initial_step(atol, rtol, fun, t0, y0, f0, h0) - state["n_equal_steps"] = 0 - D = jnp.empty((MAX_ORDER + 1, len(y0)), dtype=y0.dtype) - D = jax.ops.index_update(D, jax.ops.index[0, :], y0) - D = jax.ops.index_update(D, jax.ops.index[1, :], f0 * state["h"]) - state["D"] = D - state["y0"] = y0 - state["scale_y0"] = scale_y0 - - # kappa values for difference orders, taken from Table 1 of [1] - kappa = jnp.array([0, -0.1850, -1 / 9, -0.0823, -0.0415, 0]) - gamma = jnp.hstack((0, jnp.cumsum(1 / jnp.arange(1, MAX_ORDER + 1)))) - alpha = 1.0 / ((1 - kappa) * gamma) - c = state["h"] * alpha[order] - error_const = kappa * gamma + 1 / jnp.arange(1, MAX_ORDER + 2) - - state["kappa"] = kappa - state["gamma"] = gamma - state["alpha"] = alpha - state["c"] = c - state["error_const"] = error_const - - J = jac(y0, t0) - state["J"] = J - - state["LU"] = jax.scipy.linalg.lu_factor(state["M"] - c * J) - - state["U"] = _compute_R(order, 1) - state["psi"] = None - - state["n_function_evals"] = 2 - state["n_jacobian_evals"] = 1 - state["n_lu_decompositions"] = 1 - state["n_steps"] = 0 - - tuple_state = BDFState(*[state[k] for k in BDFInternalStates]) - y0, scale_y0 = _predict(tuple_state, D) - psi = _update_psi(tuple_state, D) - return tuple_state._replace(y0=y0, scale_y0=scale_y0, psi=psi) - - -def _compute_R(order, factor): - """ - computes the R matrix with entries - given by the first equation on page 8 of [1] - - This is used to update the differences matrix when step size h is varied according - to factor = h_{n+1} / h_n - - Note that the U matrix also defined in the same section can be also be - found using factor = 1, which corresponds to R with a constant step size - """ - I = jnp.arange(1, MAX_ORDER + 1).reshape(-1, 1) - J = jnp.arange(1, MAX_ORDER + 1) - M = jnp.empty((MAX_ORDER + 1, MAX_ORDER + 1)) - M = jax.ops.index_update(M, jax.ops.index[1:, 1:], (I - 1 - factor * J) / I) - M = jax.ops.index_update(M, jax.ops.index[0], 1) - R = jnp.cumprod(M, axis=0) - - return R + def gnool_jit(fun, static_array_argnums=(), static_argnums=()): + """redefinition of jax jit to allow static array args""" -def _select_initial_conditions(fun, M, t0, y0, tol, scale_y0): - # identify algebraic variables as zeros on diagonal - algebraic_variables = onp.diag(M) == 0.0 + @partial(jax.jit, static_argnums=static_array_argnums + static_argnums) + def callee(*args): + args = list(args) + for i in static_array_argnums: + args[i] = args[i].val + return fun(*args) - # if all differentiable variables then return y0 (can use normal python if since M - # is static) - if not onp.any(algebraic_variables): - return y0, False + def caller(*args): + args = list(args) + for i in static_array_argnums: + args[i] = HashableArrayWrapper(args[i]) + return callee(*args) - # calculate consistent initial conditions via a newton on -J_a @ delta = f_a This - # follows this reference: - # - # Shampine, L. F., Reichelt, M. W., & Kierzenka, J. A. (1999). Solving index-1 DAEs - # in MATLAB and Simulink. SIAM review, 41(3), 538-552. - - # calculate fun_a, function of algebraic variables - def fun_a(y_a): - y_full = jax.ops.index_update(y0, algebraic_variables, y_a) - return fun(y_full, t0)[algebraic_variables] - - y0_a = y0[algebraic_variables] - scale_y0_a = scale_y0[algebraic_variables] - - d = jnp.zeros(y0_a.shape[0], dtype=y0.dtype) - y_a = jnp.array(y0_a, copy=True) + return caller - # calculate neg jacobian of fun_a - J_a = jax.jacfwd(fun_a)(y_a) - LU = jax.scipy.linalg.lu_factor(-J_a) + @jax.partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2, 3)) + def _bdf_odeint(fun, mass, rtol, atol, y0, t_eval, *args): + """ + Implements a Backward Difference formula (BDF) implicit multistep integrator. + The basic algorithm is derived in [2]_. This particular implementation follows + that implemented in the Matlab routine ode15s described in [1]_ and the SciPy + implementation [3]_, which features the NDF formulas for improved stability, + with associated differences in the error constants, and calculates the jacobian + at J(t_{n+1}, y^0_{n+1}). This implementation was based on that implemented in + the scipy library [3]_, which also mainly follows [1]_ but uses the more + standard jacobian update. + + Parameters + ---------- + + func: callable + function to evaluate the time derivative of the solution `y` at time + `t` as `func(y, t, *args)`, producing the same shape/structure as `y0`. + mass: ndarray + diagonal of the mass matrix with shape (n,) + y0: ndarray + initial state vector, has shape (n,) + t_eval: ndarray + time points to evaluate the solution, has shape (m,) + args: (optional) + tuple of additional arguments for `fun`, which must be arrays + scalars, or (nested) standard Python containers (tuples, lists, dicts, + namedtuples, i.e. pytrees) of those types. + rtol: (optional) float + relative tolerance for the solver + atol: (optional) float + absolute tolerance for the solver + + Returns + ------- + y: ndarray with shape (n, m) + calculated state vector at each of the m time points + + References + ---------- + .. [1] L. F. Shampine, M. W. Reichelt, "THE MATLAB ODE SUITE", SIAM J. SCI. + COMPUTE., Vol. 18, No. 1, pp. 1-22, January 1997. + .. [2] G. D. Byrne, A. C. Hindmarsh, "A Polyalgorithm for the Numerical + Solution of Ordinary Differential Equations", ACM Transactions on + Mathematical Software, Vol. 1, No. 1, pp. 71-96, March 1975. + .. [3] Virtanen, P., Gommers, R., Oliphant, T. E., Haberland, M., Reddy, + T., Cournapeau, D., ... & van der Walt, S. J. (2020). SciPy 1.0: + fundamental algorithms for scientific computing in Python. + Nature methods, 17(3), 261-272. + """ - converged = False - dy_norm_old = -1.0 - k = 0 - while_state = [k, converged, dy_norm_old, d, y_a] + def fun_bind_inputs(y, t): + return fun(y, t, *args) - def while_cond(while_state): - k, converged, _, _, _ = while_state - return (converged == False) * (k < ROOT_SOLVE_MAXITER) # noqa: E712 + jac_bind_inputs = jax.jacfwd(fun_bind_inputs, argnums=0) - def while_body(while_state): - k, converged, dy_norm_old, d, y_a = while_state - f_eval = fun_a(y_a) - dy = jax.scipy.linalg.lu_solve(LU, f_eval) - dy_norm = jnp.sqrt(jnp.mean((dy / scale_y0_a) ** 2)) - rate = dy_norm / dy_norm_old + t0 = t_eval[0] + h0 = t_eval[1] - t0 - d += dy - y_a = y0_a + d + stepper = _bdf_init( + fun_bind_inputs, jac_bind_inputs, mass, t0, y0, h0, rtol, atol + ) + i = 0 + y_out = jnp.empty((len(t_eval), len(y0)), dtype=y0.dtype) + + init_state = [stepper, t_eval, i, y_out] + + def cond_fun(state): + _, t_eval, i, _ = state + return i < len(t_eval) + + def body_fun(state): + stepper, t_eval, i, y_out = state + stepper = _bdf_step(stepper, fun_bind_inputs, jac_bind_inputs) + index = jnp.searchsorted(t_eval, stepper.t) + + def for_body(j, y_out): + t = t_eval[j] + y_out = jax.ops.index_update( + y_out, jax.ops.index[j, :], _bdf_interpolate(stepper, t) + ) + return y_out + + y_out = jax.lax.fori_loop(i, index, for_body, y_out) + return [stepper, t_eval, index, y_out] + + stepper, t_eval, i, y_out = jax.lax.while_loop(cond_fun, body_fun, init_state) + return y_out + + BDFInternalStates = [ + "t", + "atol", + "rtol", + "M", + "newton_tol", + "order", + "h", + "n_equal_steps", + "D", + "y0", + "scale_y0", + "kappa", + "gamma", + "alpha", + "c", + "error_const", + "J", + "LU", + "U", + "psi", + "n_function_evals", + "n_jacobian_evals", + "n_lu_decompositions", + "n_steps", + "consistent_y0_failed", + ] + BDFState = collections.namedtuple("BDFState", BDFInternalStates) - # if converged then break out of iteration early - pred = dy_norm_old >= 0.0 - pred *= rate / (1 - rate) * dy_norm < tol - converged = (dy_norm == 0.0) + pred + jax.tree_util.register_pytree_node( + BDFState, lambda xs: (tuple(xs), None), lambda _, xs: BDFState(*xs) + ) - dy_norm_old = dy_norm + def _bdf_init(fun, jac, mass, t0, y0, h0, rtol, atol): + """ + Initiation routine for Backward Difference formula (BDF) implicit multistep + integrator. + + See _bdf_odeint function above for details, this function returns a dict with + the initial state of the solver + + Parameters + ---------- + + fun: callable + function with signature (y, t), where t is a scalar time and y is a ndarray + with shape (n,), returns the rhs of the system of ODE equations as an nd + array with shape (n,) + jac: callable + function with signature (y, t), where t is a scalar time and y is a ndarray + with shape (n,), returns the jacobian matrix of fun as an ndarray with + shape (n,n) + mass: ndarray + diagonal of the mass matrix with shape (n,) + t0: float + initial time + y0: ndarray + initial state vector with shape (n,) + h0: float + initial step size + rtol: (optional) float + relative tolerance for the solver + atol: (optional) float + absolute tolerance for the solver + """ - return [k + 1, converged, dy_norm_old, d, y_a] + state = {} + state["t"] = t0 + state["atol"] = atol + state["rtol"] = rtol + state["M"] = mass + EPS = jnp.finfo(y0.dtype).eps + state["newton_tol"] = jnp.maximum( + 10 * EPS / rtol, jnp.minimum(0.03, rtol ** 0.5) + ) - k, converged, dy_norm_old, d, y_a = jax.lax.while_loop( - while_cond, while_body, while_state - ) - y_tilde = jax.ops.index_update(y0, algebraic_variables, y_a) + scale_y0 = atol + rtol * jnp.abs(y0) + y0, not_converged = _select_initial_conditions( + fun, mass, t0, y0, state["newton_tol"], scale_y0 + ) + state["consistent_y0_failed"] = not_converged + + f0 = fun(y0, t0) + order = 1 + state["order"] = order + state["h"] = _select_initial_step(atol, rtol, fun, t0, y0, f0, h0) + state["n_equal_steps"] = 0 + D = jnp.empty((MAX_ORDER + 1, len(y0)), dtype=y0.dtype) + D = jax.ops.index_update(D, jax.ops.index[0, :], y0) + D = jax.ops.index_update(D, jax.ops.index[1, :], f0 * state["h"]) + state["D"] = D + state["y0"] = y0 + state["scale_y0"] = scale_y0 + + # kappa values for difference orders, taken from Table 1 of [1] + kappa = jnp.array([0, -0.1850, -1 / 9, -0.0823, -0.0415, 0]) + gamma = jnp.hstack((0, jnp.cumsum(1 / jnp.arange(1, MAX_ORDER + 1)))) + alpha = 1.0 / ((1 - kappa) * gamma) + c = state["h"] * alpha[order] + error_const = kappa * gamma + 1 / jnp.arange(1, MAX_ORDER + 2) + + state["kappa"] = kappa + state["gamma"] = gamma + state["alpha"] = alpha + state["c"] = c + state["error_const"] = error_const + + J = jac(y0, t0) + state["J"] = J + + state["LU"] = jax.scipy.linalg.lu_factor(state["M"] - c * J) + + state["U"] = _compute_R(order, 1) + state["psi"] = None + + state["n_function_evals"] = 2 + state["n_jacobian_evals"] = 1 + state["n_lu_decompositions"] = 1 + state["n_steps"] = 0 + + tuple_state = BDFState(*[state[k] for k in BDFInternalStates]) + y0, scale_y0 = _predict(tuple_state, D) + psi = _update_psi(tuple_state, D) + return tuple_state._replace(y0=y0, scale_y0=scale_y0, psi=psi) + + def _compute_R(order, factor): + """ + computes the R matrix with entries + given by the first equation on page 8 of [1] - return y_tilde, converged + This is used to update the differences matrix when step size h is varied + according to factor = h_{n+1} / h_n + Note that the U matrix also defined in the same section can be also be + found using factor = 1, which corresponds to R with a constant step size + """ + I = jnp.arange(1, MAX_ORDER + 1).reshape(-1, 1) + J = jnp.arange(1, MAX_ORDER + 1) + M = jnp.empty((MAX_ORDER + 1, MAX_ORDER + 1)) + M = jax.ops.index_update(M, jax.ops.index[1:, 1:], (I - 1 - factor * J) / I) + M = jax.ops.index_update(M, jax.ops.index[0], 1) + R = jnp.cumprod(M, axis=0) + + return R + + def _select_initial_conditions(fun, M, t0, y0, tol, scale_y0): + # identify algebraic variables as zeros on diagonal + algebraic_variables = onp.diag(M) == 0.0 + + # if all differentiable variables then return y0 (can use normal python if + # since M is static) + if not onp.any(algebraic_variables): + return y0, False + + # calculate consistent initial conditions via a newton on -J_a @ delta = f_a + # This follows this reference: + # + # Shampine, L. F., Reichelt, M. W., & Kierzenka, J. A. (1999). + # Solving index-1 DAEs in MATLAB and Simulink. SIAM review, 41(3), 538-552. -def _select_initial_step(atol, rtol, fun, t0, y0, f0, h0): - """ - Select a good initial step by stepping forward one step of forward euler, and - comparing the predicted state against that using the provided function. + # calculate fun_a, function of algebraic variables + def fun_a(y_a): + y_full = jax.ops.index_update(y0, algebraic_variables, y_a) + return fun(y_full, t0)[algebraic_variables] - Optimal step size based on the selected order is obtained using formula (4.12) - in [1] + y0_a = y0[algebraic_variables] + scale_y0_a = scale_y0[algebraic_variables] - References - ---------- - .. [1] E. Hairer, S. P. Norsett G. Wanner, "Solving Ordinary Differential - Equations I: Nonstiff Problems", Sec. II.4. - """ - scale = atol + jnp.abs(y0) * rtol - y1 = y0 + h0 * f0 - f1 = fun(y1, t0 + h0) - d2 = jnp.sqrt(jnp.mean(((f1 - f0) / scale) ** 2)) - order = 1 - h1 = h0 * d2 ** (-1 / (order + 1)) - return jnp.minimum(100 * h0, h1) + d = jnp.zeros(y0_a.shape[0], dtype=y0.dtype) + y_a = jnp.array(y0_a, copy=True) + # calculate neg jacobian of fun_a + J_a = jax.jacfwd(fun_a)(y_a) + LU = jax.scipy.linalg.lu_factor(-J_a) -def _predict(state, D): - """ - predict forward to new step (eq 2 in [1]) - """ - n = len(state.y0) - order = state.order - orders = jnp.repeat(jnp.arange(MAX_ORDER + 1).reshape(-1, 1), n, axis=1) - subD = jnp.where(orders <= order, D, 0) - y0 = jnp.sum(subD, axis=0) - scale_y0 = state.atol + state.rtol * jnp.abs(state.y0) - return y0, scale_y0 + converged = False + dy_norm_old = -1.0 + k = 0 + while_state = [k, converged, dy_norm_old, d, y_a] + def while_cond(while_state): + k, converged, _, _, _ = while_state + return (converged == False) * (k < ROOT_SOLVE_MAXITER) # noqa: E712 -def _update_psi(state, D): - """ - update psi term as defined in second equation on page 9 of [1] - """ - order = state.order - n = len(state.y0) - orders = jnp.arange(MAX_ORDER + 1) - subGamma = jnp.where(orders > 0, jnp.where(orders <= order, state.gamma, 0), 0) - orders = jnp.repeat(orders.reshape(-1, 1), n, axis=1) - subD = jnp.where(orders > 0, jnp.where(orders <= order, D, 0), 0) - psi = jnp.dot(subD.T, subGamma) * state.alpha[order] - return psi + def while_body(while_state): + k, converged, dy_norm_old, d, y_a = while_state + f_eval = fun_a(y_a) + dy = jax.scipy.linalg.lu_solve(LU, f_eval) + dy_norm = jnp.sqrt(jnp.mean((dy / scale_y0_a) ** 2)) + rate = dy_norm / dy_norm_old + d += dy + y_a = y0_a + d -def _update_difference_for_next_step(state, d): - """ - update of difference equations can be done efficiently - by reusing d and D. + # if converged then break out of iteration early + pred = dy_norm_old >= 0.0 + pred *= rate / (1 - rate) * dy_norm < tol + converged = (dy_norm == 0.0) + pred - From first equation on page 4 of [1]: - d = y_n - y^0_n = D^{k + 1} y_n + dy_norm_old = dy_norm - Standard backwards difference gives - D^{j + 1} y_n = D^{j} y_n - D^{j} y_{n - 1} + return [k + 1, converged, dy_norm_old, d, y_a] - Combining these gives the following algorithm - """ - order = state.order - D = state.D - D = jax.ops.index_update(D, jax.ops.index[order + 2], d - D[order + 1]) - D = jax.ops.index_update(D, jax.ops.index[order + 1], d) - i = order - while_state = [i, D] + k, converged, dy_norm_old, d, y_a = jax.lax.while_loop( + while_cond, while_body, while_state + ) + y_tilde = jax.ops.index_update(y0, algebraic_variables, y_a) - def while_cond(while_state): - i, _ = while_state - return i >= 0 + return y_tilde, converged - def while_body(while_state): - i, D = while_state - D = jax.ops.index_add(D, jax.ops.index[i], D[i + 1]) - i -= 1 - return [i, D] + def _select_initial_step(atol, rtol, fun, t0, y0, f0, h0): + """ + Select a good initial step by stepping forward one step of forward euler, and + comparing the predicted state against that using the provided function. - i, D = jax.lax.while_loop(while_cond, while_body, while_state) + Optimal step size based on the selected order is obtained using formula (4.12) + in [1] - return D + References + ---------- + .. [1] E. Hairer, S. P. Norsett G. Wanner, "Solving Ordinary Differential + Equations I: Nonstiff Problems", Sec. II.4. + """ + scale = atol + jnp.abs(y0) * rtol + y1 = y0 + h0 * f0 + f1 = fun(y1, t0 + h0) + d2 = jnp.sqrt(jnp.mean(((f1 - f0) / scale) ** 2)) + order = 1 + h1 = h0 * d2 ** (-1 / (order + 1)) + return jnp.minimum(100 * h0, h1) + + def _predict(state, D): + """ + predict forward to new step (eq 2 in [1]) + """ + n = len(state.y0) + order = state.order + orders = jnp.repeat(jnp.arange(MAX_ORDER + 1).reshape(-1, 1), n, axis=1) + subD = jnp.where(orders <= order, D, 0) + y0 = jnp.sum(subD, axis=0) + scale_y0 = state.atol + state.rtol * jnp.abs(state.y0) + return y0, scale_y0 + + def _update_psi(state, D): + """ + update psi term as defined in second equation on page 9 of [1] + """ + order = state.order + n = len(state.y0) + orders = jnp.arange(MAX_ORDER + 1) + subGamma = jnp.where(orders > 0, jnp.where(orders <= order, state.gamma, 0), 0) + orders = jnp.repeat(orders.reshape(-1, 1), n, axis=1) + subD = jnp.where(orders > 0, jnp.where(orders <= order, D, 0), 0) + psi = jnp.dot(subD.T, subGamma) * state.alpha[order] + return psi + + def _update_difference_for_next_step(state, d): + """ + update of difference equations can be done efficiently + by reusing d and D. + From first equation on page 4 of [1]: + d = y_n - y^0_n = D^{k + 1} y_n -def _update_step_size_and_lu(state, factor): - state = _update_step_size(state, factor) + Standard backwards difference gives + D^{j + 1} y_n = D^{j} y_n - D^{j} y_{n - 1} - # redo lu (c has changed) - LU = jax.scipy.linalg.lu_factor(state.M - state.c * state.J) - n_lu_decompositions = state.n_lu_decompositions + 1 + Combining these gives the following algorithm + """ + order = state.order + D = state.D + D = jax.ops.index_update(D, jax.ops.index[order + 2], d - D[order + 1]) + D = jax.ops.index_update(D, jax.ops.index[order + 1], d) + i = order + while_state = [i, D] - return state._replace(LU=LU, n_lu_decompositions=n_lu_decompositions) + def while_cond(while_state): + i, _ = while_state + return i >= 0 + def while_body(while_state): + i, D = while_state + D = jax.ops.index_add(D, jax.ops.index[i], D[i + 1]) + i -= 1 + return [i, D] -def _update_step_size(state, factor): - """ - If step size h is changed then also need to update the terms in - the first equation of page 9 of [1]: + i, D = jax.lax.while_loop(while_cond, while_body, while_state) - - constant c = h / (1-kappa) gamma_k term - - lu factorisation of (M - c * J) used in newton iteration (same equation) - - psi term - """ - order = state.order - h = state.h * factor - n_equal_steps = 0 - c = h * state.alpha[order] - - # update D using equations in section 3.2 of [1] - RU = _compute_R(order, factor).dot(state.U) - I = jnp.arange(0, MAX_ORDER + 1).reshape(-1, 1) - J = jnp.arange(0, MAX_ORDER + 1) - - # only update order+1, order+1 entries of D - RU = jnp.where( - jnp.logical_and(I <= order, J <= order), RU, jnp.identity(MAX_ORDER + 1) - ) - D = state.D - D = jnp.dot(RU.T, D) - # D = jax.ops.index_update(D, jax.ops.index[:order + 1], - # jnp.dot(RU.T, D[:order + 1])) + return D - # update psi (D has changed) - psi = _update_psi(state, D) + def _update_step_size_and_lu(state, factor): + state = _update_step_size(state, factor) - # update y0 (D has changed) - y0, scale_y0 = _predict(state, D) + # redo lu (c has changed) + LU = jax.scipy.linalg.lu_factor(state.M - state.c * state.J) + n_lu_decompositions = state.n_lu_decompositions + 1 - return state._replace( - n_equal_steps=n_equal_steps, h=h, c=c, D=D, psi=psi, y0=y0, scale_y0=scale_y0 - ) + return state._replace(LU=LU, n_lu_decompositions=n_lu_decompositions) + def _update_step_size(state, factor): + """ + If step size h is changed then also need to update the terms in + the first equation of page 9 of [1]: -def _update_jacobian(state, jac): - """ - we update the jacobian using J(t_{n+1}, y^0_{n+1}) - following the scipy bdf implementation rather than J(t_n, y_n) as per [1] - """ - J = jac(state.y0, state.t + state.h) - n_jacobian_evals = state.n_jacobian_evals + 1 - LU = jax.scipy.linalg.lu_factor(state.M - state.c * J) - n_lu_decompositions = state.n_lu_decompositions + 1 - return state._replace( - J=J, - n_jacobian_evals=n_jacobian_evals, - LU=LU, - n_lu_decompositions=n_lu_decompositions, - ) + - constant c = h / (1-kappa) gamma_k term + - lu factorisation of (M - c * J) used in newton iteration (same equation) + - psi term + """ + order = state.order + h = state.h * factor + n_equal_steps = 0 + c = h * state.alpha[order] + + # update D using equations in section 3.2 of [1] + RU = _compute_R(order, factor).dot(state.U) + I = jnp.arange(0, MAX_ORDER + 1).reshape(-1, 1) + J = jnp.arange(0, MAX_ORDER + 1) + + # only update order+1, order+1 entries of D + RU = jnp.where( + jnp.logical_and(I <= order, J <= order), RU, jnp.identity(MAX_ORDER + 1) + ) + D = state.D + D = jnp.dot(RU.T, D) + # D = jax.ops.index_update(D, jax.ops.index[:order + 1], + # jnp.dot(RU.T, D[:order + 1])) + + # update psi (D has changed) + psi = _update_psi(state, D) + + # update y0 (D has changed) + y0, scale_y0 = _predict(state, D) + + return state._replace( + n_equal_steps=n_equal_steps, + h=h, + c=c, + D=D, + psi=psi, + y0=y0, + scale_y0=scale_y0, + ) + def _update_jacobian(state, jac): + """ + we update the jacobian using J(t_{n+1}, y^0_{n+1}) + following the scipy bdf implementation rather than J(t_n, y_n) as per [1] + """ + J = jac(state.y0, state.t + state.h) + n_jacobian_evals = state.n_jacobian_evals + 1 + LU = jax.scipy.linalg.lu_factor(state.M - state.c * J) + n_lu_decompositions = state.n_lu_decompositions + 1 + return state._replace( + J=J, + n_jacobian_evals=n_jacobian_evals, + LU=LU, + n_lu_decompositions=n_lu_decompositions, + ) -def _newton_iteration(state, fun): - tol = state.newton_tol - c = state.c - psi = state.psi - y0 = state.y0 - LU = state.LU - M = state.M - scale_y0 = state.scale_y0 - t = state.t + state.h - d = jnp.zeros(y0.shape, dtype=y0.dtype) - y = jnp.array(y0, copy=True) - n_function_evals = state.n_function_evals - - converged = False - dy_norm_old = -1.0 - k = 0 - while_state = [k, converged, dy_norm_old, d, y, n_function_evals] - - def while_cond(while_state): - k, converged, _, _, _, _ = while_state - return (converged == False) * (k < NEWTON_MAXITER) # noqa: E712 - - def while_body(while_state): - k, converged, dy_norm_old, d, y, n_function_evals = while_state - f_eval = fun(y, t) - n_function_evals += 1 - b = c * f_eval - M @ (psi + d) - dy = jax.scipy.linalg.lu_solve(LU, b) - dy_norm = jnp.sqrt(jnp.mean((dy / scale_y0) ** 2)) - rate = dy_norm / dy_norm_old - - # if iteration is not going to converge in NEWTON_MAXITER - # (assuming the current rate), then abort - pred = rate >= 1 - pred += rate ** (NEWTON_MAXITER - k) / (1 - rate) * dy_norm > tol - pred *= dy_norm_old >= 0 - k += pred * (NEWTON_MAXITER - k - 1) - - d += dy - y = y0 + d - - # if converged then break out of iteration early - pred = dy_norm_old >= 0.0 - pred *= rate / (1 - rate) * dy_norm < tol - converged = (dy_norm == 0.0) + pred - - dy_norm_old = dy_norm - - return [k + 1, converged, dy_norm_old, d, y, n_function_evals] - - k, converged, dy_norm_old, d, y, n_function_evals = jax.lax.while_loop( - while_cond, while_body, while_state - ) - return converged, k, y, d, state._replace(n_function_evals=n_function_evals) + def _newton_iteration(state, fun): + tol = state.newton_tol + c = state.c + psi = state.psi + y0 = state.y0 + LU = state.LU + M = state.M + scale_y0 = state.scale_y0 + t = state.t + state.h + d = jnp.zeros(y0.shape, dtype=y0.dtype) + y = jnp.array(y0, copy=True) + n_function_evals = state.n_function_evals + + converged = False + dy_norm_old = -1.0 + k = 0 + while_state = [k, converged, dy_norm_old, d, y, n_function_evals] + + def while_cond(while_state): + k, converged, _, _, _, _ = while_state + return (converged == False) * (k < NEWTON_MAXITER) # noqa: E712 + + def while_body(while_state): + k, converged, dy_norm_old, d, y, n_function_evals = while_state + f_eval = fun(y, t) + n_function_evals += 1 + b = c * f_eval - M @ (psi + d) + dy = jax.scipy.linalg.lu_solve(LU, b) + dy_norm = jnp.sqrt(jnp.mean((dy / scale_y0) ** 2)) + rate = dy_norm / dy_norm_old + + # if iteration is not going to converge in NEWTON_MAXITER + # (assuming the current rate), then abort + pred = rate >= 1 + pred += rate ** (NEWTON_MAXITER - k) / (1 - rate) * dy_norm > tol + pred *= dy_norm_old >= 0 + k += pred * (NEWTON_MAXITER - k - 1) + + d += dy + y = y0 + d + + # if converged then break out of iteration early + pred = dy_norm_old >= 0.0 + pred *= rate / (1 - rate) * dy_norm < tol + converged = (dy_norm == 0.0) + pred + + dy_norm_old = dy_norm + + return [k + 1, converged, dy_norm_old, d, y, n_function_evals] + + k, converged, dy_norm_old, d, y, n_function_evals = jax.lax.while_loop( + while_cond, while_body, while_state + ) + return converged, k, y, d, state._replace(n_function_evals=n_function_evals) + def rms_norm(arg): + return jnp.sqrt(jnp.mean(arg ** 2)) -def rms_norm(arg): - return jnp.sqrt(jnp.mean(arg ** 2)) + def _prepare_next_step(state, d): + D = _update_difference_for_next_step(state, d) + psi = _update_psi(state, D) + y0, scale_y0 = _predict(state, D) + return state._replace(D=D, psi=psi, y0=y0, scale_y0=scale_y0) + def _prepare_next_step_order_change(state, d, y, n_iter): + order = state.order -def _prepare_next_step(state, d): - D = _update_difference_for_next_step(state, d) - psi = _update_psi(state, D) - y0, scale_y0 = _predict(state, D) - return state._replace(D=D, psi=psi, y0=y0, scale_y0=scale_y0) + D = _update_difference_for_next_step(state, d) + # Note: we are recalculating these from the while loop above, could re-use? + scale_y = state.atol + state.rtol * jnp.abs(y) + error = state.error_const[order] * d + error_norm = rms_norm(error / scale_y) + safety = 0.9 * (2 * NEWTON_MAXITER + 1) / (2 * NEWTON_MAXITER + n_iter) -def _prepare_next_step_order_change(state, d, y, n_iter): - order = state.order + # similar to the optimal step size factor we calculated above for the current + # order k, we need to calculate the optimal step size factors for orders + # k-1 and k+1. To do this, we note that the error = C_k * D^{k+1} y_n + error_m_norm = jnp.where( + order > 1, + rms_norm(state.error_const[order - 1] * D[order] / scale_y), + jnp.inf, + ) + error_p_norm = jnp.where( + order < MAX_ORDER, + rms_norm(state.error_const[order + 1] * D[order + 2] / scale_y), + jnp.inf, + ) - D = _update_difference_for_next_step(state, d) + error_norms = jnp.array([error_m_norm, error_norm, error_p_norm]) + factors = error_norms ** (-1 / (jnp.arange(3) + order)) + + # now we have the three factors for orders k-1, k and k+1, pick the maximum in + # order to maximise the resultant step size + max_index = jnp.argmax(factors) + order += max_index - 1 + + factor = jnp.minimum(MAX_FACTOR, safety * factors[max_index]) + + new_state = _update_step_size_and_lu(state._replace(D=D, order=order), factor) + return new_state + + def _bdf_step(state, fun, jac): + # print('bdf_step', state.t, state.h) + # we will try and use the old jacobian unless convergence of newton iteration + # fails + updated_jacobian = False + # initialise step size and try to make the step, + # iterate, reducing step size until error is in bounds + step_accepted = False + y = jnp.empty_like(state.y0) + d = jnp.empty_like(state.y0) + n_iter = -1 + + # loop until step is accepted + while_state = [state, step_accepted, updated_jacobian, y, d, n_iter] + + def while_cond(while_state): + _, step_accepted, _, _, _, _ = while_state + return step_accepted == False # noqa: E712 + + def while_body(while_state): + state, step_accepted, updated_jacobian, y, d, n_iter = while_state + + # solve BDF equation using y0 as starting point + converged, n_iter, y, d, state = _newton_iteration(state, fun) + not_converged = converged == False # noqa: E712 + + # newton iteration did not converge, but jacobian has already been + # evaluated so reduce step size by 0.3 (as per [1]) and try again + state = tree_multimap( + partial(jnp.where, not_converged * updated_jacobian), + _update_step_size_and_lu(state, 0.3), + state, + ) - # Note: we are recalculating these from the while loop above, could re-use? - scale_y = state.atol + state.rtol * jnp.abs(y) - error = state.error_const[order] * d - error_norm = rms_norm(error / scale_y) - safety = 0.9 * (2 * NEWTON_MAXITER + 1) / (2 * NEWTON_MAXITER + n_iter) + # if not_converged * updated_jacobian: + # print('not converged, update step size by 0.3') + # if not_converged * (updated_jacobian == False): + # print('not converged, update jacobian') - # similar to the optimal step size factor we calculated above for the current - # order k, we need to calculate the optimal step size factors for orders - # k-1 and k+1. To do this, we note that the error = C_k * D^{k+1} y_n - error_m_norm = jnp.where( - order > 1, rms_norm(state.error_const[order - 1] * D[order] / scale_y), jnp.inf - ) - error_p_norm = jnp.where( - order < MAX_ORDER, - rms_norm(state.error_const[order + 1] * D[order + 2] / scale_y), - jnp.inf, - ) + # if not converged and jacobian not updated, then update the jacobian and + # try again + (state, updated_jacobian) = tree_multimap( + partial( + jnp.where, not_converged * (updated_jacobian == False) # noqa: E712 + ), + (_update_jacobian(state, jac), True), + (state, False + updated_jacobian), + ) - error_norms = jnp.array([error_m_norm, error_norm, error_p_norm]) - factors = error_norms ** (-1 / (jnp.arange(3) + order)) + safety = 0.9 * (2 * NEWTON_MAXITER + 1) / (2 * NEWTON_MAXITER + n_iter) + scale_y = state.atol + state.rtol * jnp.abs(y) - # now we have the three factors for orders k-1, k and k+1, pick the maximum in - # order to maximise the resultant step size - max_index = jnp.argmax(factors) - order += max_index - 1 + # combine eq 3, 4 and 6 from [1] to obtain error + # Note that error = C_k * h^{k+1} y^{k+1} + # and d = D^{k+1} y_{n+1} \approx h^{k+1} y^{k+1} + error = state.error_const[state.order] * d - factor = jnp.minimum(MAX_FACTOR, safety * factors[max_index]) + error_norm = rms_norm(error / scale_y) - new_state = _update_step_size_and_lu(state._replace(D=D, order=order), factor) - return new_state + # calculate optimal step size factor as per eq 2.46 of [2] + factor = jnp.maximum( + MIN_FACTOR, safety * error_norm ** (-1 / (state.order + 1)) + ) + # if converged * (error_norm > 1): + # print( + # "converged, but error is too large", + # error_norm, + # factor, + # d, + # scale_y, + # ) + + (state, step_accepted) = tree_multimap( + partial(jnp.where, converged * (error_norm > 1)), # noqa: E712 + (_update_step_size_and_lu(state, factor), False), + (state, converged), + ) -def _bdf_step(state, fun, jac): - # print('bdf_step', state.t, state.h) - # we will try and use the old jacobian unless convergence of newton iteration - # fails - updated_jacobian = False - # initialise step size and try to make the step, - # iterate, reducing step size until error is in bounds - step_accepted = False - y = jnp.empty_like(state.y0) - d = jnp.empty_like(state.y0) - n_iter = -1 + return [state, step_accepted, updated_jacobian, y, d, n_iter] - # loop until step is accepted - while_state = [state, step_accepted, updated_jacobian, y, d, n_iter] + state, step_accepted, updated_jacobian, y, d, n_iter = jax.lax.while_loop( + while_cond, while_body, while_state + ) - def while_cond(while_state): - _, step_accepted, _, _, _, _ = while_state - return step_accepted == False # noqa: E712 + # take the accepted step + n_steps = state.n_steps + 1 + t = state.t + state.h - def while_body(while_state): - state, step_accepted, updated_jacobian, y, d, n_iter = while_state + # a change in order is only done after running at order k for k + 1 steps + # (see page 83 of [2]) + n_equal_steps = state.n_equal_steps + 1 - # solve BDF equation using y0 as starting point - converged, n_iter, y, d, state = _newton_iteration(state, fun) - not_converged = converged == False # noqa: E712 + state = state._replace(n_equal_steps=n_equal_steps, t=t, n_steps=n_steps) - # newton iteration did not converge, but jacobian has already been - # evaluated so reduce step size by 0.3 (as per [1]) and try again state = tree_multimap( - partial(jnp.where, not_converged * updated_jacobian), - _update_step_size_and_lu(state, 0.3), - state, - ) - - # if not_converged * updated_jacobian: - # print('not converged, update step size by 0.3') - # if not_converged * (updated_jacobian == False): - # print('not converged, update jacobian') - - # if not converged and jacobian not updated, then update the jacobian and try - # again - (state, updated_jacobian) = tree_multimap( - partial( - jnp.where, not_converged * (updated_jacobian == False) # noqa: E712 - ), - (_update_jacobian(state, jac), True), - (state, False + updated_jacobian), + partial(jnp.where, n_equal_steps < state.order + 1), + _prepare_next_step(state, d), + _prepare_next_step_order_change(state, d, y, n_iter), ) - safety = 0.9 * (2 * NEWTON_MAXITER + 1) / (2 * NEWTON_MAXITER + n_iter) - scale_y = state.atol + state.rtol * jnp.abs(y) - - # combine eq 3, 4 and 6 from [1] to obtain error - # Note that error = C_k * h^{k+1} y^{k+1} - # and d = D^{k+1} y_{n+1} \approx h^{k+1} y^{k+1} - error = state.error_const[state.order] * d + return state - error_norm = rms_norm(error / scale_y) + def _bdf_interpolate(state, t_eval): + """ + interpolate solution at time values t* where t-h < t* < t - # calculate optimal step size factor as per eq 2.46 of [2] - factor = jnp.maximum( - MIN_FACTOR, safety * error_norm ** (-1 / (state.order + 1)) + definition of the interpolating polynomial can be found on page 7 of [1] + """ + order = state.order + t = state.t + h = state.h + D = state.D + j = 0 + time_factor = 1.0 + order_summation = D[0] + while_state = [j, time_factor, order_summation] + + def while_cond(while_state): + j, _, _ = while_state + return j < order + + def while_body(while_state): + j, time_factor, order_summation = while_state + time_factor *= (t_eval - (t - h * j)) / (h * (1 + j)) + order_summation += D[j + 1] * time_factor + j += 1 + return [j, time_factor, order_summation] + + j, time_factor, order_summation = jax.lax.while_loop( + while_cond, while_body, while_state ) + return order_summation + + def block_diag(lst): + def block_fun(i, j, Ai, Aj): + if i == j: + return Ai + else: + return onp.zeros( + ( + Ai.shape[0] if Ai.ndim > 1 else 1, + Aj.shape[1] if Aj.ndim > 1 else 1, + ), + dtype=Ai.dtype, + ) + + blocks = [ + [block_fun(i, j, Ai, Aj) for j, Aj in enumerate(lst)] + for i, Ai in enumerate(lst) + ] + + return onp.block(blocks) + + # NOTE: the code below (except the docstring on jax_bdf_integrate and other minor + # edits), has been modified from the JAX library at https://github.com/google/jax. + # The main difference is the addition of support for semi-explicit dae index 1 + # problems via the addition of a mass matrix. + # This is under an Apache license, a short form of which is given here: + # + # Copyright 2018 Google LLC + # + # Licensed under the Apache License, Version 2.0 (the "License"); you may not use + # this file except in compliance with the License. + # You may obtain a copy of the License at + # + # https://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software distributed + # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + # CONDITIONS OF ANY KIND, either express or implied. See the License for the + # specific language governing permissions and limitations under the License. - # if converged * (error_norm > 1): - # print('converged, but error is too large',error_norm, factor, d, scale_y) + def flax_while_loop(cond_fun, body_fun, init_val): # pragma: no cover + """ + for debugging purposes, use this instead of jax.lax.while_loop + """ + val = init_val + while cond_fun(val): + val = body_fun(val) + return val - (state, step_accepted) = tree_multimap( - partial(jnp.where, converged * (error_norm > 1)), # noqa: E712 - (_update_step_size_and_lu(state, factor), False), - (state, converged), - ) + def flax_fori_loop(start, stop, body_fun, init_val): # pragma: no cover + """ + for debugging purposes, use this instead of jax.lax.fori_loop + """ + val = init_val + for i in range(start, stop): + val = body_fun(i, val) + return val - return [state, step_accepted, updated_jacobian, y, d, n_iter] + def flax_scan(f, init, xs, length=None): # pragma: no cover + """ + for debugging purposes, use this instead of jax.lax.scan + """ + if xs is None: + xs = [None] * length + carry = init + ys = [] + for x in xs: + carry, y = f(carry, x) + ys.append(y) + return carry, onp.stack(ys) + + @jax.partial(gnool_jit, static_array_argnums=(1,), static_argnums=(0, 2, 3)) + def _bdf_odeint_wrapper(func, mass, rtol, atol, y0, ts, *args): + y0, unravel = ravel_pytree(y0) + func = ravel_first_arg(func, unravel) + out = _bdf_odeint(func, mass, rtol, atol, y0, ts, *args) + return jax.vmap(unravel)(out) + + def _bdf_odeint_fwd(func, mass, rtol, atol, y0, ts, *args): + ys = _bdf_odeint(func, mass, rtol, atol, y0, ts, *args) + return ys, (ys, ts, args) + + def _bdf_odeint_rev(func, mass, rtol, atol, res, g): + ys, ts, args = res + + def aug_dynamics(augmented_state, t, *args): + """Original system augmented with vjp_y, vjp_t and vjp_args.""" + y, y_bar, *_ = augmented_state + # `t` here is negative time, so we need to negate again to get back to + # normal time. See the `odeint` invocation in `scan_fun` below. + y_dot, vjpfun = jax.vjp(func, y, -t, *args) + + # Adjoint equations for semi-explicit dae index 1 system from + # + # [1] Cao, Y., Li, S., Petzold, L., & Serban, R. (2003). Adjoint sensitivity + # analysis for differential-algebraic equations: The adjoint DAE system and + # its numerical solution. + # SIAM journal on scientific computing, 24(3), 1076-1089. + # + # y_bar_dot_d = -J_dd^T y_bar_d - J_ad^T y_bar_a + # 0 = J_da^T y_bar_d + J_aa^T y_bar_d + + y_bar_dot, *rest = vjpfun(y_bar) + + return (-y_dot, y_bar_dot, *rest) + + algebraic_variables = onp.diag(mass) == 0.0 + differentiable_variables = algebraic_variables == False # noqa: E712 + mass_is_I = onp.array_equal(mass, onp.eye(mass.shape[0])) + is_dae = onp.any(algebraic_variables) + + if not mass_is_I: + M_dd = mass[onp.ix_(differentiable_variables, differentiable_variables)] + LU_invM_dd = jax.scipy.linalg.lu_factor(M_dd) + + def initialise(g0, y0, t0): + # [1] gives init conditions for y_bar_a = g_d - J_ad^T (J_aa^T)^-1 g_a + if mass_is_I: + y_bar = g0 + elif is_dae: + J = jax.jacfwd(func)(y0, t0, *args) + + # boolean arguments not implemented in jnp.ix_ + J_aa = J[onp.ix_(algebraic_variables, algebraic_variables)] + J_ad = J[onp.ix_(algebraic_variables, differentiable_variables)] + LU = jax.scipy.linalg.lu_factor(J_aa) + g0_a = g0[algebraic_variables] + invJ_aa = jax.scipy.linalg.lu_solve(LU, g0_a) + y_bar = jax.ops.index_update( + g0, + differentiable_variables, + jax.scipy.linalg.lu_solve(LU_invM_dd, g0_a - J_ad @ invJ_aa), + ) + else: + y_bar = jax.scipy.linalg.lu_solve(LU_invM_dd, g0) + return y_bar + + y_bar = initialise(g[-1], ys[-1], ts[-1]) + ts_bar = [] + t0_bar = 0.0 + + def arg_to_identity(arg): + return onp.identity(arg.shape[0] if arg.ndim > 0 else 1, dtype=arg.dtype) + + def arg_dicts_to_values(args): + """ + Note:JAX puts in empty arrays into args for some reason, we remove them here + """ + return sum((tuple(b.values()) for b in args if isinstance(b, dict)), ()) + + aug_mass = (mass, mass, onp.array(1.0)) + arg_dicts_to_values( + tree_map(arg_to_identity, args) + ) - state, step_accepted, updated_jacobian, y, d, n_iter = jax.lax.while_loop( - while_cond, while_body, while_state - ) + def scan_fun(carry, i): + y_bar, t0_bar, args_bar = carry + # Compute effect of moving measurement time + t_bar = jnp.dot(func(ys[i], ts[i], *args), g[i]) + t0_bar = t0_bar - t_bar + # Run augmented system backwards to previous observation + _, y_bar, t0_bar, args_bar = jax_bdf_integrate( + aug_dynamics, + (ys[i], y_bar, t0_bar, args_bar), + jnp.array([-ts[i], -ts[i - 1]]), + *args, + mass=aug_mass, + rtol=rtol, + atol=atol, + ) + y_bar, t0_bar, args_bar = tree_map( + op.itemgetter(1), (y_bar, t0_bar, args_bar) + ) + # Add gradient from current output + y_bar = y_bar + initialise(g[i - 1], ys[i - 1], ts[i - 1]) + return (y_bar, t0_bar, args_bar), t_bar - # take the accepted step - n_steps = state.n_steps + 1 - t = state.t + state.h + init_carry = (y_bar, t0_bar, tree_map(jnp.zeros_like, args)) + (y_bar, t0_bar, args_bar), rev_ts_bar = jax.lax.scan( + scan_fun, init_carry, jnp.arange(len(ts) - 1, 0, -1) + ) + ts_bar = jnp.concatenate([jnp.array([t0_bar]), rev_ts_bar[::-1]]) + return (y_bar, ts_bar, *args_bar) - # a change in order is only done after running at order k for k + 1 steps - # (see page 83 of [2]) - n_equal_steps = state.n_equal_steps + 1 + _bdf_odeint.defvjp(_bdf_odeint_fwd, _bdf_odeint_rev) - state = state._replace(n_equal_steps=n_equal_steps, t=t, n_steps=n_steps) + @cache() + def closure_convert(fun, in_tree, in_avals): + if config.omnistaging_enabled: + wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals) + else: + in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals] + wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) + with core.initial_style_staging(): # type: ignore + jaxpr, _, consts = pe.trace_to_jaxpr( + wrapped_fun, in_pvals, instantiate=True, stage_out=False + ) # type: ignore + out_tree = out_tree() - state = tree_multimap( - partial(jnp.where, n_equal_steps < state.order + 1), - _prepare_next_step(state, d), - _prepare_next_step_order_change(state, d, y, n_iter), - ) + # We only want to closure convert for constants with respect to which we're + # differentiating. As a proxy for that, we hoist consts with float dtype. + # TODO(mattjj): revise this approach + def is_float(c): + return dtypes.issubdtype(dtypes.dtype(c), jnp.inexact) - return state + (closure_consts, hoisted_consts), merge = partition_list(is_float, consts) + num_consts = len(hoisted_consts) + def converted_fun(y, t, *hconsts_args): + hoisted_consts, args = split_list(hconsts_args, [num_consts]) + consts = merge(closure_consts, hoisted_consts) + all_args, _ = tree_flatten((y, t, *args)) + out_flat = core.eval_jaxpr(jaxpr, consts, *all_args) + return tree_unflatten(out_tree, out_flat) -def _bdf_interpolate(state, t_eval): - """ - interpolate solution at time values t* where t-h < t* < t + return converted_fun, hoisted_consts - definition of the interpolating polynomial can be found on page 7 of [1] - """ - order = state.order - t = state.t - h = state.h - D = state.D - j = 0 - time_factor = 1.0 - order_summation = D[0] - while_state = [j, time_factor, order_summation] - - def while_cond(while_state): - j, _, _ = while_state - return j < order - - def while_body(while_state): - j, time_factor, order_summation = while_state - time_factor *= (t_eval - (t - h * j)) / (h * (1 + j)) - order_summation += D[j + 1] * time_factor - j += 1 - return [j, time_factor, order_summation] - - j, time_factor, order_summation = jax.lax.while_loop( - while_cond, while_body, while_state - ) - return order_summation + def partition_list(choice, lst): + out = [], [] + which = [out[choice(elt)].append(elt) or choice(elt) for elt in lst] + def merge(l1, l2): + i1, i2 = iter(l1), iter(l2) + return [next(i2 if snd else i1) for snd in which] -def block_diag(lst): - def block_fun(i, j, Ai, Aj): - if i == j: - return Ai - else: - return onp.zeros( - ( - Ai.shape[0] if Ai.ndim > 1 else 1, - Aj.shape[1] if Aj.ndim > 1 else 1, - ), - dtype=Ai.dtype, - ) + return out, merge - blocks = [ - [block_fun(i, j, Ai, Aj) for j, Aj in enumerate(lst)] - for i, Ai in enumerate(lst) - ] + def abstractify(x): + return core.raise_to_shaped(core.get_aval(x)) - return onp.block(blocks) + def ravel_first_arg(f, unravel): + return ravel_first_arg_(lu.wrap_init(f), unravel).call_wrapped - -# NOTE: the code below (except the docstring on jax_bdf_integrate and other minor -# edits), has been modified from the JAX library at https://github.com/google/jax. -# The main difference is the addition of support for semi-explicit dae index 1 problems -# via the addition of a mass matrix. -# This is under an Apache license, a short form of which is given here: -# -# Copyright 2018 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this -# file except in compliance with the License. You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software distributed under -# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific language -# governing permissions and limitations under the License. + @lu.transformation + def ravel_first_arg_(unravel, y_flat, *args): + y = unravel(y_flat) + ans = yield (y,) + args, {} + ans_flat, _ = ravel_pytree(ans) + yield ans_flat def jax_bdf_integrate(func, y0, t_eval, *args, rtol=1e-6, atol=1e-6, mass=None): @@ -828,15 +1028,19 @@ def jax_bdf_integrate(func, y0, t_eval, *args, rtol=1e-6, atol=1e-6, mass=None): References ---------- .. [1] L. F. Shampine, M. W. Reichelt, "THE MATLAB ODE SUITE", SIAM J. SCI. - COMPUTE., Vol. 18, No. 1, pp. 1-22, January 1997. + COMPUTE., Vol. 18, No. 1, pp. 1-22, January 1997. .. [2] G. D. Byrne, A. C. Hindmarsh, "A Polyalgorithm for the Numerical - Solution of Ordinary Differential Equations", ACM Transactions on - Mathematical Software, Vol. 1, No. 1, pp. 71-96, March 1975. + Solution of Ordinary Differential Equations", ACM Transactions on + Mathematical Software, Vol. 1, No. 1, pp. 71-96, March 1975. .. [3] Virtanen, P., Gommers, R., Oliphant, T. E., Haberland, M., Reddy, - T., Cournapeau, D., ... & van der Walt, S. J. (2020). SciPy 1.0: - fundamental algorithms for scientific computing in Python. - Nature methods, 17(3), 261-272. + T., Cournapeau, D., ... & van der Walt, S. J. (2020). SciPy 1.0: + fundamental algorithms for scientific computing in Python. + Nature methods, 17(3), 261-272. """ + if not pybamm.have_jax(): + raise ModuleNotFoundError( + "Jax is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501 + ) def _check_arg(arg): if not isinstance(arg, core.Tracer) and not core.valid_jaxtype(arg): @@ -854,213 +1058,3 @@ def _check_arg(arg): else: mass = block_diag(tree_flatten(mass)[0]) return _bdf_odeint_wrapper(converted, mass, rtol, atol, y0, t_eval, *consts, *args) - - -def flax_while_loop(cond_fun, body_fun, init_val): # pragma: no cover - """ - for debugging purposes, use this instead of jax.lax.while_loop - """ - val = init_val - while cond_fun(val): - val = body_fun(val) - return val - - -def flax_fori_loop(start, stop, body_fun, init_val): # pragma: no cover - """ - for debugging purposes, use this instead of jax.lax.fori_loop - """ - val = init_val - for i in range(start, stop): - val = body_fun(i, val) - return val - - -def flax_scan(f, init, xs, length=None): # pragma: no cover - """ - for debugging purposes, use this instead of jax.lax.scan - """ - if xs is None: - xs = [None] * length - carry = init - ys = [] - for x in xs: - carry, y = f(carry, x) - ys.append(y) - return carry, onp.stack(ys) - - -@jax.partial(gnool_jit, - static_array_argnums=(1,), static_argnums=(0, 2, 3)) -def _bdf_odeint_wrapper(func, mass, rtol, atol, y0, ts, *args): - y0, unravel = ravel_pytree(y0) - func = ravel_first_arg(func, unravel) - out = _bdf_odeint(func, mass, rtol, atol, y0, ts, *args) - return jax.vmap(unravel)(out) - - -def _bdf_odeint_fwd(func, mass, rtol, atol, y0, ts, *args): - ys = _bdf_odeint(func, mass, rtol, atol, y0, ts, *args) - return ys, (ys, ts, args) - - -def _bdf_odeint_rev(func, mass, rtol, atol, res, g): - ys, ts, args = res - - def aug_dynamics(augmented_state, t, *args): - """Original system augmented with vjp_y, vjp_t and vjp_args.""" - y, y_bar, *_ = augmented_state - # `t` here is negative time, so we need to negate again to get back to - # normal time. See the `odeint` invocation in `scan_fun` below. - y_dot, vjpfun = jax.vjp(func, y, -t, *args) - - # Adjoint equations for semi-explicit dae index 1 system from - # - # [1] Cao, Y., Li, S., Petzold, L., & Serban, R. (2003). Adjoint sensitivity - # analysis for differential-algebraic equations: The adjoint DAE system and its - # numerical solution. SIAM journal on scientific computing, 24(3), 1076-1089. - # - # y_bar_dot_d = -J_dd^T y_bar_d - J_ad^T y_bar_a - # 0 = J_da^T y_bar_d + J_aa^T y_bar_d - - y_bar_dot, *rest = vjpfun(y_bar) - - return (-y_dot, y_bar_dot, *rest) - - algebraic_variables = onp.diag(mass) == 0.0 - differentiable_variables = algebraic_variables == False # noqa: E712 - mass_is_I = onp.array_equal(mass, onp.eye(mass.shape[0])) - is_dae = onp.any(algebraic_variables) - - if not mass_is_I: - M_dd = mass[onp.ix_(differentiable_variables, differentiable_variables)] - LU_invM_dd = jax.scipy.linalg.lu_factor(M_dd) - - def initialise(g0, y0, t0): - # [1] gives init conditions for y_bar_a = g_d - J_ad^T (J_aa^T)^-1 g_a - if mass_is_I: - y_bar = g0 - elif is_dae: - J = jax.jacfwd(func)(y0, t0, *args) - - # boolean arguments not implemented in jnp.ix_ - J_aa = J[onp.ix_(algebraic_variables, algebraic_variables)] - J_ad = J[onp.ix_(algebraic_variables, differentiable_variables)] - LU = jax.scipy.linalg.lu_factor(J_aa) - g0_a = g0[algebraic_variables] - invJ_aa = jax.scipy.linalg.lu_solve(LU, g0_a) - y_bar = jax.ops.index_update( - g0, - differentiable_variables, - jax.scipy.linalg.lu_solve(LU_invM_dd, g0_a - J_ad @ invJ_aa), - ) - else: - y_bar = jax.scipy.linalg.lu_solve(LU_invM_dd, g0) - return y_bar - - y_bar = initialise(g[-1], ys[-1], ts[-1]) - ts_bar = [] - t0_bar = 0.0 - - def arg_to_identity(arg): - return onp.identity(arg.shape[0] if arg.ndim > 0 else 1, dtype=arg.dtype) - - def arg_dicts_to_values(args): - """ - Note: JAX puts in empty arrays into args for some reason, we remove them here - """ - return sum((tuple(b.values()) for b in args if isinstance(b, dict)), ()) - - aug_mass = (mass, mass, onp.array(1.0)) + arg_dicts_to_values( - tree_map(arg_to_identity, args) - ) - - def scan_fun(carry, i): - y_bar, t0_bar, args_bar = carry - # Compute effect of moving measurement time - t_bar = jnp.dot(func(ys[i], ts[i], *args), g[i]) - t0_bar = t0_bar - t_bar - # Run augmented system backwards to previous observation - _, y_bar, t0_bar, args_bar = jax_bdf_integrate( - aug_dynamics, - (ys[i], y_bar, t0_bar, args_bar), - jnp.array([-ts[i], -ts[i - 1]]), - *args, - mass=aug_mass, - rtol=rtol, - atol=atol, - ) - y_bar, t0_bar, args_bar = tree_map(op.itemgetter(1), (y_bar, t0_bar, args_bar)) - # Add gradient from current output - y_bar = y_bar + initialise(g[i - 1], ys[i - 1], ts[i - 1]) - return (y_bar, t0_bar, args_bar), t_bar - - init_carry = (y_bar, t0_bar, tree_map(jnp.zeros_like, args)) - (y_bar, t0_bar, args_bar), rev_ts_bar = jax.lax.scan( - scan_fun, init_carry, jnp.arange(len(ts) - 1, 0, -1) - ) - ts_bar = jnp.concatenate([jnp.array([t0_bar]), rev_ts_bar[::-1]]) - return (y_bar, ts_bar, *args_bar) - - -_bdf_odeint.defvjp(_bdf_odeint_fwd, _bdf_odeint_rev) - - -@cache() -def closure_convert(fun, in_tree, in_avals): - if config.omnistaging_enabled: - wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals) - else: - in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals] - wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) - with core.initial_style_staging(): # type: ignore - jaxpr, _, consts = pe.trace_to_jaxpr( - wrapped_fun, in_pvals, instantiate=True, stage_out=False - ) # type: ignore - out_tree = out_tree() - - # We only want to closure convert for constants with respect to which we're - # differentiating. As a proxy for that, we hoist consts with float dtype. - # TODO(mattjj): revise this approach - def is_float(c): - return dtypes.issubdtype(dtypes.dtype(c), jnp.inexact) - - (closure_consts, hoisted_consts), merge = partition_list(is_float, consts) - num_consts = len(hoisted_consts) - - def converted_fun(y, t, *hconsts_args): - hoisted_consts, args = split_list(hconsts_args, [num_consts]) - consts = merge(closure_consts, hoisted_consts) - all_args, _ = tree_flatten((y, t, *args)) - out_flat = core.eval_jaxpr(jaxpr, consts, *all_args) - return tree_unflatten(out_tree, out_flat) - - return converted_fun, hoisted_consts - - -def partition_list(choice, lst): - out = [], [] - which = [out[choice(elt)].append(elt) or choice(elt) for elt in lst] - - def merge(l1, l2): - i1, i2 = iter(l1), iter(l2) - return [next(i2 if snd else i1) for snd in which] - - return out, merge - - -def abstractify(x): - return core.raise_to_shaped(core.get_aval(x)) - - -def ravel_first_arg(f, unravel): - return ravel_first_arg_(lu.wrap_init(f), unravel).call_wrapped - - -@lu.transformation -def ravel_first_arg_(unravel, y_flat, *args): - y = unravel(y_flat) - ans = yield (y,) + args, {} - ans_flat, _ = ravel_pytree(ans) - yield ans_flat diff --git a/pybamm/solvers/jax_solver.py b/pybamm/solvers/jax_solver.py index e0124294e9..a0d3008207 100644 --- a/pybamm/solvers/jax_solver.py +++ b/pybamm/solvers/jax_solver.py @@ -1,12 +1,14 @@ # # Solver class using Scipy's adaptive time stepper # +import numpy as onp + import pybamm -import jax -from jax.experimental.ode import odeint -import jax.numpy as jnp -import numpy as onp +if pybamm.have_jax(): + import jax + import jax.numpy as jnp + from jax.experimental.ode import odeint class JaxSolver(pybamm.BaseSolver): @@ -56,6 +58,11 @@ def __init__( extrap_tol=0, extra_options=None, ): + if not pybamm.have_jax(): + raise ModuleNotFoundError( + "Jax is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501 + ) + # note: bdf solver itself calculates consistent initial conditions so can set # root_method to none, allow user to override this behavior super().__init__( diff --git a/pybamm/util.py b/pybamm/util.py index cab5d995b3..a9a15ea1ea 100644 --- a/pybamm/util.py +++ b/pybamm/util.py @@ -4,16 +4,21 @@ # The code in this file is adapted from Pints # (see https://github.com/pints-team/pints) # -import importlib -import numpy as np +import importlib.util +import numbers import os -import timeit import pathlib import pickle -import pybamm -import numbers +import subprocess +import sys +import timeit import warnings from collections import defaultdict +from platform import system + +import numpy as np + +import pybamm def root_dir(): @@ -336,3 +341,21 @@ def get_parameters_filepath(path): return path else: return os.path.join(pybamm.__path__[0], path) + + +def have_jax(): + """Check if jax is installed""" + return importlib.util.find_spec("jax") is not None + + +def install_jax(): + """Install jax, jaxlib""" + jax_version = "jax==0.2.12" + jaxlib_version = "jaxlib==0.1.70" + + if system() == "Windows": + raise NotImplementedError("Jax is not available on Windows") + else: + subprocess.check_call( + [sys.executable, "-m", "pip", "install", jax_version, jaxlib_version] + ) diff --git a/requirements.txt b/requirements.txt index 7753d6ab6f..b23c8f2cf9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -numpy >= 1.16 +numpy >= 1.16 scipy >= 1.3 pandas >= 0.24 anytree >= 2.4.3 @@ -6,11 +6,9 @@ autograd >= 1.2 scikit-fem >= 0.2.0 casadi >= 3.5.0 imageio>=2.9.0 -jax==0.2.12 -jaxlib==0.1.70 jupyter # For example notebooks pybtex -sympy==1.8 +sympy==1.9 # Note: Matplotlib is loaded for debug plots but to ensure pybamm runs # on systems without an attached display it should never be imported # outside of plot() methods. diff --git a/run-tests.py b/run-tests.py index d2be5ffee0..6b220fd3ea 100755 --- a/run-tests.py +++ b/run-tests.py @@ -104,35 +104,20 @@ def run_doc_tests(): sys.exit(ret) -def run_notebook_and_scripts(skip_slow_books=False, executable="python"): +def run_notebook_and_scripts(executable="python"): """ Runs Jupyter notebook tests. Exits if they fail. """ - # Ignore slow books? - ignore_list = [] - if skip_slow_books and os.path.isfile(".slow-books"): - with open(".slow-books", "r") as f: - for line in f.readlines(): - line = line.strip() - if not line or line[:1] == "#": - continue - if not line.startswith("examples/"): - line = "examples/" + line - if not line.endswith(".ipynb"): - line = line + ".ipynb" - if not os.path.isfile(line): - raise Exception("Slow notebook note found: " + line) - ignore_list.append(line) # Scan and run print("Testing notebooks and scripts with executable `" + str(executable) + "`") - if not scan_for_nb_and_scripts("examples", True, executable, ignore_list): + if not scan_for_nb_and_scripts("examples", True, executable): print("\nErrors encountered in notebooks") sys.exit(1) print("\nOK") -def scan_for_nb_and_scripts(root, recursive=True, executable="python", ignore_list=[]): +def scan_for_nb_and_scripts(root, recursive=True, executable="python"): """ Scans for, and tests, all notebooks and scripts in a directory. """ @@ -142,9 +127,6 @@ def scan_for_nb_and_scripts(root, recursive=True, executable="python", ignore_li # Scan path for filename in os.listdir(root): path = os.path.join(root, filename) - if path in ignore_list: - print("Skipping slow book: " + path) - continue # Recurse into subdirectories if recursive and os.path.isdir(path): @@ -354,11 +336,6 @@ def export_notebook(ipath, opath): parser.add_argument( "--examples", action="store_true", - help="Test only the fast Jupyter notebooks and scripts in `examples`.", - ) - parser.add_argument( - "--allexamples", - action="store_true", help="Test all Jupyter notebooks and scripts in `examples`.", ) parser.add_argument( @@ -416,12 +393,9 @@ def export_notebook(ipath, opath): has_run = True run_doc_tests() # Notebook tests - if args.allexamples: - has_run = True - run_notebook_and_scripts(executable=interpreter) elif args.examples: has_run = True - run_notebook_and_scripts(True, interpreter) + run_notebook_and_scripts(interpreter) if args.debook: has_run = True export_notebook(*args.debook) diff --git a/setup.py b/setup.py index 90ba21485b..55fd14a2bc 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ import logging import subprocess from pathlib import Path -from platform import system, version +from platform import system import wheel.bdist_wheel as orig import site import shutil @@ -162,11 +162,6 @@ def compile_KLU(): idaklu_ext = Extension("pybamm.solvers.idaklu", ["pybamm/solvers/c_solvers/idaklu.cpp"]) ext_modules = [idaklu_ext] if compile_KLU() else [] -jax_dependencies = [] -if not (system() == "Windows" or (system() == "Darwin" and "ARM64" in version())): - jax_dependencies = ["jax==0.2.12", "jaxlib==0.1.70"] - - # Load text for description and license with open("README.md", encoding="utf-8") as f: readme = f.read() @@ -198,10 +193,9 @@ def compile_KLU(): "scikit-fem>=0.2.0", "casadi>=3.5.0", "imageio>=2.9.0", - *jax_dependencies, "jupyter", # For example notebooks "pybtex", - "sympy==1.8", + "sympy==1.9", # Note: Matplotlib is loaded for debug plots, but to ensure pybamm runs # on systems without an attached display, it should never be imported # outside of plot() methods. @@ -221,6 +215,7 @@ def compile_KLU(): "pybamm_add_parameter = pybamm.parameters_cli:add_parameter", "pybamm_rm_parameter = pybamm.parameters_cli:remove_parameter", "pybamm_install_odes = pybamm.install_odes:main", + "pybamm_install_jax = pybamm.util:install_jax", ] }, ) diff --git a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_mpm.py b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_mpm.py index f538965b52..98d596dd00 100644 --- a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_mpm.py +++ b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_mpm.py @@ -5,7 +5,6 @@ import tests import numpy as np import unittest -from platform import system class TestMPM(unittest.TestCase): @@ -29,7 +28,7 @@ def test_optimisations(self): np.testing.assert_array_almost_equal(original, using_known_evals) np.testing.assert_array_almost_equal(original, to_python) - if system() != "Windows": + if pybamm.have_jax(): to_jax = optimtest.evaluate_model(to_jax=True) np.testing.assert_array_almost_equal(original, to_jax) diff --git a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spm.py b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spm.py index a8bb0bd82b..22d4cfa6b4 100644 --- a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spm.py +++ b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spm.py @@ -5,7 +5,6 @@ import tests import numpy as np import unittest -from platform import system, version class TestSPM(unittest.TestCase): @@ -71,9 +70,7 @@ def test_optimisations(self): np.testing.assert_array_almost_equal(original, using_known_evals) np.testing.assert_array_almost_equal(original, to_python) - if not ( - system() == "Windows" or (system() == "Darwin" and "ARM64" in version()) - ): + if pybamm.have_jax(): to_jax = optimtest.evaluate_model(to_jax=True) np.testing.assert_array_almost_equal(original, to_jax) diff --git a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spme.py b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spme.py index d1c88be1c1..32093dd957 100644 --- a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spme.py +++ b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spme.py @@ -3,10 +3,8 @@ # import pybamm import tests - import numpy as np import unittest -from platform import system, version class TestSPMe(unittest.TestCase): @@ -79,9 +77,7 @@ def test_optimisations(self): np.testing.assert_array_almost_equal(original, using_known_evals) np.testing.assert_array_almost_equal(original, to_python) - if not ( - system() == "Windows" or (system() == "Darwin" and "ARM64" in version()) - ): + if pybamm.have_jax(): to_jax = optimtest.evaluate_model(to_jax=True) np.testing.assert_array_almost_equal(original, to_jax) diff --git a/tests/unit/test_citations.py b/tests/unit/test_citations.py index 17c2ad2d5b..34401aa42a 100644 --- a/tests/unit/test_citations.py +++ b/tests/unit/test_citations.py @@ -3,7 +3,6 @@ # import pybamm import unittest -from platform import system, version class TestCitations(unittest.TestCase): @@ -255,10 +254,7 @@ def test_solver_citations(self): pybamm.IDAKLUSolver() self.assertIn("Hindmarsh2005", citations._papers_to_cite) - @unittest.skipIf( - system() == "Windows" or (system() == "Darwin" and "ARM64" in version()), - "JAX not supported on windows or Mac M1", - ) + @unittest.skipIf(not pybamm.have_jax(), "jax is not installed") def test_jax_citations(self): citations = pybamm.citations citations._reset() diff --git a/tests/unit/test_experiments/test_experiment.py b/tests/unit/test_experiments/test_experiment.py index db6f6196e2..b91645460e 100644 --- a/tests/unit/test_experiments/test_experiment.py +++ b/tests/unit/test_experiments/test_experiment.py @@ -46,19 +46,44 @@ def test_read_strings(self): self.assertEqual( experiment.operating_conditions[:-3], [ - {"electric": (1, "C"), "time": 1800.0, "period": 20.0}, - {"electric": (0.05, "C"), "time": 1800.0, "period": 20.0}, - {"electric": (-0.5, "C"), "time": 2700.0, "period": 20.0}, - {"electric": (1, "A"), "time": 1800.0, "period": 20.0}, - {"electric": (-0.2, "A"), "time": 2700.0, "period": 60.0}, - {"electric": (1, "W"), "time": 1800.0, "period": 20.0}, - {"electric": (-0.2, "W"), "time": 2700.0, "period": 20.0}, - {"electric": (0, "A"), "time": 600.0, "period": 300.0}, - {"electric": (1, "V"), "time": 20.0, "period": 20.0}, - {"electric": (-1, "C"), "time": None, "period": 20.0}, - {"electric": (4.1, "V"), "time": None, "period": 20.0}, - {"electric": (3, "V"), "time": None, "period": 20.0}, - {"electric": (1 / 3, "C"), "time": 7200.0, "period": 20.0}, + {"electric": (1, "C"), "time": 1800.0, "period": 20.0, "dc_data": None}, + { + "electric": (0.05, "C"), + "time": 1800.0, + "period": 20.0, + "dc_data": None, + }, + { + "electric": (-0.5, "C"), + "time": 2700.0, + "period": 20.0, + "dc_data": None, + }, + {"electric": (1, "A"), "time": 1800.0, "period": 20.0, "dc_data": None}, + { + "electric": (-0.2, "A"), + "time": 2700.0, + "period": 60.0, + "dc_data": None, + }, + {"electric": (1, "W"), "time": 1800.0, "period": 20.0, "dc_data": None}, + { + "electric": (-0.2, "W"), + "time": 2700.0, + "period": 20.0, + "dc_data": None, + }, + {"electric": (0, "A"), "time": 600.0, "period": 300.0, "dc_data": None}, + {"electric": (1, "V"), "time": 20.0, "period": 20.0, "dc_data": None}, + {"electric": (-1, "C"), "time": None, "period": 20.0, "dc_data": None}, + {"electric": (4.1, "V"), "time": None, "period": 20.0, "dc_data": None}, + {"electric": (3, "V"), "time": None, "period": 20.0, "dc_data": None}, + { + "electric": (1 / 3, "C"), + "time": 7200.0, + "period": 20.0, + "dc_data": None, + }, ], ) # Calculation for operating conditions of drive cycle @@ -72,19 +97,19 @@ def test_read_strings(self): period_2 = np.min(np.diff(drive_cycle_2[:, 0])) # Check drive cycle operating conditions np.testing.assert_array_equal( - experiment.operating_conditions[-3]["electric"][0], drive_cycle + experiment.operating_conditions[-3]["dc_data"], drive_cycle ) self.assertEqual(experiment.operating_conditions[-3]["electric"][1], "A") self.assertEqual(experiment.operating_conditions[-3]["time"], time_0) self.assertEqual(experiment.operating_conditions[-3]["period"], period_0) np.testing.assert_array_equal( - experiment.operating_conditions[-2]["electric"][0], drive_cycle_1 + experiment.operating_conditions[-2]["dc_data"], drive_cycle_1 ) self.assertEqual(experiment.operating_conditions[-2]["electric"][1], "V") self.assertEqual(experiment.operating_conditions[-2]["time"], time_1) self.assertEqual(experiment.operating_conditions[-2]["period"], period_1) np.testing.assert_array_equal( - experiment.operating_conditions[-1]["electric"][0], drive_cycle_2 + experiment.operating_conditions[-1]["dc_data"], drive_cycle_2 ) self.assertEqual(experiment.operating_conditions[-1]["electric"][1], "W") self.assertEqual(experiment.operating_conditions[-1]["time"], time_2) @@ -128,9 +153,24 @@ def test_read_strings_cccv_combined(self): self.assertEqual( experiment.operating_conditions, [ - {"electric": (0.05, "C"), "time": 1800.0, "period": 60.0}, - {"electric": (-0.5, "C", 1, "V"), "time": None, "period": 60.0}, - {"electric": (0.05, "C"), "time": 1800.0, "period": 60.0}, + { + "electric": (0.05, "C"), + "time": 1800.0, + "period": 60.0, + "dc_data": None, + }, + { + "electric": (-0.5, "C", 1, "V"), + "time": None, + "period": 60.0, + "dc_data": None, + }, + { + "electric": (0.05, "C"), + "time": 1800.0, + "period": 60.0, + "dc_data": None, + }, ], ) self.assertEqual(experiment.events, [None, (0.02, "C"), None]) @@ -146,8 +186,13 @@ def test_read_strings_cccv_combined(self): self.assertEqual( experiment.operating_conditions, [ - {"electric": (-0.5, "C"), "time": None, "period": 60.0}, - {"electric": (1, "V"), "time": None, "period": 60.0}, + { + "electric": (-0.5, "C"), + "time": None, + "period": 60.0, + "dc_data": None, + }, + {"electric": (1, "V"), "time": None, "period": 60.0, "dc_data": None}, ], ) experiment = pybamm.Experiment( @@ -160,8 +205,13 @@ def test_read_strings_cccv_combined(self): self.assertEqual( experiment.operating_conditions, [ - {"electric": (-0.5, "C"), "time": 120.0, "period": 60.0}, - {"electric": (1, "V"), "time": None, "period": 60.0}, + { + "electric": (-0.5, "C"), + "time": 120.0, + "period": 60.0, + "dc_data": None, + }, + {"electric": (1, "V"), "time": None, "period": 60.0, "dc_data": None}, ], ) @@ -173,11 +223,26 @@ def test_read_strings_repeat(self): self.assertEqual( experiment.operating_conditions, [ - {"electric": (0.01, "A"), "time": 1800.0, "period": 60}, - {"electric": (-0.5, "C"), "time": 2700.0, "period": 60}, - {"electric": (1, "V"), "time": 20.0, "period": 60}, - {"electric": (-0.5, "C"), "time": 2700.0, "period": 60}, - {"electric": (1, "V"), "time": 20.0, "period": 60}, + { + "electric": (0.01, "A"), + "time": 1800.0, + "period": 60, + "dc_data": None, + }, + { + "electric": (-0.5, "C"), + "time": 2700.0, + "period": 60, + "dc_data": None, + }, + {"electric": (1, "V"), "time": 20.0, "period": 60, "dc_data": None}, + { + "electric": (-0.5, "C"), + "time": 2700.0, + "period": 60, + "dc_data": None, + }, + {"electric": (1, "V"), "time": 20.0, "period": 60, "dc_data": None}, ], ) self.assertEqual(experiment.period, 60) @@ -193,10 +258,30 @@ def test_cycle_unpacking(self): self.assertEqual( experiment.operating_conditions, [ - {"electric": (0.05, "C"), "time": 1800.0, "period": 60.0}, - {"electric": (-0.2, "C"), "time": 2700.0, "period": 60.0}, - {"electric": (0.05, "C"), "time": 1800.0, "period": 60.0}, - {"electric": (-0.2, "C"), "time": 2700.0, "period": 60.0}, + { + "electric": (0.05, "C"), + "time": 1800.0, + "period": 60.0, + "dc_data": None, + }, + { + "electric": (-0.2, "C"), + "time": 2700.0, + "period": 60.0, + "dc_data": None, + }, + { + "electric": (0.05, "C"), + "time": 1800.0, + "period": 60.0, + "dc_data": None, + }, + { + "electric": (-0.2, "C"), + "time": 2700.0, + "period": 60.0, + "dc_data": None, + }, ], ) self.assertEqual(experiment.cycle_lengths, [2, 1, 1]) diff --git a/tests/unit/test_experiments/test_simulation_with_experiment.py b/tests/unit/test_experiments/test_simulation_with_experiment.py index 31eb795df9..5b47b672b8 100644 --- a/tests/unit/test_experiments/test_simulation_with_experiment.py +++ b/tests/unit/test_experiments/test_simulation_with_experiment.py @@ -158,6 +158,24 @@ def test_run_experiment_cccv_ode(self): ) self.assertEqual(solutions[1].termination, "final time") + def test_run_experiment_drive_cycle(self): + drive_cycle = np.array([np.arange(10), np.arange(10)]).T + experiment = pybamm.Experiment( + [ + ( + "Run drive_cycle (A)", + "Run drive_cycle (V)", + "Run drive_cycle (W)", + ) + ], + drive_cycles={"drive_cycle": drive_cycle} + ) + model = pybamm.lithium_ion.DFN() + sim = pybamm.Simulation(model, experiment=experiment) + self.assertIn(('drive_cycle', 'A'), sim.op_conds_to_model_and_param) + self.assertIn(('drive_cycle', 'V'), sim.op_conds_to_model_and_param) + self.assertIn(('drive_cycle', 'W'), sim.op_conds_to_model_and_param) + def test_run_experiment_old_setup_type(self): experiment = pybamm.Experiment( [ diff --git a/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py b/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py index 64547d1443..ee7ef06bfb 100644 --- a/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py +++ b/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py @@ -8,7 +8,9 @@ import numpy as np import scipy.sparse from collections import OrderedDict -from platform import system, version + +if pybamm.have_jax(): + import jax def test_function(arg): @@ -457,10 +459,7 @@ def test_evaluator_python(self): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) - @unittest.skipIf( - system() == "Windows" or (system() == "Darwin" and "ARM64" in version()), - "JAX not supported on windows or Mac M1", - ) + @unittest.skipIf(not pybamm.have_jax(), "jax is not installed") def test_find_symbols_jax(self): # test sparse conversion constant_symbols = OrderedDict() @@ -473,10 +472,7 @@ def test_find_symbols_jax(self): list(constant_symbols.values())[0].toarray(), A.entries.toarray() ) - @unittest.skipIf( - system() == "Windows" or (system() == "Darwin" and "ARM64" in version()), - "JAX not supported on windows or Mac M1", - ) + @unittest.skipIf(not pybamm.have_jax(), "jax is not installed") def test_evaluator_jax(self): a = pybamm.StateVector(slice(0, 1)) b = pybamm.StateVector(slice(1, 2)) @@ -638,10 +634,7 @@ def test_evaluator_jax(self): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) - @unittest.skipIf( - system() == "Windows" or (system() == "Darwin" and "ARM64" in version()), - "JAX not supported on windows or Mac M1", - ) + @unittest.skipIf(not pybamm.have_jax(), "jax is not installed") def test_evaluator_jax_jacobian(self): a = pybamm.StateVector(slice(0, 1)) y_tests = [np.array([[2.0]]), np.array([[1.0]]), np.array([1.0])] @@ -656,10 +649,7 @@ def test_evaluator_jax_jacobian(self): result_true = evaluator_jac.evaluate(t=None, y=y) np.testing.assert_allclose(result_test, result_true) - @unittest.skipIf( - system() == "Windows" or (system() == "Darwin" and "ARM64" in version()), - "JAX not supported on windows or Mac M1", - ) + @unittest.skipIf(not pybamm.have_jax(), "jax is not installed") def test_evaluator_jax_debug(self): a = pybamm.StateVector(slice(0, 1)) expr = a ** 2 @@ -667,10 +657,7 @@ def test_evaluator_jax_debug(self): evaluator = pybamm.EvaluatorJax(expr) evaluator.debug(y=y_test) - @unittest.skipIf( - system() == "Windows" or (system() == "Darwin" and "ARM64" in version()), - "JAX not supported on windows or Mac M1", - ) + @unittest.skipIf(not pybamm.have_jax(), "jax is not installed") def test_evaluator_jax_inputs(self): a = pybamm.InputParameter("a") expr = a ** 2 @@ -678,13 +665,8 @@ def test_evaluator_jax_inputs(self): result = evaluator.evaluate(inputs={"a": 2}) self.assertEqual(result, 4) - @unittest.skipIf( - system() == "Windows" or (system() == "Darwin" and "ARM64" in version()), - "JAX not supported on windows or Mac M1", - ) + @unittest.skipIf(not pybamm.have_jax(), "jax is not installed") def test_jax_coo_matrix(self): - import jax - A = pybamm.JaxCooMatrix([0, 1], [0, 1], [1.0, 2.0], (2, 2)) Adense = jax.numpy.array([[1.0, 0], [0, 2.0]]) v = jax.numpy.array([[2.0], [1.0]]) diff --git a/tests/unit/test_solvers/test_base_solver.py b/tests/unit/test_solvers/test_base_solver.py index 7967c96949..bd29be3f81 100644 --- a/tests/unit/test_solvers/test_base_solver.py +++ b/tests/unit/test_solvers/test_base_solver.py @@ -296,6 +296,21 @@ def test_timescale_input_fail(self): with self.assertRaisesRegex(pybamm.SolverError, "The model timescale"): sol = solver.step(old_solution=sol, model=model, dt=1.0, inputs={"a": 20}) + def test_inputs_step(self): + # Make sure interpolant inputs are dropped + model = pybamm.BaseModel() + v = pybamm.Variable("v") + model.rhs = {v: -1} + model.initial_conditions = {v: 1} + x = np.array([0, 1]) + interp = pybamm.Interpolant(x, x, pybamm.t) + solver = pybamm.CasadiSolver() + for input_key in ["Current input [A]", "Voltage input [V]", "Power input [W]"]: + sol = solver.step( + old_solution=None, model=model, dt=1.0, inputs={input_key: interp} + ) + self.assertFalse(input_key in sol.all_inputs[0]) + def test_extrapolation_warnings(self): # Make sure the extrapolation warnings work model = pybamm.BaseModel() @@ -327,29 +342,26 @@ def test_extrapolation_warnings(self): @unittest.skipIf(not pybamm.have_idaklu(), "idaklu solver is not installed") def test_sensitivities(self): - def exact_diff_a(y, a, b): - return np.array([ - [y[0]**2 + 2 * a], - [y[0]] - ]) + return np.array([[y[0] ** 2 + 2 * a], [y[0]]]) + @unittest.skipIf(not pybamm.have_jax(), "jax is not installed") def exact_diff_b(y, a, b): return np.array([[y[0]], [0]]) - for convert_to_format in ['', 'python', 'casadi', 'jax']: + for convert_to_format in ["", "python", "casadi", "jax"]: model = pybamm.BaseModel() v = pybamm.Variable("v") u = pybamm.Variable("u") a = pybamm.InputParameter("a") b = pybamm.InputParameter("b") - model.rhs = {v: a * v**2 + b * v + a**2} + model.rhs = {v: a * v ** 2 + b * v + a ** 2} model.algebraic = {u: a * v - u} model.initial_conditions = {v: 1, u: a * 1} model.convert_to_format = convert_to_format - solver = pybamm.IDAKLUSolver(root_method='lm') - model.calculate_sensitivities = ['a', 'b'] - solver.set_up(model, inputs={'a': 0, 'b': 0}) + solver = pybamm.IDAKLUSolver(root_method="lm") + model.calculate_sensitivities = ["a", "b"] + solver.set_up(model, inputs={"a": 0, "b": 0}) all_inputs = [] for v_value in [0.1, -0.2, 1.5, 8.4]: for u_value in [0.13, -0.23, 1.3, 13.4]: @@ -357,24 +369,20 @@ def exact_diff_b(y, a, b): for b_value in [0.82, 1.9]: y = np.array([v_value, u_value]) t = 0 - inputs = {'a': a_value, 'b': b_value} + inputs = {"a": a_value, "b": b_value} all_inputs.append((t, y, inputs)) for t, y, inputs in all_inputs: - if model.convert_to_format == 'casadi': + if model.convert_to_format == "casadi": use_inputs = casadi.vertcat(*[x for x in inputs.values()]) else: use_inputs = inputs - sens = model.sensitivities_eval( - t, y, use_inputs - ) + sens = model.sensitivities_eval(t, y, use_inputs) np.testing.assert_allclose( - sens['a'], - exact_diff_a(y, inputs['a'], inputs['b']) + sens["a"], exact_diff_a(y, inputs["a"], inputs["b"]) ) np.testing.assert_allclose( - sens['b'], - exact_diff_b(y, inputs['a'], inputs['b']) + sens["b"], exact_diff_b(y, inputs["a"], inputs["b"]) ) diff --git a/tests/unit/test_solvers/test_idaklu_solver.py b/tests/unit/test_solvers/test_idaklu_solver.py index ca66e0c9c5..d17d07c63d 100644 --- a/tests/unit/test_solvers/test_idaklu_solver.py +++ b/tests/unit/test_solvers/test_idaklu_solver.py @@ -46,6 +46,7 @@ def test_ida_roberts_klu(self): true_solution = 0.1 * solution.t np.testing.assert_array_almost_equal(solution.y[0, :], true_solution) + @unittest.skipIf(not pybamm.have_jax(), "jax is not installed") def test_ida_roberts_klu_sensitivities(self): # this test implements a python version of the ida Roberts # example provided in sundials diff --git a/tests/unit/test_solvers/test_jax_bdf_solver.py b/tests/unit/test_solvers/test_jax_bdf_solver.py index 772bc937d0..b6e7ab92f1 100644 --- a/tests/unit/test_solvers/test_jax_bdf_solver.py +++ b/tests/unit/test_solvers/test_jax_bdf_solver.py @@ -4,16 +4,12 @@ import sys import time import numpy as np -from platform import system, version -if not (system() == "Windows" or (system() == "Darwin" and "ARM64" in version())): +if pybamm.have_jax(): import jax -@unittest.skipIf( - system() == "Windows" or (system() == "Darwin" and "ARM64" in version()), - "JAX not supported on windows or Mac M1", -) +@unittest.skipIf(not pybamm.have_jax(), "jax is not installed") class TestJaxBDFSolver(unittest.TestCase): def test_solver(self): # Create model diff --git a/tests/unit/test_solvers/test_jax_solver.py b/tests/unit/test_solvers/test_jax_solver.py index 74dccdaf99..e9956e4295 100644 --- a/tests/unit/test_solvers/test_jax_solver.py +++ b/tests/unit/test_solvers/test_jax_solver.py @@ -4,16 +4,12 @@ import sys import time import numpy as np -from platform import system, version -if not (system() == "Windows" or (system() == "Darwin" and "ARM64" in version())): +if pybamm.have_jax(): import jax -@unittest.skipIf( - system() == "Windows" or (system() == "Darwin" and "ARM64" in version()), - "JAX not supported on windows or Mac M1", -) +@unittest.skipIf(not pybamm.have_jax(), "jax is not installed") class TestJaxSolver(unittest.TestCase): def test_model_solver(self): # Create model diff --git a/tests/unit/test_solvers/test_scipy_solver.py b/tests/unit/test_solvers/test_scipy_solver.py index b525d95eac..ceb4fcbab0 100644 --- a/tests/unit/test_solvers/test_scipy_solver.py +++ b/tests/unit/test_solvers/test_scipy_solver.py @@ -6,15 +6,12 @@ from tests import get_mesh_for_testing, get_discretisation_for_testing import warnings import sys -from platform import system, version class TestScipySolver(unittest.TestCase): def test_model_solver_python_and_jax(self): - if not ( - system() == "Windows" or (system() == "Darwin" and "ARM64" in version()) - ): + if pybamm.have_jax(): formats = ["python", "jax"] else: formats = ["python"] diff --git a/tox.ini b/tox.ini index 1501988e90..45d0084e4a 100644 --- a/tox.ini +++ b/tox.ini @@ -3,7 +3,7 @@ envlist = {windows}-{tests,quick,dev},tests,quick,dev [testenv] skipsdist = true -skip_install = flake8: true +skip_install = flake8: true usedevelop = true passenv = !windows-!mac: SUNDIALS_INST whitelist_externals = !windows-!mac: sh @@ -17,13 +17,15 @@ deps = dev,doctests: sphinx>=1.5 dev,doctests: guzzle-sphinx-theme !windows-!mac: scikits.odes - + commands = - tests: python run-tests.py --unit --folder all - quick: python run-tests.py --unit - examples: python run-tests.py --examples - dev-!windows-!mac: sh -c "echo export LD_LIBRARY_PATH={env:LD_LIBRARY_PATH} >> {envbindir}/activate" - doctests: python run-tests.py --doctest + tests-!windows-!mac: sh -c "pybamm_install_jax" # install jax, jaxlib for ubuntu + tests: python run-tests.py --unit --folder all + quick: python run-tests.py --unit + integration: python run-tests.py --unit --folder integration + examples: python run-tests.py --examples + dev-!windows-!mac: sh -c "echo export LD_LIBRARY_PATH={env:LD_LIBRARY_PATH} >> {envbindir}/activate" + doctests: python run-tests.py --doctest [testenv:pybamm-requires] platform = [linux,darwin] @@ -35,7 +37,7 @@ deps = cmake commands = python {toxinidir}/scripts/install_KLU_Sundials.py - - git clone https://github.com/pybind/pybind11.git {toxinidir}/pybind11 + - git clone https://github.com/pybind/pybind11.git {toxinidir}/pybind11 [testenv:flake8] skip_install = true @@ -43,10 +45,11 @@ deps = flake8>=3 commands = python -m flake8 [testenv:coverage] -deps = +deps = coverage scikits.odes -commands = +commands = + !windows-!mac: sh -c "pybamm_install_jax" coverage run run-tests.py --nosub # Some tests make use of multiple processes through # multiprocessing. Coverage data is then generated for each