From 60ba076005f08aa3fae5083ccfc88058e8679708 Mon Sep 17 00:00:00 2001 From: "Eric G. Kratz" Date: Fri, 29 Mar 2024 13:34:58 -0400 Subject: [PATCH] Remove ODES solver (#3932) * Remove ODES * Remove some additional ODES files * More ODES removals * Update .github/workflows/run_periodic_tests.yml * Change versions and fix comments * Remove some unneeded comments * Change log and docker * Apply suggestions from code review Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> * Docker fix and bumping version in a test * Revert test version * Replace skipped test * Update change log --------- Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> --- .github/codecov.yml | 2 - .github/workflows/run_periodic_tests.yml | 64 -- .github/workflows/test_on_push.yml | 32 +- .gitignore | 3 - CHANGELOG.md | 1 + README.md | 3 +- docs/source/api/solvers/index.rst | 1 - docs/source/api/solvers/scikits_solvers.rst | 8 - .../user_guide/installation/gnu-linux-mac.rst | 56 -- docs/source/user_guide/installation/index.rst | 20 +- .../installation/install-from-docker.rst | 13 - examples/scripts/compare_dae_solver.py | 10 - noxfile.py | 95 +- pybamm/CITATIONS.bib | 13 - pybamm/__init__.py | 2 - pybamm/install_odes.py | 207 ---- pybamm/solvers/scikits_dae_solver.py | 183 ---- pybamm/solvers/scikits_ode_solver.py | 187 ---- pyproject.toml | 7 +- scripts/Dockerfile | 12 +- tests/unit/test_citations.py | 13 - .../unit/test_solvers/test_scikits_solvers.py | 920 ------------------ 22 files changed, 27 insertions(+), 1825 deletions(-) delete mode 100644 docs/source/api/solvers/scikits_solvers.rst delete mode 100644 pybamm/install_odes.py delete mode 100644 pybamm/solvers/scikits_dae_solver.py delete mode 100644 pybamm/solvers/scikits_ode_solver.py delete mode 100644 tests/unit/test_solvers/test_scikits_solvers.py diff --git a/.github/codecov.yml b/.github/codecov.yml index 1f0452076a..e69de29bb2 100644 --- a/.github/codecov.yml +++ b/.github/codecov.yml @@ -1,2 +0,0 @@ -ignore: - - pybamm/install_odes.py diff --git a/.github/workflows/run_periodic_tests.yml b/.github/workflows/run_periodic_tests.yml index 26758a5ded..4878710ba9 100644 --- a/.github/workflows/run_periodic_tests.yml +++ b/.github/workflows/run_periodic_tests.yml @@ -150,19 +150,12 @@ jobs: - name: Install build-time dependencies & run unit tests for M-series macOS runner shell: bash env: - # Point scikits.odes to the correct SUNDIALS installation - SUNDIALS_INST: $HOME/.local/lib - # Homebrew environment variables HOMEBREW_NO_INSTALL_CLEANUP: 1 NONINTERACTIVE: 1 run: | eval "$(pyenv init -)" pyenv activate pybamm-${{ matrix.python-version }} python -m pip install --upgrade pip nox - # Don't use Homebrew to install SUNDIALS because scikits.odes looks for - # in Homebrew folders instead, which we don't want - brew uninstall sundials --force - pip cache remove scikits.odes python -m nox -s pybamm-requires -- --force python -m nox -s unit @@ -179,60 +172,3 @@ jobs: eval "$(pyenv init -)" pyenv activate pybamm-${{ matrix.python-version }} pyenv uninstall -f $( python --version ) - - test_install_odes: - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest, macos-12] - python-version: ["3.8", "3.9", "3.10", "3.11"] - # Include macOS M1 runners - include: - - os: macos-14 - python-version: "3.10" - - os: macos-14 - python-version: "3.11" - # scikits.odes is not available on Python 3.12 yet - # See https://github.com/bmcage/odes/issues/162 - # - os: macos-14 - # python-version: "3.12" - fail-fast: false - name: Test pybamm_install_odes (${{ matrix.os }} / Python ${{ matrix.python-version }}) - - steps: - - name: Check out PyBaMM repository - uses: actions/checkout@v4 - - - name: Install Linux system dependencies - if: matrix.os == 'ubuntu-latest' - run: | - sudo apt-get update - sudo apt-get install gfortran gcc libopenblas-dev - - name: Install macOS system dependencies - if: matrix.os == 'macos-12' || matrix.os == 'macos-14' - env: - # Homebrew environment variables - HOMEBREW_NO_INSTALL_CLEANUP: 1 - HOMEBREW_NO_AUTO_UPDATE: 1 - HOMEBREW_NO_COLOR: 1 - # Speed up CI - NONINTERACTIVE: 1 - run: | - brew analytics off - brew install openblas - brew reinstall gcc gfortran - - - name: Set up Python ${{ matrix.python-version }} - id: setup-python - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - - name: Install PyBaMM - run: python -m pip install -e . - - - name: Test pybamm_install_odes on ${{ matrix.os }} - run: | - python -m pip cache purge - python -m pip install wget cmake - pybamm_install_odes diff --git a/.github/workflows/test_on_push.yml b/.github/workflows/test_on_push.yml index 8175625f96..bbe10d6256 100644 --- a/.github/workflows/test_on_push.yml +++ b/.github/workflows/test_on_push.yml @@ -38,11 +38,10 @@ jobs: matrix: os: [ubuntu-latest, macos-12, windows-latest] python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] - # We check coverage on Ubuntu with Python 3.11, so we skip unit tests for it here - # TODO: check coverage with Python 3.12 when [odes] supports it + # We check coverage on Ubuntu with Python 3.12, so we skip unit tests for it here exclude: - os: ubuntu-latest - python-version: "3.11" + python-version: "3.12" # Include macOS M1 runners include: - os: macos-14 @@ -57,7 +56,6 @@ jobs: - name: Check out PyBaMM repository uses: actions/checkout@v4 - # Install and cache apt packages - name: Install Linux system dependencies uses: awalsh128/cache-apt-pkgs-action@v1.4.2 if: matrix.os == 'ubuntu-latest' @@ -76,7 +74,6 @@ jobs: - name: Install macOS system dependencies if: matrix.os == 'macos-12' || matrix.os == 'macos-14' env: - # Homebrew environment variables HOMEBREW_NO_INSTALL_CLEANUP: 1 HOMEBREW_NO_AUTO_UPDATE: 1 HOMEBREW_NO_COLOR: 1 @@ -123,20 +120,17 @@ jobs: - name: Run unit tests for ${{ matrix.os }} with Python ${{ matrix.python-version }} run: python -m nox -s unit - # Runs only on Ubuntu with Python 3.11 - # TODO: check coverage with Python 3.12 when [odes] supports it check_coverage: needs: style runs-on: ubuntu-latest strategy: fail-fast: false - name: Coverage tests (ubuntu-latest / Python 3.11) + name: Coverage tests (ubuntu-latest / Python 3.12) steps: - name: Check out PyBaMM repository uses: actions/checkout@v4 - # Install and cache apt packages - name: Install Linux system dependencies uses: awalsh128/cache-apt-pkgs-action@v1.4.2 with: @@ -150,11 +144,11 @@ jobs: sudo dot -c sudo apt-get install libopenblas-dev texlive-latex-extra dvipng - - name: Set up Python 3.11 + - name: Set up Python 3.12 id: setup-python uses: actions/setup-python@v5 with: - python-version: 3.11 + python-version: 3.12 cache: 'pip' - name: Install nox @@ -206,7 +200,6 @@ jobs: - name: Check out PyBaMM repository uses: actions/checkout@v4 - # Install and cache apt packages - name: Install Linux system dependencies uses: awalsh128/cache-apt-pkgs-action@v1.4.2 if: matrix.os == 'ubuntu-latest' @@ -225,7 +218,6 @@ jobs: - name: Install macOS system dependencies if: matrix.os == 'macos-12' || matrix.os == 'macos-14' env: - # Homebrew environment variables HOMEBREW_NO_INSTALL_CLEANUP: 1 HOMEBREW_NO_AUTO_UPDATE: 1 HOMEBREW_NO_COLOR: 1 @@ -272,8 +264,7 @@ jobs: - name: Run integration tests for ${{ matrix.os }} with Python ${{ matrix.python-version }} run: python -m nox -s integration -# Runs only on Ubuntu with Python 3.12. Skips IDAKLU module compilation -# for speedups, which is already tested in other jobs. + # Skips IDAKLU module compilation for speedups, which is already tested in other jobs. run_doctests: needs: style runs-on: ubuntu-latest @@ -287,7 +278,6 @@ jobs: with: fetch-depth: 0 - # Install and cache apt packages - name: Install Linux system dependencies uses: awalsh128/cache-apt-pkgs-action@v1.4.2 with: @@ -301,7 +291,7 @@ jobs: sudo dot -c sudo apt-get install texlive-latex-extra dvipng - - name: Set up Python 3.11 + - name: Set up Python id: setup-python uses: actions/setup-python@v5 with: @@ -311,13 +301,12 @@ jobs: - name: Install nox run: python -m pip install nox - - name: Install docs dependencies and run doctests for GNU/Linux with Python 3.11 + - name: Install docs dependencies and run doctests for GNU/Linux run: python -m nox -s doctests - - name: Check if the documentation can be built for GNU/Linux with Python 3.11 + - name: Check if the documentation can be built for GNU/Linux run: python -m nox -s docs - # Runs only on Ubuntu with Python 3.12 run_example_tests: needs: style runs-on: ubuntu-latest @@ -329,7 +318,6 @@ jobs: - name: Check out PyBaMM repository uses: actions/checkout@v4 - # Install and cache apt packages - name: Install Linux system dependencies uses: awalsh128/cache-apt-pkgs-action@v1.4.2 with: @@ -372,7 +360,6 @@ jobs: - name: Run example notebooks tests for GNU/Linux with Python 3.12 run: python -m nox -s examples - # Runs only on Ubuntu with Python 3.12 run_scripts_tests: needs: style runs-on: ubuntu-latest @@ -384,7 +371,6 @@ jobs: - name: Check out PyBaMM repository uses: actions/checkout@v4 - # Install and cache apt packages - name: Install Linux system dependencies uses: awalsh128/cache-apt-pkgs-action@v1.4.2 with: diff --git a/.gitignore b/.gitignore index 46c7e02b9f..7e28f2da7d 100644 --- a/.gitignore +++ b/.gitignore @@ -107,9 +107,6 @@ KLU_module_deps # setup setup.log -# odes setup -scikits_odes_setup.log - # test test.c test.json diff --git a/CHANGELOG.md b/CHANGELOG.md index 68f9caa6a5..fda5a18756 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ ## Breaking changes +- The ODES solver was removed due to compatability issues. Users should use IDAKULU, Casadi, or JAX instead. ([#3932](https://github.com/pybamm-team/PyBaMM/pull/3932)) - Integrated the `[pandas]` extra into the core PyBaMM package, deprecating the `pybamm[pandas]` optional dependency. Pandas is now a required dependency and will be installed upon installing PyBaMM ([#3892](https://github.com/pybamm-team/PyBaMM/pull/3892)) - Renamed "have_optional_dependency" to "import_optional_dependency" ([#3866](https://github.com/pybamm-team/PyBaMM/pull/3866)) - Integrated the `[latexify]` extra into the core PyBaMM package, deprecating the `pybamm[latexify]` set of optional dependencies. SymPy is now a required dependency and will be installed upon installing PyBaMM ([#3848](https://github.com/pybamm-team/PyBaMM/pull/3848)) diff --git a/README.md b/README.md index fbb0914de7..6e1c2bac21 100644 --- a/README.md +++ b/README.md @@ -128,9 +128,8 @@ conda install -c conda-forge pybamm ### Optional solvers -Following GNU/Linux and macOS solvers are optionally available: +The following solvers are optionally available: -- [scikits.odes](https://scikits-odes.readthedocs.io/en/latest/)-based solver, see [the documentation](https://docs.pybamm.org/en/latest/source/user_guide/installation/gnu-linux-mac.html#optional-scikits-odes-solver). - [jax](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html)-based solver, see [the documentation](https://docs.pybamm.org/en/latest/source/user_guide/installation/gnu-linux-mac.html#optional-jaxsolver). ## 📖 Citing PyBaMM diff --git a/docs/source/api/solvers/index.rst b/docs/source/api/solvers/index.rst index b843ec0ff3..a9aa8ac1dd 100644 --- a/docs/source/api/solvers/index.rst +++ b/docs/source/api/solvers/index.rst @@ -9,7 +9,6 @@ Solvers jax_solver idaklu_solver idaklu_jax - scikits_solvers casadi_solver algebraic_solvers solution diff --git a/docs/source/api/solvers/scikits_solvers.rst b/docs/source/api/solvers/scikits_solvers.rst deleted file mode 100644 index d440793632..0000000000 --- a/docs/source/api/solvers/scikits_solvers.rst +++ /dev/null @@ -1,8 +0,0 @@ -Scikits.odes Solvers -==================== - -.. autoclass:: pybamm.ScikitsOdeSolver - :members: - -.. autoclass:: pybamm.ScikitsDaeSolver - :members: diff --git a/docs/source/user_guide/installation/gnu-linux-mac.rst b/docs/source/user_guide/installation/gnu-linux-mac.rst index 6c2c6f182f..806576938f 100644 --- a/docs/source/user_guide/installation/gnu-linux-mac.rst +++ b/docs/source/user_guide/installation/gnu-linux-mac.rst @@ -89,62 +89,6 @@ installed automatically when you install PyBaMM using ``pip``. For an introduction to virtual environments, see (https://realpython.com/python-virtual-environments-a-primer/). -.. _scikits.odes-label: - -Optional - scikits.odes solver -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Users can install `scikits.odes `__ to utilize its interfaced SUNDIALS ODE and DAE `solvers `__ wrapped in PyBaMM. - -.. note:: - - Currently, only GNU/Linux and macOS are supported. - -.. note:: - - The ``scikits.odes`` solver is not supported on Python 3.12 yet. Please refer to https://github.com/bmcage/odes/issues/162. - There is support for Python 3.8, 3.9, 3.10, and 3.11. - -.. tab:: Debian/Ubuntu - - In a terminal, run the following commands: - - .. code:: bash - - sudo apt-get install libopenblas-dev cmake - pybamm_install_odes - - This will compile and install SUNDIALS for the system (under ``~/.local``), before installing ``scikits.odes``. (Alternatively, one can install SUNDIALS without this script and run ``pip install pybamm[odes]`` to install ``pybamm`` with ``scikits.odes``.) - -.. tab:: macOS - - In a terminal, run the following command: - - .. code:: bash - - brew install openblas gcc gfortran cmake - pybamm_install_odes - -The ``pybamm_install_odes`` command, installed with PyBaMM, automatically downloads and installs the SUNDIALS library on your -system (under ``~/.local``), before installing `scikits.odes `__ . (Alternatively, one can install SUNDIALS without this script and run ``pip install pybamm[odes]`` to install ``pybamm`` with `scikits.odes `__) - -To avoid installation failures when using ``pip install pybamm[odes]``, make sure to set the ``SUNDIALS_INST`` environment variable. If you have installed SUNDIALS using Homebrew, set the variable to the appropriate location. For example: - -.. code:: bash - - export SUNDIALS_INST=$(brew --prefix sundials) - -Ensure that the path matches the installation location on your system. You can verify the installation location by running: - -.. code:: bash - - brew info sundials - -Look for the installation path, and use that path to set the ``SUNDIALS_INST`` variable. - -Note: The location where Homebrew installs SUNDIALS might vary based on the system architecture (ARM or Intel). Adjust the path in the ``export SUNDIALS_INST`` command accordingly. - -To avoid manual setup of path the ``pybamm_install_odes`` is recommended for a smoother installation process, as it takes care of automatically downloading and installing the SUNDIALS library on your system. Optional - JaxSolver ~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/user_guide/installation/index.rst b/docs/source/user_guide/installation/index.rst index 1f0e29968b..0686419d4e 100644 --- a/docs/source/user_guide/installation/index.rst +++ b/docs/source/user_guide/installation/index.rst @@ -45,9 +45,8 @@ It can be installed using ``pip`` or ``conda``, or from source. Optional solvers ---------------- -Following GNU/Linux and macOS solvers are optionally available: +The following solvers are optionally available: -* `scikits.odes `_ -based solver, see `Optional - scikits.odes solver `_. * `jax `_ -based solver, see `Optional - JaxSolver `_. Dependencies @@ -204,23 +203,6 @@ Dependency Minimu `jaxlib `__ 0.4.20 jax Support library for JAX ========================================================================= ================== ================== ======================= -.. _install.odes_dependencies: - -odes dependencies -^^^^^^^^^^^^^^^^^ - -Installable with ``pip install "pybamm[odes]"`` - -======================================================================================================================================= ================== ================== ============================= -Dependency Minimum Version pip extra Notes -======================================================================================================================================= ================== ================== ============================= -`scikits.odes `__ \- odes For scikits ODE & DAE solvers -======================================================================================================================================= ================== ================== ============================= - -.. note:: - - Before running ``pip install "pybamm[odes]"``, make sure to install ``scikits.odes`` build-time requirements as described `here `_ . - Full installation guide ----------------------- diff --git a/docs/source/user_guide/installation/install-from-docker.rst b/docs/source/user_guide/installation/install-from-docker.rst index 61f99817c7..82f75bebf4 100644 --- a/docs/source/user_guide/installation/install-from-docker.rst +++ b/docs/source/user_guide/installation/install-from-docker.rst @@ -26,12 +26,6 @@ Use the following command to pull the PyBaMM Docker image from Docker Hub: docker pull pybamm/pybamm:latest -.. tab:: Scikits.odes solver - - .. code:: bash - - docker pull pybamm/pybamm:odes - .. tab:: JAX solver .. code:: bash @@ -143,18 +137,11 @@ Building Docker images with optional arguments When building the PyBaMM Docker images locally, you have the option to include specific solvers by using optional arguments. These solvers include: - ``IDAKLU``: For IDA solver provided by the SUNDIALS plus KLU. -- ``ODES``: For scikits.odes solver for ODE & DAE problems. - ``JAX``: For Jax solver. - ``ALL``: For all the above solvers. To build the Docker images with optional arguments, you can follow these steps for each solver: -.. tab:: Scikits.odes solver - - .. code-block:: bash - - docker build -t pybamm:odes -f scripts/Dockerfile --build-arg ODES=true . - .. tab:: JAX solver .. code-block:: bash diff --git a/examples/scripts/compare_dae_solver.py b/examples/scripts/compare_dae_solver.py index c98309a2ed..815b458f1a 100644 --- a/examples/scripts/compare_dae_solver.py +++ b/examples/scripts/compare_dae_solver.py @@ -38,16 +38,6 @@ Please consult installation instructions on GitHub. """ ) -if pybamm.have_scikits_odes(): - scikits_sol = pybamm.ScikitsDaeSolver(atol=1e-8, rtol=1e-8).solve(model, t_eval) - solutions.append(scikits_sol) -else: - pybamm.logger.error( - """ - Cannot solve model with Scikits DAE solver as solver is not installed. - Please consult installation instructions on GitHub. - """ - ) # plot plot = pybamm.QuickPlot(solutions) diff --git a/noxfile.py b/noxfile.py index 0fb73606d8..f93c88b954 100644 --- a/noxfile.py +++ b/noxfile.py @@ -17,7 +17,6 @@ PYBAMM_ENV = { "SUNDIALS_INST": f"{homedir}/.local", "LD_LIBRARY_PATH": f"{homedir}/.local/lib", - "PIP_NO_BINARY": "scikits.odes", } VENV_DIR = Path("./venv").resolve() @@ -63,19 +62,7 @@ def run_coverage(session): set_environment_variables(PYBAMM_ENV, session=session) session.install("coverage", silent=False) if sys.platform != "win32": - if sys.version_info > (3, 12): - session.install("-e", ".[all,dev,jax]", silent=False) - else: - session.run_always( - sys.executable, - "-m", - "pip", - "cache", - "remove", - "scikits.odes", - external=True, - ) - session.install("-e", ".[all,dev,jax,odes]", silent=False) + session.install("-e", ".[all,dev,jax]", silent=False) else: if sys.version_info < (3, 9): session.install("-e", ".[all,dev]", silent=False) @@ -89,19 +76,7 @@ def run_integration(session): """Run the integration tests.""" set_environment_variables(PYBAMM_ENV, session=session) if sys.platform != "win32": - if sys.version_info > (3, 12): - session.install("-e", ".[all,dev,jax]", silent=False) - else: - session.run_always( - sys.executable, - "-m", - "pip", - "cache", - "remove", - "scikits.odes", - external=True, - ) - session.install("-e", ".[all,dev,jax,odes]", silent=False) + session.install("-e", ".[all,dev,jax]", silent=False) else: if sys.version_info < (3, 9): session.install("-e", ".[all,dev]", silent=False) @@ -122,19 +97,7 @@ def run_unit(session): """Run the unit tests.""" set_environment_variables(PYBAMM_ENV, session=session) if sys.platform != "win32": - if sys.version_info > (3, 12): - session.install("-e", ".[all,dev,jax]", silent=False) - else: - session.run_always( - sys.executable, - "-m", - "pip", - "cache", - "remove", - "scikits.odes", - external=True, - ) - session.install("-e", ".[all,dev,jax,odes]", silent=False) + session.install("-e", ".[all,dev,jax]", silent=False) else: if sys.version_info < (3, 9): session.install("-e", ".[all,dev]", silent=False) @@ -176,35 +139,15 @@ def set_dev(session): # is fixed session.run(python, "-m", "pip", "install", "setuptools", external=True) if sys.platform == "linux": - if sys.version_info > (3, 12): - session.run( - python, - "-m", - "pip", - "install", - "-e", - ".[all,dev,jax]", - external=True, - ) - else: - session.run_always( - sys.executable, - "-m", - "pip", - "cache", - "remove", - "scikits.odes", - external=True, - ) - session.run( - python, - "-m", - "pip", - "install", - "-e", - ".[all,dev,jax,odes]", - external=True, - ) + session.run( + python, + "-m", + "pip", + "install", + "-e", + ".[all,dev,jax]", + external=True, + ) else: if sys.version_info < (3, 9): session.run( @@ -233,19 +176,7 @@ def run_tests(session): """Run the unit tests and integration tests sequentially.""" set_environment_variables(PYBAMM_ENV, session=session) if sys.platform != "win32": - if sys.version_info > (3, 12): - session.install("-e", ".[all,dev,jax]", silent=False) - else: - session.run_always( - sys.executable, - "-m", - "pip", - "cache", - "remove", - "scikits.odes", - external=True, - ) - session.install("-e", ".[all,dev,jax,odes]", silent=False) + session.install("-e", ".[all,dev,jax]", silent=False) else: if sys.version_info < (3, 9): session.install("-e", ".[all,dev]", silent=False) diff --git a/pybamm/CITATIONS.bib b/pybamm/CITATIONS.bib index 21740584b5..d2a6643f69 100644 --- a/pybamm/CITATIONS.bib +++ b/pybamm/CITATIONS.bib @@ -252,19 +252,6 @@ @article{Lain2019 doi = {10.3390/batteries5040064}, } -@article{Malengier2018, - year = {2018}, - month = {feb}, - publisher = {The Open Journal}, - volume = {3}, - number = {22}, - pages = {165}, - author = {Malengier, Benny and Ki{\v{s}}on, Pavol and Tocknell, James and Abert, Claas and Bruckner, Florian and Bisotti, Marc-Antonio}, - title = {{ODES: a high level interface to ODE and DAE solvers}}, - journal = {The Journal of Open Source Software}, - doi = {10.21105/joss.00165}, -} - @article{Marquis2019, title = {{An asymptotic derivation of a single particle model with electrolyte}}, author = {Marquis, Scott G. and Sulzer, Valentin and Timms, Robert and Please, Colin P. and Chapman, S. Jon}, diff --git a/pybamm/__init__.py b/pybamm/__init__.py index c2654ea9cf..6e48d17e52 100644 --- a/pybamm/__init__.py +++ b/pybamm/__init__.py @@ -213,8 +213,6 @@ from .solvers.algebraic_solver import AlgebraicSolver from .solvers.casadi_solver import CasadiSolver from .solvers.casadi_algebraic_solver import CasadiAlgebraicSolver -from .solvers.scikits_dae_solver import ScikitsDaeSolver -from .solvers.scikits_ode_solver import ScikitsOdeSolver, have_scikits_odes from .solvers.scipy_solver import ScipySolver from .solvers.jax_solver import JaxSolver diff --git a/pybamm/install_odes.py b/pybamm/install_odes.py deleted file mode 100644 index a832bdb491..0000000000 --- a/pybamm/install_odes.py +++ /dev/null @@ -1,207 +0,0 @@ -import os -import tarfile -from os.path import join, isfile -import argparse -import sys -import logging -import subprocess -from multiprocessing import cpu_count - -from pybamm.util import root_dir - -if sys.platform == "win32": - raise Exception("pybamm_install_odes is not supported on Windows.") - -SUNDIALS_VERSION = "6.5.0" - -# Build in parallel wherever possible -os.environ["CMAKE_BUILD_PARALLEL_LEVEL"] = str(cpu_count()) - -try: - # wget module is required to download SUNDIALS or SuiteSparse. - import wget - - NO_WGET = False -except ModuleNotFoundError: - NO_WGET = True - -# Build in parallel wherever possible -os.environ["CMAKE_BUILD_PARALLEL_LEVEL"] = str(cpu_count()) - - -def download_extract_library(url, directory): - # Download and extract archive at url - if NO_WGET: - error_msg = ( - "Could not find wget module." - " Please install wget module (pip install wget)." - ) - raise ModuleNotFoundError(error_msg) - archive = wget.download(url, out=directory) - tar = tarfile.open(archive) - tar.extractall(directory) - - -def install_sundials(download_dir, install_dir): - # Download the SUNDIALS library and compile it. - logger = logging.getLogger("scikits.odes setup") - - try: - subprocess.run(["cmake", "--version"]) - except OSError as error: - raise RuntimeError("CMake must be installed to build SUNDIALS.") from error - - url = f"https://github.com/LLNL/sundials/releases/download/v{SUNDIALS_VERSION}/sundials-{SUNDIALS_VERSION}.tar.gz" - logger.info("Downloading sundials") - download_extract_library(url, download_dir) - - cmake_args = [ - "-DLAPACK_ENABLE=ON", - "-DSUNDIALS_INDEX_SIZE=32", - "-DBUILD_ARKODE:BOOL=OFF", - "-DEXAMPLES_ENABLE:BOOL=OFF", - f"-DCMAKE_INSTALL_PREFIX={install_dir}", - ] - - # SUNDIALS are built within directory 'build_sundials' in the PyBaMM root - # directory - build_directory = os.path.abspath(join(download_dir, "build_sundials")) - if not os.path.exists(build_directory): - print("\n-" * 10, "Creating build dir", "-" * 40) - os.makedirs(build_directory) - - print("-" * 10, "Running CMake prepare", "-" * 40) - subprocess.run( - ["cmake", f"../sundials-{SUNDIALS_VERSION}", *cmake_args], - cwd=build_directory, - check=True, - ) - - print("-" * 10, "Building the sundials", "-" * 40) - make_cmd = ["make", "install"] - subprocess.run(make_cmd, cwd=build_directory, check=True) - - -def update_LD_LIBRARY_PATH(install_dir): - # Look for the current python virtual env and add an export statement - # for LD_LIBRARY_PATH in the activate script. If no virtual env is found, - # the current user's .bashrc file is modified instead. - - export_statement = f"export LD_LIBRARY_PATH={install_dir}/lib:$LD_LIBRARY_PATH" - - home_dir = os.environ.get("HOME") - bashrc_path = os.path.join(home_dir, ".bashrc") - zshrc_path = os.path.join(home_dir, ".zshrc") - venv_path = os.environ.get("VIRTUAL_ENV") - - if venv_path: - script_path = os.path.join(venv_path, "bin/activate") - else: - if os.path.exists(bashrc_path): - script_path = os.path.join(os.environ.get("HOME"), ".bashrc") - elif os.path.exists(zshrc_path): - script_path = os.path.join(os.environ.get("HOME"), ".zshrc") - elif os.path.exists(bashrc_path) and os.path.exists(zshrc_path): - print( - "Both .bashrc and .zshrc found in the home directory. Setting .bashrc as path" - ) - script_path = os.path.join(os.environ.get("HOME"), ".bashrc") - else: - print("Neither .bashrc nor .zshrc found in the home directory.") - - if os.getenv("LD_LIBRARY_PATH") and f"{install_dir}/lib" in os.getenv( - "LD_LIBRARY_PATH" - ): - print(f"{install_dir}/lib was found in LD_LIBRARY_PATH.") - if os.path.exists(bashrc_path): - print("--> Not updating venv activate or .bashrc scripts") - if os.path.exists(zshrc_path): - print("--> Not updating venv activate or .zshrc scripts") - else: - with open(script_path, "a+") as fh: - # Just check that export statement is not already there. - if export_statement not in fh.read(): - fh.write(export_statement) - print( - f"Adding {install_dir}/lib to LD_LIBRARY_PATH" f" in {script_path}" - ) - - -def main(arguments=None): - log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - logger = logging.getLogger("scikits.odes setup") - - # To override the default severity of logging - logger.setLevel("INFO") - - # Use FileHandler() to log to a file - logfile = join(os.path.dirname(os.path.abspath(__file__)), "scikits_odes_setup.log") - print(logfile) - file_handler = logging.FileHandler(logfile) - formatter = logging.Formatter(log_format) - file_handler.setFormatter(formatter) - - # Add the file handler - logger.addHandler(file_handler) - logger.info("Starting scikits.odes setup") - - desc = "Install scikits.odes." - parser = argparse.ArgumentParser(description=desc) - parser.add_argument("--sundials-libs", type=str, help="path to sundials libraries.") - default_install_dir = os.path.join(os.getenv("HOME"), ".local") - parser.add_argument("--install-dir", type=str, default=default_install_dir) - args = parser.parse_args(arguments) - - pybamm_dir = root_dir() - install_dir = ( - args.install_dir - if os.path.isabs(args.install_dir) - else os.path.join(pybamm_dir, args.install_dir) - ) - - # Check if sundials is already installed - SUNDIALS_LIB_DIRS = [join(os.getenv("HOME"), ".local"), "/usr/local", "/usr"] - - if args.sundials_libs: - SUNDIALS_LIB_DIRS.insert(0, args.sundials_libs) - for DIR in SUNDIALS_LIB_DIRS: - logger.info(f"Looking for sundials at {DIR}") - SUNDIALS_FOUND = isfile(join(DIR, "lib", "libsundials_ida.so")) or isfile( - join(DIR, "lib", "libsundials_ida.dylib") - ) - if SUNDIALS_FOUND: - SUNDIALS_LIB_DIR = DIR - logger.info(f"Found sundials at {SUNDIALS_LIB_DIR}") - break - - if not SUNDIALS_FOUND: - logger.info("Could not find sundials libraries.") - logger.info(f"Installing sundials in {install_dir}") - download_dir = os.path.join(pybamm_dir, "sundials") - if not os.path.exists(download_dir): - os.makedirs(download_dir) - install_sundials(download_dir, install_dir) - SUNDIALS_LIB_DIR = install_dir - - update_LD_LIBRARY_PATH(SUNDIALS_LIB_DIR) - - # At the time scikits.odes is pip installed, the path to the sundials - # library must be contained in an env variable SUNDIALS_INST - # see https://scikits-odes.readthedocs.io/en/latest/installation.html#id1 - os.environ["SUNDIALS_INST"] = SUNDIALS_LIB_DIR - env = os.environ.copy() - logger.info("Installing scikits.odes via pip") - logger.info("Purging scikits.odes whels from pip cache if present") - subprocess.run( - [f"{sys.executable}", "-m", "pip", "cache", "remove", "scikits.odes"], - check=True, - ) - subprocess.run( - [f"{sys.executable}", "-m", "pip", "install", "scikits.odes", "--verbose"], - env=env, - check=True, - ) - - -if __name__ == "__main__": - main(sys.argv[1:]) diff --git a/pybamm/solvers/scikits_dae_solver.py b/pybamm/solvers/scikits_dae_solver.py deleted file mode 100644 index c942e8ccd7..0000000000 --- a/pybamm/solvers/scikits_dae_solver.py +++ /dev/null @@ -1,183 +0,0 @@ -# -# Solver class using Scipy's adaptive time stepper -# -# mypy: ignore-errors -import casadi -import pybamm - -import numpy as np -import importlib -import scipy.sparse as sparse - -scikits_odes_spec = importlib.util.find_spec("scikits") -if scikits_odes_spec is not None: - scikits_odes_spec = importlib.util.find_spec("scikits.odes") - if scikits_odes_spec is not None: - scikits_odes = importlib.util.module_from_spec(scikits_odes_spec) - scikits_odes_spec.loader.exec_module(scikits_odes) - - -class ScikitsDaeSolver(pybamm.BaseSolver): - """Solve a discretised model, using scikits.odes. - - Parameters - ---------- - method : str, optional - The method to use in solve_ivp (default is "BDF") - rtol : float, optional - The relative tolerance for the solver (default is 1e-6). - atol : float, optional - The absolute tolerance for the solver (default is 1e-6). - root_method : str or pybamm algebraic solver class, optional - The method to use to find initial conditions (for DAE solvers). - If a solver class, must be an algebraic solver class. - If "casadi", - the solver uses casadi's Newton rootfinding algorithm to find initial - conditions. Otherwise, the solver uses 'scipy.optimize.root' with method - specified by 'root_method' (e.g. "lm", "hybr", ...) - root_tol : float, optional - The tolerance for the initial-condition solver (default is 1e-6). - extrap_tol : float, optional - The tolerance to assert whether extrapolation occurs or not (default is 0). - extra_options : dict, optional - Any options to pass to the solver. - Please consult `scikits.odes documentation - `_ for details. - Some common keys: - - - 'max_steps': maximum (int) number of steps the solver can take - """ - - def __init__( - self, - method="ida", - rtol=1e-6, - atol=1e-6, - root_method="casadi", - root_tol=1e-6, - extrap_tol=None, - extra_options=None, - ): - if scikits_odes_spec is None: - raise ImportError("scikits.odes is not installed") - - super().__init__(method, rtol, atol, root_method, root_tol, extrap_tol) - self.name = f"Scikits DAE solver ({method})" - - self.extra_options = extra_options or {} - - pybamm.citations.register("Malengier2018") - pybamm.citations.register("Hindmarsh2000") - pybamm.citations.register("Hindmarsh2005") - - def _integrate(self, model, t_eval, inputs_dict=None): - """ - Solve a model defined by dydt with initial conditions y0. - - Parameters - ---------- - model : :class:`pybamm.BaseModel` - The model whose solution to calculate. - t_eval : numeric type - The times at which to compute the solution - inputs_dict : dict, optional - Any input parameters to pass to the model when solving - - """ - inputs_dict = inputs_dict or {} - if model.convert_to_format == "casadi": - inputs = casadi.vertcat(*[x for x in inputs_dict.values()]) - else: - inputs = inputs_dict - - y0 = model.y0 - if isinstance(y0, casadi.DM): - y0 = y0.full() - y0 = y0.flatten() - - rhs_algebraic_eval = model.rhs_algebraic_eval - events = model.terminate_events_eval - jacobian = model.jac_rhs_algebraic_eval - if model.convert_to_format == "jax": - mass_matrix = model.mass_matrix.entries.toarray() - else: - mass_matrix = model.mass_matrix.entries - - if model.convert_to_format == "casadi": - - def eqsres(t, y, ydot, return_residuals): - return_residuals[:] = ( - rhs_algebraic_eval(t, y, inputs).full().flatten() - - mass_matrix @ ydot - ) - - else: - - def eqsres(t, y, ydot, return_residuals): - return_residuals[:] = ( - rhs_algebraic_eval(t, y, inputs).flatten() - mass_matrix @ ydot - ) - - def rootfn(t, y, ydot, return_root): - return_root[:] = [float(event(t, y, inputs)) for event in events] - - extra_options = { - **self.extra_options, - "old_api": False, - "rtol": self.rtol, - "atol": self.atol, - } - - if jacobian: - jac_y0_t0 = jacobian(t_eval[0], y0, inputs) - if sparse.issparse(jac_y0_t0): - - def jacfn(t, y, ydot, residuals, cj, J): - jac_eval = jacobian(t, y, inputs) - cj * mass_matrix - J[:][:] = jac_eval.toarray() - - else: - - def jacfn(t, y, ydot, residuals, cj, J): - jac_eval = jacobian(t, y, inputs) - cj * mass_matrix - J[:][:] = jac_eval - - extra_options.update({"jacfn": jacfn}) - - if events: - extra_options.update({"rootfn": rootfn, "nr_rootfns": len(events)}) - - # solver works with ydot0 set to zero - ydot0 = np.zeros_like(y0) - - # set up and solve - dae_solver = scikits_odes.dae(self.method, eqsres, **extra_options) - timer = pybamm.Timer() - sol = dae_solver.solve(t_eval, y0, ydot0) - integration_time = timer.time() - - # return solution, we need to tranpose y to match scipy's interface - if sol.flag in [0, 2]: - # 0 = solved for all t_eval - if sol.flag == 0: - termination = "final time" - # 2 = found root(s) - elif sol.flag == 2: - termination = "event" - if sol.roots.t is None: - t_root = None - else: - t_root = sol.roots.t - sol = pybamm.Solution( - sol.values.t, - np.transpose(sol.values.y), - model, - inputs_dict, - t_root, - np.transpose(sol.roots.y), - termination, - ) - sol.integration_time = integration_time - return sol - else: - raise pybamm.SolverError(sol.message) diff --git a/pybamm/solvers/scikits_ode_solver.py b/pybamm/solvers/scikits_ode_solver.py deleted file mode 100644 index f3a4232da9..0000000000 --- a/pybamm/solvers/scikits_ode_solver.py +++ /dev/null @@ -1,187 +0,0 @@ -# -# Solver class using Scipy's adaptive time stepper -# -# mypy: ignore-errors -import casadi -import pybamm - -import numpy as np -import importlib -import scipy.sparse as sparse - -scikits_odes_spec = importlib.util.find_spec("scikits") -if scikits_odes_spec is not None: - scikits_odes_spec = importlib.util.find_spec("scikits.odes") - if scikits_odes_spec is not None: - scikits_odes = importlib.util.module_from_spec(scikits_odes_spec) - scikits_odes_spec.loader.exec_module(scikits_odes) - - -def have_scikits_odes(): - return scikits_odes_spec is not None - - -class ScikitsOdeSolver(pybamm.BaseSolver): - """Solve a discretised model, using scikits.odes. - - Parameters - ---------- - method : str, optional - The method to use in solve_ivp (default is "BDF") - rtol : float, optional - The relative tolerance for the solver (default is 1e-6). - atol : float, optional - The absolute tolerance for the solver (default is 1e-6). - extrap_tol : float, optional - The tolerance to assert whether extrapolation occurs or not (default is 0). - extra_options : dict, optional - Any options to pass to the solver. - Please consult `scikits.odes documentation - `_ for details. - Some common keys: - - - 'linsolver': can be 'dense' (= default), 'lapackdense', 'spgmr', 'spbcgs', \ - 'sptfqmr' - """ - - def __init__( - self, - method="cvode", - rtol=1e-6, - atol=1e-6, - extrap_tol=None, - extra_options=None, - ): - if scikits_odes_spec is None: # pragma: no cover - raise ImportError("scikits.odes is not installed") - - super().__init__(method, rtol, atol, extrap_tol=extrap_tol) - self.extra_options = extra_options or {} - self.ode_solver = True - self.name = f"Scikits ODE solver ({method})" - - pybamm.citations.register("Malengier2018") - pybamm.citations.register("Hindmarsh2000") - pybamm.citations.register("Hindmarsh2005") - - def _integrate(self, model, t_eval, inputs_dict=None): - """ - Solve a model defined by dydt with initial conditions y0. - - Parameters - ---------- - model : :class:`pybamm.BaseModel` - The model whose solution to calculate. - t_eval : numeric type - The times at which to compute the solution - inputs_dict : dict, optional - Any input parameters to pass to the model when solving - - """ - inputs_dict = inputs_dict or {} - if model.convert_to_format == "casadi": - inputs = casadi.vertcat(*[x for x in inputs_dict.values()]) - else: - inputs = inputs_dict - - y0 = model.y0 - if isinstance(y0, casadi.DM): - y0 = y0.full() - y0 = y0.flatten() - - derivs = model.rhs_eval - events = model.terminate_events_eval - jacobian = model.jac_rhs_eval - - if model.convert_to_format == "casadi": - - def eqsydot(t, y, return_ydot): - return_ydot[:] = derivs(t, y, inputs).full().flatten() - - else: - - def eqsydot(t, y, return_ydot): - return_ydot[:] = derivs(t, y, inputs).flatten() - - def rootfn(t, y, return_root): - return_root[:] = [float(event(t, y, inputs)) for event in events] - - if jacobian: - jac_y0_t0 = jacobian(t_eval[0], y0, inputs) - if sparse.issparse(jac_y0_t0): - - def jacfn(t, y, fy, J): - J[:][:] = jacobian(t, y, inputs).toarray() - - def jac_times_vecfn(v, Jv, t, y, userdata): - Jv[:] = userdata._jac_eval * v - return 0 - - else: - - def jacfn(t, y, fy, J): - J[:][:] = jacobian(t, y, inputs) - - def jac_times_vecfn(v, Jv, t, y, userdata): - Jv[:] = np.matmul(userdata._jac_eval, v) - return 0 - - def jac_times_setupfn(t, y, fy, userdata): - userdata._jac_eval = jacobian(t, y, inputs) - return 0 - - extra_options = { - **self.extra_options, - "old_api": False, - "rtol": self.rtol, - "atol": self.atol, - } - - # Read linsolver (defaults to dense) - linsolver = extra_options.get("linsolver", "dense") - - if jacobian: - if linsolver in ("dense", "lapackdense"): - extra_options.update({"jacfn": jacfn}) - elif linsolver in ("spgmr", "spbcgs", "sptfqmr"): - extra_options.update( - { - "jac_times_setupfn": jac_times_setupfn, - "jac_times_vecfn": jac_times_vecfn, - "user_data": self, - } - ) - - if events: - extra_options.update({"rootfn": rootfn, "nr_rootfns": len(events)}) - - ode_solver = scikits_odes.ode(self.method, eqsydot, **extra_options) - timer = pybamm.Timer() - sol = ode_solver.solve(t_eval, y0) - integration_time = timer.time() - - # return solution, we need to tranpose y to match scipy's ivp interface - if sol.flag in [0, 2]: - # 0 = solved for all t_eval - if sol.flag == 0: - termination = "final time" - # 2 = found root(s) - elif sol.flag == 2: - termination = "event" - if sol.roots.t is None: - t_root = None - else: - t_root = sol.roots.t - sol = pybamm.Solution( - sol.values.t, - np.transpose(sol.values.y), - model, - inputs_dict, - t_root, - np.transpose(sol.roots.y), - termination, - ) - sol.integration_time = integration_time - return sol - else: - raise pybamm.SolverError(sol.message) diff --git a/pyproject.toml b/pyproject.toml index f8fd0e33e7..30ed16a17a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,11 +121,7 @@ jax = [ "jax==0.4.20; python_version >= '3.9'", "jaxlib==0.4.20; python_version >= '3.9'", ] -# For the scikits.odes solver -odes = [ - "scikits.odes" -] -# Contains all optional dependencies, except for odes, jax, and dev dependencies +# Contains all optional dependencies, except for jax and dev dependencies all = [ "autograd>=1.6.2", "scikit-fem>=8.1.0", @@ -133,7 +129,6 @@ all = [ ] [project.scripts] -pybamm_install_odes = "pybamm.install_odes:main" pybamm_install_jax = "pybamm.util:install_jax" [project.entry-points."pybamm_parameter_sets"] diff --git a/scripts/Dockerfile b/scripts/Dockerfile index 8def7ced9e..888f5f197d 100644 --- a/scripts/Dockerfile +++ b/scripts/Dockerfile @@ -2,7 +2,6 @@ FROM continuumio/miniconda3:latest WORKDIR / -# Install the necessary dependencies RUN apt-get update && apt-get -y upgrade RUN apt-get install -y libopenblas-dev gcc gfortran graphviz git make g++ build-essential cmake pandoc texlive-latex-extra dvipng RUN rm -rf /var/lib/apt/lists/* @@ -12,7 +11,6 @@ USER pybamm WORKDIR /home/pybamm/ -# Clone project files from Git repository RUN git clone https://github.com/pybamm-team/PyBaMM.git WORKDIR /home/pybamm/PyBaMM @@ -20,7 +18,6 @@ WORKDIR /home/pybamm/PyBaMM ENV CMAKE_C_COMPILER=/usr/bin/gcc ENV CMAKE_CXX_COMPILER=/usr/bin/g++ ENV CMAKE_MAKE_PROGRAM=/usr/bin/make -ENV SUNDIALS_INST=/home/pybamm/.local ENV LD_LIBRARY_PATH=/home/pybamm/.local/lib RUN conda create -n pybamm python=3.11 @@ -29,7 +26,6 @@ SHELL ["conda", "run", "-n", "pybamm", "/bin/bash", "-c"] RUN conda install -y pip ARG IDAKLU -ARG ODES ARG JAX ARG ALL @@ -43,11 +39,6 @@ RUN if [ "$IDAKLU" = "true" ]; then \ pip install --user -e ".[all,dev,docs]"; \ fi -RUN if [ "$ODES" = "true" ]; then \ - python scripts/install_KLU_Sundials.py && \ - pip install --user -e ".[all,dev,docs,odes]"; \ - fi - RUN if [ "$JAX" = "true" ]; then \ pip install --user -e ".[all,dev,docs,jax]"; \ fi @@ -56,11 +47,10 @@ RUN if [ "$ALL" = "true" ]; then \ python scripts/install_KLU_Sundials.py && \ rm -rf pybind11 && \ git clone https://github.com/pybind/pybind11.git && \ - pip install --user -e ".[all,dev,docs,jax,odes]"; \ + pip install --user -e ".[all,dev,docs,jax]"; \ fi RUN if [ -z "$IDAKLU" ] \ - && [ -z "$ODES" ] \ && [ -z "$JAX" ] \ && [ -z "$ALL" ]; then \ pip install --user -e ".[all,dev,docs]"; \ diff --git a/tests/unit/test_citations.py b/tests/unit/test_citations.py index 61c93d3efe..ba216e62ff 100644 --- a/tests/unit/test_citations.py +++ b/tests/unit/test_citations.py @@ -409,19 +409,6 @@ def test_solver_citations(self): self.assertIn("Virtanen2020", citations._papers_to_cite) self.assertIn("Virtanen2020", citations._citation_tags.keys()) - if pybamm.have_scikits_odes(): - citations._reset() - self.assertNotIn("Malengier2018", citations._papers_to_cite) - pybamm.ScikitsOdeSolver() - self.assertIn("Malengier2018", citations._papers_to_cite) - self.assertIn("Malengier2018", citations._citation_tags.keys()) - - citations._reset() - self.assertNotIn("Malengier2018", citations._papers_to_cite) - pybamm.ScikitsDaeSolver() - self.assertIn("Malengier2018", citations._papers_to_cite) - self.assertIn("Malengier2018", citations._citation_tags.keys()) - if pybamm.have_idaklu(): citations._reset() self.assertNotIn("Hindmarsh2005", citations._papers_to_cite) diff --git a/tests/unit/test_solvers/test_scikits_solvers.py b/tests/unit/test_solvers/test_scikits_solvers.py deleted file mode 100644 index db07678a3d..0000000000 --- a/tests/unit/test_solvers/test_scikits_solvers.py +++ /dev/null @@ -1,920 +0,0 @@ -# -# Tests for the Scikits Solver classes -# -from tests import TestCase -import pybamm -import numpy as np -import unittest -import warnings -from tests import get_mesh_for_testing, get_discretisation_for_testing -import sys - - -@unittest.skipIf(not pybamm.have_scikits_odes(), "scikits.odes not installed") -class TestScikitsSolvers(TestCase): - def test_model_ode_integrate_failure(self): - # Turn off warnings to ignore sqrt error - warnings.simplefilter("ignore") - - model = pybamm.BaseModel() - var = pybamm.Variable("var") - model.rhs = {var: -pybamm.sqrt(var)} - model.initial_conditions = {var: 1} - disc = pybamm.Discretisation() - disc.process_model(model) - - t_eval = np.linspace(0, 3, 100) - solver = pybamm.ScikitsOdeSolver() - # Expect solver to fail when y goes negative - with self.assertRaises(pybamm.SolverError): - solver.solve(model, t_eval) - - # Turn warnings back on - warnings.simplefilter("default") - - def test_model_dae_integrate_failure_bad_ics(self): - # Force model to fail by providing bad ics - solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8) - - # Create custom model so that custom ics - class Model: - mass_matrix = pybamm.Matrix([[1.0, 0.0], [0.0, 0.0]]) - y0 = np.array([0.0, 1.0]) - terminate_events_eval = [] - convert_to_format = "python" - - def rhs_algebraic_eval(self, t, y, inputs): - return np.array([0.5 * np.ones_like(y[0]), 2 * y[0] - y[1]]) - - def jac_rhs_algebraic_eval(self, t, y, inputs): - return np.array([[0.0, 0.0], [2.0, -1.0]]) - - model = Model() - t_eval = np.linspace(0, 1, 100) - - with self.assertRaises(pybamm.SolverError): - solver._integrate(model, t_eval) - - def test_dae_integrate_bad_ics(self): - # Make sure that dae solver can fix bad ics automatically - # Constant - solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8) - - model = pybamm.BaseModel() - var = pybamm.Variable("var") - var2 = pybamm.Variable("var2") - model.rhs = {var: 0.5} - model.algebraic = {var2: 2 * var - var2} - model.initial_conditions = {var: 0, var2: 1} - disc = pybamm.Discretisation() - disc.process_model(model) - - t_eval = np.linspace(0, 1, 100) - solver.set_up(model) - solver._set_initial_conditions(model, 0, {}, True) - # check y0 - np.testing.assert_array_equal(model.y0.full().flatten(), [0, 0]) - # check dae solutions - solution = solver.solve(model, t_eval) - np.testing.assert_array_equal(solution.t, t_eval) - np.testing.assert_allclose(0.5 * solution.t, solution.y[0]) - np.testing.assert_allclose(1.0 * solution.t, solution.y[1]) - - def test_dae_integrate_with_non_unity_mass(self): - # Constant - solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8) - - # Create custom model so that custom mass matrix can be used - class Model: - mass_matrix = pybamm.Matrix([[4.0, 0.0], [0.0, 0.0]]) - y0 = np.array([0.0, 0.0]) - terminate_events_eval = [] - convert_to_format = "python" - len_rhs_and_alg = 2 - - def rhs_algebraic_eval(self, t, y, inputs): - return np.array([0.5 * np.ones_like(y[0]), 2.0 * y[0] - y[1]]) - - def jac_rhs_algebraic_eval(self, t, y, inputs): - return np.array([[0.0, 0.0], [2.0, -1.0]]) - - model = Model() - t_eval = np.linspace(0, 1, 100) - solution = solver._integrate(model, t_eval) - np.testing.assert_array_equal(solution.t, t_eval) - np.testing.assert_allclose(0.125 * solution.t, solution.y[0]) - np.testing.assert_allclose(0.25 * solution.t, solution.y[1]) - - def test_model_solver_ode_python(self): - model = pybamm.BaseModel() - model.convert_to_format = "python" - whole_cell = ["negative electrode", "separator", "positive electrode"] - var = pybamm.Variable("var", domain=whole_cell) - model.rhs = {var: 0.1 * var} - model.initial_conditions = {var: 1} - disc = get_discretisation_for_testing() - disc.process_model(model) - - # Solve - solver = pybamm.ScikitsOdeSolver(rtol=1e-9, atol=1e-9) - t_eval = np.linspace(0, 1, 100) - solution = solver.solve(model, t_eval) - np.testing.assert_array_equal(solution.t, t_eval) - np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t)) - - def test_model_solver_ode_events_python(self): - model = pybamm.BaseModel() - model.convert_to_format = "python" - whole_cell = ["negative electrode", "separator", "positive electrode"] - var = pybamm.Variable("var", domain=whole_cell) - model.rhs = {var: 0.1 * var} - model.initial_conditions = {var: 1} - model.events = [ - pybamm.Event("2 * var = 2.5", pybamm.min(2.5 - 2 * var)), - pybamm.Event("var = 1.5", pybamm.min(1.5 - var)), - ] - disc = get_discretisation_for_testing() - disc.process_model(model) - - # Solve - solver = pybamm.ScikitsOdeSolver(rtol=1e-9, atol=1e-9) - t_eval = np.linspace(0, 10, 100) - solution = solver.solve(model, t_eval) - np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t)) - np.testing.assert_array_less(solution.y[0, :-1], 1.5) - np.testing.assert_array_less(solution.y[0, :-1], 1.25) - np.testing.assert_equal(solution.t_event[0], solution.t[-1]) - np.testing.assert_array_equal(solution.y_event[:, 0], solution.y[:, -1]) - - def test_model_solver_ode_jacobian_python(self): - model = pybamm.BaseModel() - model.convert_to_format = "python" - whole_cell = ["negative electrode", "separator", "positive electrode"] - var1 = pybamm.Variable("var1", domain=whole_cell) - var2 = pybamm.Variable("var2", domain=whole_cell) - model.rhs = {var1: var1, var2: 1 - var1} - model.initial_conditions = {var1: 1.0, var2: -1.0} - model.variables = {"var1": var1, "var2": var2} - - disc = get_discretisation_for_testing() - disc.process_model(model) - - # Add user-supplied Jacobian to model - mesh = get_mesh_for_testing() - submesh = mesh[("negative electrode", "separator", "positive electrode")] - N = submesh.npts - - # Solve testing various linear solvers - linsolvers = [ - "dense", - # "lapackdense", - "spgmr", - "spbcgs", - "sptfqmr", - ] - - for linsolver in linsolvers: - solver = pybamm.ScikitsOdeSolver( - rtol=1e-9, atol=1e-9, extra_options={"linsolver": linsolver} - ) - t_eval = np.linspace(0, 1, 100) - solution = solver.solve(model, t_eval) - np.testing.assert_array_equal(solution.t, t_eval) - - T, Y = solution.t, solution.y - np.testing.assert_array_almost_equal( - model.variables["var1"].evaluate(T, Y), - np.ones((N, T.size)) * np.exp(T[np.newaxis, :]), - ) - np.testing.assert_array_almost_equal( - model.variables["var2"].evaluate(T, Y), - np.ones((N, T.size)) * (T[np.newaxis, :] - np.exp(T[np.newaxis, :])), - ) - - def test_model_solver_dae_python(self): - model = pybamm.BaseModel() - model.convert_to_format = "python" - whole_cell = ["negative electrode", "separator", "positive electrode"] - var1 = pybamm.Variable("var1", domain=whole_cell) - var2 = pybamm.Variable("var2", domain=whole_cell) - model.rhs = {var1: 0.1 * var1} - model.algebraic = {var2: 2 * var1 - var2} - model.initial_conditions = {var1: 1, var2: 2} - model.use_jacobian = False - disc = get_discretisation_for_testing() - disc.process_model(model) - - # Solve - solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8, root_method="lm") - t_eval = np.linspace(0, 1, 100) - solution = solver.solve(model, t_eval) - np.testing.assert_array_equal(solution.t, t_eval) - np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t)) - np.testing.assert_allclose(solution.y[-1], 2 * np.exp(0.1 * solution.t)) - - @unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed") - def test_model_solver_dae_jax(self): - model = pybamm.BaseModel() - model.convert_to_format = "jax" - whole_cell = ["negative electrode", "separator", "positive electrode"] - var1 = pybamm.Variable("var1", domain=whole_cell) - var2 = pybamm.Variable("var2", domain=whole_cell) - model.rhs = {var1: 0.1 * var1} - model.algebraic = {var2: 2 * var1 - var2} - model.initial_conditions = {var1: 1, var2: 2} - model.use_jacobian = False - disc = get_discretisation_for_testing() - disc.process_model(model) - - # Solve - solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8, root_method="lm") - t_eval = np.linspace(0, 1, 100) - solution = solver.solve(model, t_eval) - np.testing.assert_array_equal(solution.t, t_eval) - np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t)) - np.testing.assert_allclose(solution.y[-1], 2 * np.exp(0.1 * solution.t)) - - def test_model_solver_dae_bad_ics_python(self): - model = pybamm.BaseModel() - model.convert_to_format = "python" - whole_cell = ["negative electrode", "separator", "positive electrode"] - var1 = pybamm.Variable("var1", domain=whole_cell) - var2 = pybamm.Variable("var2", domain=whole_cell) - model.rhs = {var1: 0.1 * var1} - model.algebraic = {var2: 2 * var1 - var2} - model.initial_conditions = {var1: 1, var2: 3} - disc = get_discretisation_for_testing() - disc.process_model(model) - - # Solve - solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8, root_method="lm") - t_eval = np.linspace(0, 1, 100) - solution = solver.solve(model, t_eval) - np.testing.assert_array_equal(solution.t, t_eval) - np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t)) - np.testing.assert_allclose(solution.y[-1], 2 * np.exp(0.1 * solution.t)) - - def test_model_solver_dae_events_python(self): - model = pybamm.BaseModel() - model.convert_to_format = "python" - whole_cell = ["negative electrode", "separator", "positive electrode"] - var1 = pybamm.Variable("var1", domain=whole_cell) - var2 = pybamm.Variable("var2", domain=whole_cell) - model.rhs = {var1: 0.1 * var1} - model.algebraic = {var2: 2 * var1 - var2} - model.initial_conditions = {var1: 1, var2: 2} - model.events = [ - pybamm.Event("var1 = 1.5", pybamm.min(1.5 - var1)), - pybamm.Event("var2 = 2.5", pybamm.min(2.5 - var2)), - ] - disc = get_discretisation_for_testing() - disc.process_model(model) - - # Solve - solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8, root_method="lm") - t_eval = np.linspace(0, 5, 100) - solution = solver.solve(model, t_eval) - np.testing.assert_array_less(solution.y[0, :-1], 1.5) - np.testing.assert_array_less(solution.y[-1, :-1], 2.5) - np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t)) - np.testing.assert_allclose(solution.y[-1], 2 * np.exp(0.1 * solution.t)) - np.testing.assert_equal(solution.t_event[0], solution.t[-1]) - np.testing.assert_array_equal(solution.y_event[:, 0], solution.y[:, -1]) - - def test_model_solver_dae_nonsmooth_python(self): - model = pybamm.BaseModel() - model.convert_to_format = "python" - whole_cell = ["negative electrode", "separator", "positive electrode"] - var1 = pybamm.Variable("var1", domain=whole_cell) - var2 = pybamm.Variable("var2", domain=whole_cell) - discontinuity = 0.6 - - def nonsmooth_rate(t): - return 0.1 * (t < discontinuity) + 0.1 - - def nonsmooth_mult(t): - return (t < discontinuity) + 1.0 - - rate = nonsmooth_rate(pybamm.t) - mult = nonsmooth_mult(pybamm.t) - # put in an extra heaviside with no time dependence, this should be ignored by - # the solver i.e. no extra discontinuities added - model.rhs = {var1: rate * var1 + (var1 < 0)} - model.algebraic = {var2: mult * var1 - var2} - model.initial_conditions = {var1: 1, var2: 2} - model.events = [ - pybamm.Event("var1 = 1.5", pybamm.min(1.5 - var1)), - pybamm.Event("var2 = 2.5", pybamm.min(2.5 - var2)), - pybamm.Event( - "nonsmooth rate", - pybamm.Scalar(discontinuity), - pybamm.EventType.DISCONTINUITY, - ), - pybamm.Event( - "nonsmooth mult", - pybamm.Scalar(discontinuity), - pybamm.EventType.DISCONTINUITY, - ), - ] - disc = get_discretisation_for_testing() - disc.process_model(model) - - # Solve - solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8, root_method="lm") - - # create two time series, one without a time point on the discontinuity, - # and one with - t_eval1 = np.linspace(0, 5, 10) - t_eval2 = np.insert( - t_eval1, np.searchsorted(t_eval1, discontinuity), discontinuity - ) - solution1 = solver.solve(model, t_eval1) - solution2 = solver.solve(model, t_eval2) - - # check time vectors - for solution in [solution1, solution2]: - # time vectors are ordered - self.assertTrue(np.all(solution.t[:-1] <= solution.t[1:])) - - # time value before and after discontinuity is an epsilon away - dindex = np.searchsorted(solution.t, discontinuity) - value_before = solution.t[dindex - 1] - value_after = solution.t[dindex] - self.assertEqual(value_before / (1 - sys.float_info.epsilon), discontinuity) - self.assertEqual(value_after / (1 + sys.float_info.epsilon), discontinuity) - - # both solution time vectors should have same number of points - self.assertEqual(len(solution1.t), len(solution2.t)) - - # check solution - for solution in [solution1, solution2]: - np.testing.assert_array_less(solution.y[0, :-1], 1.5) - np.testing.assert_array_less(solution.y[-1, :-1], 2.5) - var1_soln = np.exp(0.2 * solution.t) - y0 = np.exp(0.2 * discontinuity) - var1_soln[solution.t > discontinuity] = y0 * np.exp( - 0.1 * (solution.t[solution.t > discontinuity] - discontinuity) - ) - var2_soln = 2 * var1_soln - var2_soln[solution.t > discontinuity] = var1_soln[ - solution.t > discontinuity - ] - np.testing.assert_allclose(solution.y[0], var1_soln, rtol=1e-06) - np.testing.assert_allclose(solution.y[-1], var2_soln, rtol=1e-06) - - def test_model_solver_dae_multiple_nonsmooth_python(self): - model = pybamm.BaseModel() - model.convert_to_format = "python" - whole_cell = ["negative electrode", "separator", "positive electrode"] - var1 = pybamm.Variable("var1", domain=whole_cell) - var2 = pybamm.Variable("var2", domain=whole_cell) - a = 0.6 - discontinuities = (np.arange(3) + 1) * a - - model.rhs = {var1: pybamm.Modulo(pybamm.t, a)} - model.algebraic = {var2: 2 * var1 - var2} - model.initial_conditions = {var1: 0, var2: 0} - model.events = [ - pybamm.Event("var1 = 0.55", pybamm.min(0.55 - var1)), - pybamm.Event("var2 = 1.2", pybamm.min(1.2 - var2)), - ] - for discontinuity in discontinuities: - model.events.append( - pybamm.Event("nonsmooth rate", pybamm.Scalar(discontinuity)) - ) - disc = get_discretisation_for_testing() - disc.process_model(model) - - # Solve - solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8, root_method="lm") - - # create two time series, one without a time point on the discontinuity, - # and one with - t_eval1 = np.linspace(0, 2, 10) - t_eval2 = np.insert( - t_eval1, np.searchsorted(t_eval1, discontinuities), discontinuities - ) - solution1 = solver.solve(model, t_eval1) - solution2 = solver.solve(model, t_eval2) - - # check time vectors - for solution in [solution1, solution2]: - # time vectors are ordered - self.assertTrue(np.all(solution.t[:-1] <= solution.t[1:])) - - # time value before and after discontinuity is an epsilon away - for discontinuity in discontinuities: - dindex = np.searchsorted(solution.t, discontinuity) - value_before = solution.t[dindex - 1] - value_after = solution.t[dindex] - self.assertEqual( - value_before / (1 - sys.float_info.epsilon), discontinuity - ) - self.assertEqual( - value_after / (1 + sys.float_info.epsilon), discontinuity - ) - - # both solution time vectors should have same number of points - self.assertEqual(len(solution1.t), len(solution2.t)) - - # check solution - for solution in [solution1, solution2]: - np.testing.assert_array_less(solution.y[0, :-1], 0.55) - np.testing.assert_array_less(solution.y[-1, :-1], 1.2) - var1_soln = (solution.t % a) ** 2 / 2 + a**2 / 2 * (solution.t // a) - var2_soln = 2 * var1_soln - np.testing.assert_allclose(solution.y[0], var1_soln, rtol=1e-06) - np.testing.assert_allclose(solution.y[-1], var2_soln, rtol=1e-06) - - def test_model_solver_dae_no_nonsmooth_python(self): - model = pybamm.BaseModel() - model.convert_to_format = "python" - whole_cell = ["negative electrode", "separator", "positive electrode"] - var1 = pybamm.Variable("var1", domain=whole_cell) - var2 = pybamm.Variable("var2", domain=whole_cell) - discontinuity = 5.6 - - def nonsmooth_rate(t): - return 0.1 * int(t < discontinuity) + 0.1 - - def nonsmooth_mult(t): - return int(t < discontinuity) + 1.0 - - # put in an extra heaviside with no time dependence, this should be ignored by - # the solver i.e. no extra discontinuities added - model.rhs = {var1: 0.1 * var1} - model.algebraic = {var2: 2 * var1 - var2} - model.initial_conditions = {var1: 1, var2: 2} - disc = get_discretisation_for_testing() - disc.process_model(model) - - # Solve - solver = pybamm.ScikitsDaeSolver(rtol=1e-9, atol=1e-9, root_method="lm") - - # create two time series, one without a time point on the discontinuity, - # and one with - t_eval = np.linspace(0, 5, 10) - solution = solver.solve(model, t_eval) - - # test solution, discontinuity should not be triggered - np.testing.assert_array_equal(solution.t, t_eval) - np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t)) - np.testing.assert_allclose(solution.y[-1], 2 * np.exp(0.1 * solution.t)) - - def test_model_solver_dae_with_jacobian_python(self): - model = pybamm.BaseModel() - model.convert_to_format = "python" - whole_cell = ["negative electrode", "separator", "positive electrode"] - var1 = pybamm.Variable("var1", domain=whole_cell) - var2 = pybamm.Variable("var2", domain=whole_cell) - model.rhs = {var1: 0.1 * var1} - model.algebraic = {var2: 2 * var1 - var2} - model.initial_conditions = {var1: 1.0, var2: 2.0} - model.initial_conditions_ydot = {var1: 0.1, var2: 0.2} - disc = get_discretisation_for_testing() - disc.process_model(model) - - # Add user-supplied Jacobian to model - mesh = get_mesh_for_testing() - submesh = mesh[("negative electrode", "separator", "positive electrode")] - N = submesh.npts - - def jacobian(t, y): - return np.block( - [ - [0.1 * np.eye(N), np.zeros((N, N))], - [2.0 * np.eye(N), -1.0 * np.eye(N)], - ] - ) - - model.jacobian = jacobian - # Solve - solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8, root_method="lm") - t_eval = np.linspace(0, 1, 100) - solution = solver.solve(model, t_eval) - np.testing.assert_array_equal(solution.t, t_eval) - np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t)) - np.testing.assert_allclose(solution.y[-1], 2 * np.exp(0.1 * solution.t)) - - def test_solve_ode_model_with_dae_solver_python(self): - model = pybamm.BaseModel() - model.convert_to_format = "python" - var = pybamm.Variable("var") - model.rhs = {var: 0.1 * var} - model.initial_conditions = {var: 1} - disc = get_discretisation_for_testing() - disc.process_model(model) - - # Solve - solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8, root_method="lm") - t_eval = np.linspace(0, 1, 100) - solution = solver.solve(model, t_eval) - np.testing.assert_array_equal(solution.t, t_eval) - np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t)) - - def test_model_step_ode_python(self): - model = pybamm.BaseModel() - model.convert_to_format = "python" - whole_cell = ["negative electrode", "separator", "positive electrode"] - var = pybamm.Variable("var", domain=whole_cell) - model.rhs = {var: -0.1 * var} - model.initial_conditions = {var: 1} - disc = get_discretisation_for_testing() - disc.process_model(model) - - solver = pybamm.ScikitsOdeSolver(rtol=1e-9, atol=1e-9) - - # Step once - dt = 1 - step_sol = solver.step(None, model, dt) - np.testing.assert_array_equal(step_sol.t, [0, dt]) - np.testing.assert_allclose(step_sol.y[0], np.exp(-0.1 * step_sol.t)) - - # Step again (return 5 points) - step_sol_2 = solver.step(step_sol, model, dt, npts=5) - np.testing.assert_array_equal( - step_sol_2.t, np.array([0, 1, 1 + 1e-9, 1.25, 1.5, 1.75, 2]) - ) - np.testing.assert_allclose(step_sol_2.y[0], np.exp(-0.1 * step_sol_2.t)) - - # Check steps give same solution as solve - t_eval = step_sol.t - solution = solver.solve(model, t_eval) - np.testing.assert_allclose(solution.y[0], step_sol.y[0], atol=1e-6, rtol=1e-6) - - def test_model_step_dae_python(self): - model = pybamm.BaseModel() - model.convert_to_format = "python" - whole_cell = ["negative electrode", "separator", "positive electrode"] - var1 = pybamm.Variable("var1", domain=whole_cell) - var2 = pybamm.Variable("var2", domain=whole_cell) - model.rhs = {var1: 0.1 * var1} - model.algebraic = {var2: 2 * var1 - var2} - model.initial_conditions = {var1: 1, var2: 2} - model.use_jacobian = False - disc = get_discretisation_for_testing() - disc.process_model(model) - - solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8, root_method="lm") - - # Step once - dt = 1 - step_sol = solver.step(None, model, dt) - np.testing.assert_array_equal(step_sol.t, [0, dt]) - np.testing.assert_allclose(step_sol.y[0, :], np.exp(0.1 * step_sol.t)) - np.testing.assert_allclose(step_sol.y[-1, :], 2 * np.exp(0.1 * step_sol.t)) - - # Step again (return 5 points) - step_sol_2 = solver.step(step_sol, model, dt, npts=5) - np.testing.assert_array_equal( - step_sol_2.t, np.array([0, 1, 1 + 1e-9, 1.25, 1.5, 1.75, 2]) - ) - np.testing.assert_allclose(step_sol_2.y[0, :], np.exp(0.1 * step_sol_2.t)) - np.testing.assert_allclose(step_sol_2.y[-1, :], 2 * np.exp(0.1 * step_sol_2.t)) - - # Check steps give same solution as solve - t_eval = step_sol.t - solution = solver.solve(model, t_eval) - np.testing.assert_allclose(solution.y[0, :], step_sol.y[0, :]) - np.testing.assert_allclose(solution.y[-1, :], step_sol.y[-1, :]) - - def test_model_solver_ode_events_casadi(self): - # Create model - model = pybamm.BaseModel() - model.convert_to_format = "casadi" - whole_cell = ["negative electrode", "separator", "positive electrode"] - var = pybamm.Variable("var", domain=whole_cell) - model.rhs = {var: 0.1 * var} - model.initial_conditions = {var: 1} - model.events = [ - pybamm.Event("2 * var = 2.5", pybamm.min(2.5 - 2 * var)), - pybamm.Event("var = 1.5", pybamm.min(1.5 - var)), - ] - disc = get_discretisation_for_testing() - disc.process_model(model) - - # Solve - solver = pybamm.ScikitsOdeSolver(rtol=1e-9, atol=1e-9) - t_eval = np.linspace(0, 10, 100) - solution = solver.solve(model, t_eval) - np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t)) - np.testing.assert_array_less(solution.y[0:, -1], 1.5) - np.testing.assert_array_less(solution.y[0:, -1], 1.25 + 1e-9) - np.testing.assert_equal(solution.t_event[0], solution.t[-1]) - np.testing.assert_array_equal(solution.y_event[:, 0], solution.y[:, -1]) - - def test_model_solver_dae_events_casadi(self): - # Create model - model = pybamm.BaseModel() - for use_jacobian in [True, False]: - model.use_jacobian = use_jacobian - model.convert_to_format = "casadi" - whole_cell = ["negative electrode", "separator", "positive electrode"] - var1 = pybamm.Variable("var1", domain=whole_cell) - var2 = pybamm.Variable("var2", domain=whole_cell) - model.rhs = {var1: 0.1 * var1} - model.algebraic = {var2: 2 * var1 - var2} - model.initial_conditions = {var1: 1, var2: 2} - model.events = [ - pybamm.Event("var1 = 1.5", pybamm.min(1.5 - var1)), - pybamm.Event("var2 = 2.5", pybamm.min(2.5 - var2)), - ] - disc = get_discretisation_for_testing() - model_disc = disc.process_model(model, inplace=False) - - # Solve - solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8) - t_eval = np.linspace(0, 5, 100) - solution = solver.solve(model_disc, t_eval) - np.testing.assert_array_less(solution.y[0, :-1], 1.5) - np.testing.assert_array_less(solution.y[-1, :-1], 2.5) - np.testing.assert_equal(solution.t_event[0], solution.t[-1]) - np.testing.assert_array_equal(solution.y_event[:, 0], solution.y[:, -1]) - np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t)) - np.testing.assert_allclose(solution.y[-1], 2 * np.exp(0.1 * solution.t)) - - def test_model_solver_dae_inputs_events(self): - # Create model - for form in ["python", "casadi"]: - model = pybamm.BaseModel() - model.convert_to_format = form - whole_cell = ["negative electrode", "separator", "positive electrode"] - var1 = pybamm.Variable("var1", domain=whole_cell) - var2 = pybamm.Variable("var2", domain=whole_cell) - model.rhs = {var1: pybamm.InputParameter("rate 1") * var1} - model.algebraic = {var2: pybamm.InputParameter("rate 2") * var1 - var2} - model.initial_conditions = {var1: 1, var2: 2} - model.events = [ - pybamm.Event("var1 = 1.5", pybamm.min(1.5 - var1)), - pybamm.Event("var2 = 2.5", pybamm.min(2.5 - var2)), - ] - disc = get_discretisation_for_testing() - disc.process_model(model) - - # Solve - if form == "python": - solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8, root_method="lm") - else: - solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8) - t_eval = np.linspace(0, 5, 100) - solution = solver.solve(model, t_eval, inputs={"rate 1": 0.1, "rate 2": 2}) - np.testing.assert_array_less(solution.y[0, :-1], 1.5) - np.testing.assert_array_less(solution.y[-1, :-1], 2.5) - np.testing.assert_array_equal(solution.y_event[:, 0], solution.y[:, -1]) - np.testing.assert_equal(solution.t_event[0], solution.t[-1]) - - np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t)) - np.testing.assert_allclose(solution.y[-1], 2 * np.exp(0.1 * solution.t)) - - def test_model_solver_dae_inputs_in_initial_conditions(self): - # Create model - model = pybamm.BaseModel() - var1 = pybamm.Variable("var1") - var2 = pybamm.Variable("var2") - model.rhs = {var1: pybamm.InputParameter("rate") * var1} - model.algebraic = {var2: var1 - var2} - model.initial_conditions = { - var1: pybamm.InputParameter("ic 1"), - var2: pybamm.InputParameter("ic 2"), - } - - # Solve - solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8) - t_eval = np.linspace(0, 5, 100) - solution = solver.solve( - model, t_eval, inputs={"rate": -1, "ic 1": 0.1, "ic 2": 2} - ) - np.testing.assert_array_almost_equal( - solution.y[0], 0.1 * np.exp(-solution.t), decimal=5 - ) - np.testing.assert_array_almost_equal( - solution.y[-1], 0.1 * np.exp(-solution.t), decimal=5 - ) - - # Solve again with different initial conditions - solution = solver.solve( - model, t_eval, inputs={"rate": -0.1, "ic 1": 1, "ic 2": 3} - ) - np.testing.assert_array_almost_equal( - solution.y[0], 1 * np.exp(-0.1 * solution.t), decimal=5 - ) - np.testing.assert_array_almost_equal( - solution.y[-1], 1 * np.exp(-0.1 * solution.t), decimal=5 - ) - - def test_solve_ode_model_with_dae_solver_casadi(self): - model = pybamm.BaseModel() - model.convert_to_format = "casadi" - var = pybamm.Variable("var") - model.rhs = {var: 0.1 * var} - model.initial_conditions = {var: 1} - disc = get_discretisation_for_testing() - disc.process_model(model) - - # Solve - solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8) - t_eval = np.linspace(0, 1, 100) - solution = solver.solve(model, t_eval) - np.testing.assert_array_equal(solution.t, t_eval) - np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t)) - - def test_model_step_events(self): - # Create model - model = pybamm.BaseModel() - var1 = pybamm.Variable("var1") - var2 = pybamm.Variable("var2") - model.rhs = {var1: 0.1 * var1} - model.algebraic = {var2: 2 * var1 - var2} - model.initial_conditions = {var1: 1, var2: 2} - model.events = [ - pybamm.Event("var1 = 1.5", pybamm.min(1.5 - var1)), - pybamm.Event("var2 = 2.5", pybamm.min(2.5 - var2)), - ] - disc = pybamm.Discretisation() - disc.process_model(model) - - # Solve - step_solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8) - dt = 0.05 - time = 0 - end_time = 5 - step_solution = None - while time < end_time: - step_solution = step_solver.step(step_solution, model, dt=dt, npts=10) - time += dt - np.testing.assert_array_less(step_solution.y[0, :-1], 1.5) - np.testing.assert_array_less(step_solution.y[-1, :-1], 2.5) - np.testing.assert_equal(step_solution.t_event[0], step_solution.t[-1]) - np.testing.assert_array_equal( - step_solution.y_event[:, 0], step_solution.y[:, -1] - ) - np.testing.assert_array_almost_equal( - step_solution.y[0], np.exp(0.1 * step_solution.t), decimal=5 - ) - np.testing.assert_array_almost_equal( - step_solution.y[-1], 2 * np.exp(0.1 * step_solution.t), decimal=5 - ) - - def test_model_step_nonsmooth_events(self): - # Create model - model = pybamm.BaseModel() - var1 = pybamm.Variable("var1") - var2 = pybamm.Variable("var2") - - a = 0.6 - discontinuities = (np.arange(3) + 1) * a - - model.rhs = {var1: pybamm.Modulo(pybamm.t, a)} - model.algebraic = {var2: 2 * var1 - var2} - model.initial_conditions = {var1: 0, var2: 0} - model.events = [ - pybamm.Event("var1 = 0.55", pybamm.min(0.55 - var1)), - pybamm.Event("var2 = 1.2", pybamm.min(1.2 - var2)), - ] - for discontinuity in discontinuities: - model.events.append( - pybamm.Event("nonsmooth rate", pybamm.Scalar(discontinuity)) - ) - disc = get_discretisation_for_testing() - disc.process_model(model) - - # Solve - step_solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8) - dt = 0.05 - time = 0 - end_time = 3 - step_solution = None - while time < end_time: - step_solution = step_solver.step(step_solution, model, dt=dt, npts=10) - time += dt - np.testing.assert_array_less(step_solution.y[0, :-1], 0.55) - np.testing.assert_array_less(step_solution.y[-1, :-1], 1.2) - np.testing.assert_equal(step_solution.t_event[0], step_solution.t[-1]) - np.testing.assert_array_equal( - step_solution.y_event[:, 0], step_solution.y[:, -1] - ) - var1_soln = (step_solution.t % a) ** 2 / 2 + a**2 / 2 * (step_solution.t // a) - var2_soln = 2 * var1_soln - np.testing.assert_array_almost_equal(step_solution.y[0], var1_soln, decimal=4) - np.testing.assert_array_almost_equal(step_solution.y[-1], var2_soln, decimal=4) - - def test_model_solver_dae_nonsmooth(self): - whole_cell = ["negative electrode", "separator", "positive electrode"] - var1 = pybamm.Variable("var1", domain=whole_cell) - var2 = pybamm.Variable("var2") - discontinuity = 0.6 - - # Create three different models with the same solution, each expressing the - # discontinuity in a different way - - # first model explicitly adds a discontinuity event - def nonsmooth_rate(t): - return 0.1 * (t < discontinuity) + 0.1 - - rate = pybamm.Function(nonsmooth_rate, pybamm.t) - model1 = pybamm.BaseModel() - model1.rhs = {var1: rate * var1} - model1.algebraic = {var2: var2} - model1.initial_conditions = {var1: 1, var2: 0} - model1.events = [ - pybamm.Event("var1 = 1.5", pybamm.min(1.5 - var1)), - pybamm.Event( - "nonsmooth rate", - pybamm.Scalar(discontinuity), - pybamm.EventType.DISCONTINUITY, - ), - ] - - # second model implicitly adds a discontinuity event via a heaviside function - model2 = pybamm.BaseModel() - model2.rhs = {var1: (0.1 * (pybamm.t < discontinuity) + 0.1) * var1} - model2.algebraic = {var2: var2} - model2.initial_conditions = {var1: 1, var2: 0} - model2.events = [pybamm.Event("var1 = 1.5", pybamm.min(1.5 - var1))] - - # third model implicitly adds a discontinuity event via another heaviside - # function - model3 = pybamm.BaseModel() - model3.rhs = {var1: (-0.1 * (discontinuity < pybamm.t) + 0.2) * var1} - model3.algebraic = {var2: var2} - model3.initial_conditions = {var1: 1, var2: 0} - model3.events = [pybamm.Event("var1 = 1.5", pybamm.min(1.5 - var1))] - - for model in [model1, model2, model3]: - disc = get_discretisation_for_testing() - disc.process_model(model) - - # Solve - solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8) - - # create two time series, one without a time point on the discontinuity, - # and one with - t_eval1 = np.linspace(0, 5, 10) - t_eval2 = np.insert( - t_eval1, np.searchsorted(t_eval1, discontinuity), discontinuity - ) - solution1 = solver.solve(model, t_eval1) - solution2 = solver.solve(model, t_eval2) - - # check time vectors - for solution in [solution1, solution2]: - # time vectors are ordered - self.assertTrue(np.all(solution.t[:-1] <= solution.t[1:])) - - # time value before and after discontinuity is an epsilon away - dindex = np.searchsorted(solution.t, discontinuity) - value_before = solution.t[dindex - 1] - value_after = solution.t[dindex] - self.assertEqual( - value_before / (1 - sys.float_info.epsilon), discontinuity - ) - self.assertEqual( - value_after / (1 + sys.float_info.epsilon), discontinuity - ) - - # both solution time vectors should have same number of points - self.assertEqual(len(solution1.t), len(solution2.t)) - - # check solution - for solution in [solution1, solution2]: - np.testing.assert_array_less(solution.y[0, :-1], 1.5) - np.testing.assert_array_less(solution.y[-1, :-1], 2.5) - np.testing.assert_equal(solution.t_event[0], solution.t[-1]) - np.testing.assert_array_equal(solution.y_event[:, 0], solution.y[:, -1]) - var1_soln = np.exp(0.2 * solution.t) - y0 = np.exp(0.2 * discontinuity) - var1_soln[solution.t > discontinuity] = y0 * np.exp( - 0.1 * (solution.t[solution.t > discontinuity] - discontinuity) - ) - np.testing.assert_allclose(solution.y[0], var1_soln, rtol=1e-06) - - def test_ode_solver_fail_with_dae(self): - model = pybamm.BaseModel() - a = pybamm.Scalar(1) - model.algebraic = {a: a} - model.concatenated_initial_conditions = a - solver = pybamm.ScikitsOdeSolver() - with self.assertRaisesRegex(pybamm.SolverError, "Cannot use ODE solver"): - solver.set_up(model) - - def test_dae_solver_algebraic_model(self): - model = pybamm.BaseModel() - var = pybamm.Variable("var") - model.algebraic = {var: var + 1} - model.initial_conditions = {var: 0} - - disc = pybamm.Discretisation() - disc.process_model(model) - - solver = pybamm.ScikitsDaeSolver() - t_eval = np.linspace(0, 1) - solution = solver.solve(model, t_eval) - np.testing.assert_array_equal(solution.y, -1) - - -if __name__ == "__main__": - print("Add -v for more debug output") - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main()