diff --git a/CHANGELOG.md b/CHANGELOG.md index 12c6ad8c84..bd9c751af7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # [Unreleased](https://github.com/pybamm-team/PyBaMM/) +## Features + +- Serialisation added so models can be written to/read from JSON ([#3397](https://github.com/pybamm-team/PyBaMM/pull/3397)) + ## Bug fixes - Fixed a bug where simulations using the CasADi-based solvers would fail randomly with the half-cell model ([#3494](https://github.com/pybamm-team/PyBaMM/pull/3494)) diff --git a/docs/source/api/expression_tree/operations/index.rst b/docs/source/api/expression_tree/operations/index.rst index c084389f1a..67beaca136 100644 --- a/docs/source/api/expression_tree/operations/index.rst +++ b/docs/source/api/expression_tree/operations/index.rst @@ -8,4 +8,5 @@ Classes and functions that operate on the expression tree evaluate jacobian convert_to_casadi + serialise unpack_symbol diff --git a/docs/source/api/expression_tree/operations/serialise.rst b/docs/source/api/expression_tree/operations/serialise.rst new file mode 100644 index 0000000000..daa1b652f1 --- /dev/null +++ b/docs/source/api/expression_tree/operations/serialise.rst @@ -0,0 +1,5 @@ +Serialise +========= + +.. autoclass:: pybamm.expression_tree.operations.serialise.Serialise + :members: diff --git a/docs/source/examples/index.rst b/docs/source/examples/index.rst index 7c17cfc4aa..4afaa6eeeb 100644 --- a/docs/source/examples/index.rst +++ b/docs/source/examples/index.rst @@ -63,6 +63,7 @@ The notebooks are organised into subfolders, and can be viewed in the galleries notebooks/models/MSMR.ipynb notebooks/models/pouch-cell-model.ipynb notebooks/models/rate-capability.ipynb + notebooks/models/saving_models.ipynb notebooks/models/SEI-on-cracks.ipynb notebooks/models/simulating-ORegan-2022-parameter-set.ipynb notebooks/models/SPM.ipynb diff --git a/docs/source/examples/notebooks/models/saving_models.ipynb b/docs/source/examples/notebooks/models/saving_models.ipynb new file mode 100644 index 0000000000..91a6f2ae5c --- /dev/null +++ b/docs/source/examples/notebooks/models/saving_models.ipynb @@ -0,0 +1,376 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Saving PyBaMM models to file" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Models which are discretised (i.e. ready to solve/ previously solved, see [this notebook](https://github.com/pybamm-team/PyBaMM/blob/develop/docs/source/examples/notebooks/spatial_methods/finite-volumes.ipynb) for more information on the pybamm.Discretisation class) can be serialised and saved to a JSON file, ready to be read in again either in PyBaMM, or a different modelling library. \n", + "\n", + "In the example below, we build a basic DFN model, and then save the model out to `sim_model_example.json`, which should have appear in the 'models' directory." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install pybamm -q # install PyBaMM if it is not installed\n", + "import pybamm\n", + "\n", + "# do the example\n", + "dfn_model = pybamm.lithium_ion.DFN()\n", + "dfn_sim = pybamm.Simulation(dfn_model)\n", + "# discretise and build the model\n", + "dfn_sim.build()\n", + "\n", + "dfn_sim.save_model(\"sim_model_example\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This model file can then be read in and solved by choosing a solver, and running as below." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Recreate the pybamm model from the JSON file\n", + "new_dfn_model = pybamm.load_model(\"sim_model_example.json\")\n", + "\n", + "sim_reloaded = pybamm.Simulation(new_dfn_model)\n", + "sim_reloaded.solve([0, 3600])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It would be nice to plot the results of the two models, to confirm that they are producing the same result.\n", + "\n", + "However, notice that running the code below generates an error stating that the model variables were not provided during the reading in of the model." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "tags": [ + "raises-exception" + ] + }, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "No variables to plot", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/Users/pipliggins/Documents/repos/pybamm-local/docs/source/examples/notebooks/models/saving_models.ipynb Cell 7\u001b[0m line \u001b[0;36m8\n\u001b[1;32m 5\u001b[0m plot_sim\u001b[39m.\u001b[39msolve([\u001b[39m0\u001b[39m, \u001b[39m3600\u001b[39m])\n\u001b[1;32m 6\u001b[0m sims\u001b[39m.\u001b[39mappend(plot_sim)\n\u001b[0;32m----> 8\u001b[0m pybamm\u001b[39m.\u001b[39;49mdynamic_plot(sims, time_unit\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mseconds\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n", + "File \u001b[0;32m~/Documents/repos/pybamm-local/pybamm/plotting/dynamic_plot.py:20\u001b[0m, in \u001b[0;36mdynamic_plot\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 9\u001b[0m \u001b[39mCreates a :class:`pybamm.QuickPlot` object (with arguments 'args' and keyword\u001b[39;00m\n\u001b[1;32m 10\u001b[0m \u001b[39marguments 'kwargs') and then calls :meth:`pybamm.QuickPlot.dynamic_plot`.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[39m The 'QuickPlot' object that was created\u001b[39;00m\n\u001b[1;32m 18\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 19\u001b[0m kwargs_for_class \u001b[39m=\u001b[39m {k: v \u001b[39mfor\u001b[39;00m k, v \u001b[39min\u001b[39;00m kwargs\u001b[39m.\u001b[39mitems() \u001b[39mif\u001b[39;00m k \u001b[39m!=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mtesting\u001b[39m\u001b[39m\"\u001b[39m}\n\u001b[0;32m---> 20\u001b[0m plot \u001b[39m=\u001b[39m pybamm\u001b[39m.\u001b[39;49mQuickPlot(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs_for_class)\n\u001b[1;32m 21\u001b[0m plot\u001b[39m.\u001b[39mdynamic_plot(kwargs\u001b[39m.\u001b[39mget(\u001b[39m\"\u001b[39m\u001b[39mtesting\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39mFalse\u001b[39;00m))\n\u001b[1;32m 22\u001b[0m \u001b[39mreturn\u001b[39;00m plot\n", + "File \u001b[0;32m~/Documents/repos/pybamm-local/pybamm/plotting/quick_plot.py:146\u001b[0m, in \u001b[0;36mQuickPlot.__init__\u001b[0;34m(self, solutions, output_variables, labels, colors, linestyles, shading, figsize, n_rows, time_unit, spatial_unit, variable_limits)\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[39m# check variables have been provided after any serialisation\u001b[39;00m\n\u001b[1;32m 145\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39many\u001b[39m(\u001b[39mlen\u001b[39m(m\u001b[39m.\u001b[39mvariables) \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m \u001b[39mfor\u001b[39;00m m \u001b[39min\u001b[39;00m models):\n\u001b[0;32m--> 146\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mAttributeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mNo variables to plot\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 148\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_rows \u001b[39m=\u001b[39m n_rows \u001b[39mor\u001b[39;00m \u001b[39mint\u001b[39m(\n\u001b[1;32m 149\u001b[0m \u001b[39mlen\u001b[39m(output_variables) \u001b[39m/\u001b[39m\u001b[39m/\u001b[39m np\u001b[39m.\u001b[39msqrt(\u001b[39mlen\u001b[39m(output_variables))\n\u001b[1;32m 150\u001b[0m )\n\u001b[1;32m 151\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_cols \u001b[39m=\u001b[39m \u001b[39mint\u001b[39m(np\u001b[39m.\u001b[39mceil(\u001b[39mlen\u001b[39m(output_variables) \u001b[39m/\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_rows))\n", + "\u001b[0;31mAttributeError\u001b[0m: No variables to plot" + ] + } + ], + "source": [ + "dfn_models = [dfn_model, new_dfn_model]\n", + "sims = []\n", + "for model in dfn_models:\n", + " plot_sim = pybamm.Simulation(model)\n", + " plot_sim.solve([0, 3600])\n", + " sims.append(plot_sim)\n", + "\n", + "pybamm.dynamic_plot(sims, time_unit=\"seconds\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To be able to plot the results from a serialised model, the mesh and model variables need to be saved alongside the model itself.\n", + "\n", + "To do this, set the `variables` option to `True` when saving the model as in the example below; notice how the models will now plot nicely." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "81d8329fab424264bd56c65d53d34f63", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "interactive(children=(FloatSlider(value=0.0, description='t', max=3600.0, step=36.0), Output()), _dom_classes=…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# using the first simulation, save a new file which includes a list of the model variables\n", + "dfn_sim.save_model(\"sim_model_variables\", variables=True)\n", + "\n", + "# read the model back in\n", + "model_with_vars = pybamm.load_model(\"sim_model_variables.json\")\n", + "\n", + "# plot the pre- and post-serialisation models together to prove they behave the same\n", + "models = [dfn_model, model_with_vars]\n", + "sims = []\n", + "for model in models:\n", + " sim = pybamm.Simulation(model)\n", + " sim.solve([0, 3600])\n", + " sims.append(sim)\n", + "\n", + "pybamm.dynamic_plot(sims, time_unit=\"seconds\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Saving from Model\n", + "\n", + "Alternatively, the model can be saved directly from the Model class.\n", + "\n", + "Note that at the moment, only models derived from the BaseBatteryModel class can be serialised; those built from scratch using pybamm.BaseModel() are currently unsupported.\n", + "\n", + "First set up the model, as explained in detail for the [SPM](https://github.com/pybamm-team/PyBaMM/blob/develop/docs/source/examples/notebooks/models/SPM.ipynb)." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# create the model\n", + "spm_model = pybamm.lithium_ion.SPM()\n", + "\n", + "# set up and discretise ready to solve\n", + "geometry = spm_model.default_geometry\n", + "param = spm_model.default_parameter_values\n", + "param.process_model(spm_model)\n", + "param.process_geometry(geometry)\n", + "mesh = pybamm.Mesh(geometry, spm_model.default_submesh_types, spm_model.default_var_pts)\n", + "disc = pybamm.Discretisation(mesh, spm_model.default_spatial_methods)\n", + "disc.process_model(spm_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then save the model. Note that in this case the model variables and the mesh must be provided directly." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "# Serialise the spm_model, providing the varaibles and the mesh\n", + "spm_model.save_model(\"example_model\", variables=spm_model.variables, mesh=mesh)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now you can read the model back in, solve and plot." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ce5addf4f59c447e97d2fbee633cb6e0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "interactive(children=(FloatSlider(value=0.0, description='t', max=1.0, step=0.01), Output()), _dom_classes=('w…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# read back in\n", + "new_spm_model = pybamm.load_model(\"example_model.json\")\n", + "\n", + "# select a solver and solve\n", + "new_spm_solver = new_spm_model.default_solver\n", + "new_spm_solution = new_spm_solver.solve(new_spm_model, [0, 3600])\n", + "\n", + "# plot the solution\n", + "new_spm_solution.plot()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Making edits to a serialised model\n", + "\n", + "As mentioned at the begining of this notebook, only models which have already been discretised can be serialised and readh back in. This means that after serialisation, the model *cannot be edited*, as the non-discretised elements of the model such as the original rhs are not saved.\n", + "\n", + "If you are likely to want to save a model and then edit it in the future, you may wish to use the `Simulation.save()` functionality to pickle your simulation, as described in [tutorial 6](https://github.com/pybamm-team/PyBaMM/blob/develop/docs/source/examples/notebooks/getting_started/tutorial-6-managing-simulation-outputs.ipynb)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Before finishing we will remove the data files we saved so that we leave the directory as we found it" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.remove(\"example_model.json\")\n", + "os.remove(\"sim_model_example.json\")\n", + "os.remove(\"sim_model_variables.json\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## References\n", + "\n", + "The relevant papers for this notebook are:" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[1] Joel A. E. Andersson, Joris Gillis, Greg Horn, James B. Rawlings, and Moritz Diehl. CasADi – A software framework for nonlinear optimization and optimal control. Mathematical Programming Computation, 11(1):1–36, 2019. doi:10.1007/s12532-018-0139-4.\n", + "[2] Marc Doyle, Thomas F. Fuller, and John Newman. Modeling of galvanostatic charge and discharge of the lithium/polymer/insertion cell. Journal of the Electrochemical society, 140(6):1526–1533, 1993. doi:10.1149/1.2221597.\n", + "[3] Charles R. Harris, K. Jarrod Millman, Stéfan J. van der Walt, Ralf Gommers, Pauli Virtanen, David Cournapeau, Eric Wieser, Julian Taylor, Sebastian Berg, Nathaniel J. Smith, and others. Array programming with NumPy. Nature, 585(7825):357–362, 2020. doi:10.1038/s41586-020-2649-2.\n", + "[4] Scott G. Marquis, Valentin Sulzer, Robert Timms, Colin P. Please, and S. Jon Chapman. An asymptotic derivation of a single particle model with electrolyte. Journal of The Electrochemical Society, 166(15):A3693–A3706, 2019. doi:10.1149/2.0341915jes.\n", + "[5] Valentin Sulzer, Scott G. Marquis, Robert Timms, Martin Robinson, and S. Jon Chapman. Python Battery Mathematical Modelling (PyBaMM). Journal of Open Research Software, 9(1):14, 2021. doi:10.5334/jors.309.\n", + "\n" + ] + } + ], + "source": [ + "pybamm.print_citations()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pybamm/__init__.py b/pybamm/__init__.py index 07d8a1c0ea..cab7914cd9 100644 --- a/pybamm/__init__.py +++ b/pybamm/__init__.py @@ -188,6 +188,11 @@ UserSupplied2DSubMesh, ) +# +# Serialisation +# +from .models.base_model import load_model + # # Spatial Methods # diff --git a/pybamm/expression_tree/array.py b/pybamm/expression_tree/array.py index 2736886d95..90e02e0236 100644 --- a/pybamm/expression_tree/array.py +++ b/pybamm/expression_tree/array.py @@ -57,6 +57,30 @@ def __init__( name, domain=domain, auxiliary_domains=auxiliary_domains, domains=domains ) + @classmethod + def _from_json(cls, snippet: dict): + instance = cls.__new__(cls) + + if isinstance(snippet["entries"], dict): + matrix = csr_matrix( + ( + snippet["entries"]["data"], + snippet["entries"]["row_indices"], + snippet["entries"]["column_pointers"], + ), + shape=snippet["entries"]["shape"], + ) + else: + matrix = snippet["entries"] + + instance.__init__( + matrix, + name=snippet["name"], + domains=snippet["domains"], + ) + + return instance + @property def entries(self): return self._entries @@ -129,6 +153,30 @@ def to_equation(self): entries_list = self.entries.tolist() return sympy.Array(entries_list) + def to_json(self): + """ + Method to serialise an Array object into JSON. + """ + + if isinstance(self.entries, np.ndarray): + matrix = self.entries.tolist() + elif isinstance(self.entries, csr_matrix): + matrix = { + "shape": self.entries.shape, + "data": self.entries.data.tolist(), + "row_indices": self.entries.indices.tolist(), + "column_pointers": self.entries.indptr.tolist(), + } + + json_dict = { + "name": self.name, + "id": self.id, + "domains": self.domains, + "entries": matrix, + } + + return json_dict + def linspace(start, stop, num=50, **kwargs): """ diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index bfb31596e6..be0aa2f517 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -68,6 +68,23 @@ def __init__(self, name, left, right): self.left = self.children[0] self.right = self.children[1] + @classmethod + def _from_json(cls, snippet: dict): + """Use to instantiate when deserialising; discretisation has + already occured so pre-processing of binaries is not necessary.""" + + instance = cls.__new__(cls) + + super(BinaryOperator, instance).__init__( + snippet["name"], + children=[snippet["children"][0], snippet["children"][1]], + domains=snippet["domains"], + ) + instance.left = instance.children[0] + instance.right = instance.children[1] + + return instance + def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" # Possibly add brackets for clarity @@ -156,6 +173,15 @@ def to_equation(self): eq2 = child2.to_equation() return self._sympy_operator(eq1, eq2) + def to_json(self): + """ + Method to serialise a BinaryOperator object into JSON. + """ + + json_dict = {"name": self.name, "id": self.id, "domains": self.domains} + + return json_dict + class Power(BinaryOperator): """ diff --git a/pybamm/expression_tree/broadcasts.py b/pybamm/expression_tree/broadcasts.py index d30762ad70..d117341710 100644 --- a/pybamm/expression_tree/broadcasts.py +++ b/pybamm/expression_tree/broadcasts.py @@ -50,6 +50,17 @@ def _diff(self, variable): # Differentiate the child and broadcast the result in the same way return self._unary_new_copy(self.child.diff(variable)) + def to_json(self): + raise NotImplementedError( + "pybamm.Broadcast: Serialisation is only implemented for discretised models" + ) + + @classmethod + def _from_json(cls, snippet): + raise NotImplementedError( + "pybamm.Broadcast: Please use a discretised model when reading in from JSON" + ) + class PrimaryBroadcast(Broadcast): """ diff --git a/pybamm/expression_tree/concatenations.py b/pybamm/expression_tree/concatenations.py index 1c82aff122..71d776f03e 100644 --- a/pybamm/expression_tree/concatenations.py +++ b/pybamm/expression_tree/concatenations.py @@ -43,6 +43,17 @@ def __init__(self, *children, name=None, check_domain=True, concat_fun=None): super().__init__(name, children, domains=domains) + @classmethod + def _from_json(cls, *children, name, domains, concat_fun=None): + """Creates a new Concatenation instance from a json object""" + instance = cls.__new__(cls) + + instance.concatenation_function = concat_fun + + super(Concatenation, instance).__init__(name, children, domains=domains) + + return instance + def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" out = self.name + "(" @@ -183,6 +194,18 @@ def __init__(self, *children): concat_fun=np.concatenate ) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.Concatenation._from_json()`.""" + instance = super()._from_json( + *snippet["children"], + name="numpy_concatenation", + domains=snippet["domains"], + concat_fun=np.concatenate + ) + + return instance + def _concatenation_jac(self, children_jacs): """See :meth:`pybamm.Concatenation.concatenation_jac()`.""" children = self.children @@ -251,6 +274,31 @@ def __init__(self, children, full_mesh, copy_this=None): self._children_slices = copy.copy(copy_this._children_slices) self.secondary_dimensions_npts = copy_this.secondary_dimensions_npts + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.Concatenation._from_json()`.""" + instance = super()._from_json( + *snippet["children"], + name="domain_concatenation", + domains=snippet["domains"] + ) + + def repack_defaultDict(slices): + slices = defaultdict(list, slices) + for domain, sls in slices.items(): + sls = [slice(s["start"], s["stop"], s["step"]) for s in sls] + slices[domain] = sls + return slices + + instance._size = snippet["size"] + instance._slices = repack_defaultDict(snippet["slices"]) + instance._children_slices = [ + repack_defaultDict(s) for s in snippet["children_slices"] + ] + instance.secondary_dimensions_npts = snippet["secondary_dimensions_npts"] + + return instance + def _get_auxiliary_domain_repeats(self, auxiliary_domains): """Helper method to read the 'auxiliary_domain' meshes.""" mesh_pts = 1 @@ -316,6 +364,32 @@ def _concatenation_new_copy(self, children): ) return new_symbol + def to_json(self): + """ + Method to serialise a DomainConcatenation object into JSON. + """ + + def unpack_defaultDict(slices): + slices = dict(slices) + for domain, sls in slices.items(): + sls = [{"start": s.start, "stop": s.stop, "step": s.step} for s in sls] + slices[domain] = sls + return slices + + json_dict = { + "name": self.name, + "id": self.id, + "domains": self.domains, + "slices": unpack_defaultDict(self._slices), + "size": self._size, + "children_slices": [ + unpack_defaultDict(child_slice) for child_slice in self._children_slices + ], + "secondary_dimensions_npts": self.secondary_dimensions_npts, + } + + return json_dict + class SparseStack(Concatenation): """ diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index 0c7e98b508..d6767f1aa9 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -5,6 +5,7 @@ import numpy as np from scipy import special +from typing import Callable import pybamm from pybamm.util import have_optional_dependency @@ -211,6 +212,17 @@ def to_equation(self): eq_list.append(eq) return self._sympy_operator(*eq_list) + def to_json(self): + raise NotImplementedError( + "pybamm.Function: Serialisation is only implemented for discretised models." + ) + + @classmethod + def _from_json(cls, snippet): + raise NotImplementedError( + "pybamm.Function: Please use a discretised model when reading in from JSON." + ) + def simplified_function(func_class, child): """ @@ -244,6 +256,25 @@ class SpecificFunction(Function): def __init__(self, function, child): super().__init__(function, child) + @classmethod + def _from_json(cls, function: Callable, snippet: dict): + """ + Reconstructs a SpecificFunction instance during deserialisation of a JSON file. + + Parameters + ---------- + function : method + Function to be applied to child + snippet: dict + Contains the child to apply the function to + """ + + instance = cls.__new__(cls) + + super(SpecificFunction, instance).__init__(function, snippet["children"][0]) + + return instance + def _function_new_copy(self, children): """See :meth:`pybamm.Function._function_new_copy()`""" return pybamm.simplify_if_constant(self.__class__(*children)) @@ -255,6 +286,19 @@ def _sympy_operator(self, child): sympy_function = getattr(sympy, class_name) return sympy_function(child) + def to_json(self): + """ + Method to serialise a SpecificFunction object into JSON. + """ + + json_dict = { + "name": self.name, + "id": self.id, + "function": self.function.__name__, + } + + return json_dict + class Arcsinh(SpecificFunction): """Arcsinh function.""" @@ -262,6 +306,12 @@ class Arcsinh(SpecificFunction): def __init__(self, child): super().__init__(np.arcsinh, child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.SpecificFunction._from_json()`.""" + instance = super()._from_json(np.arcsinh, snippet) + return instance + def _function_diff(self, children, idx): """See :meth:`pybamm.Symbol._function_diff()`.""" return 1 / sqrt(children[0] ** 2 + 1) @@ -283,6 +333,12 @@ class Arctan(SpecificFunction): def __init__(self, child): super().__init__(np.arctan, child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.SpecificFunction._from_json()`.""" + instance = super()._from_json(np.arctan, snippet) + return instance + def _function_diff(self, children, idx): """See :meth:`pybamm.Function._function_diff()`.""" return 1 / (children[0] ** 2 + 1) @@ -304,6 +360,12 @@ class Cos(SpecificFunction): def __init__(self, child): super().__init__(np.cos, child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.SpecificFunction._from_json()`.""" + instance = super()._from_json(np.cos, snippet) + return instance + def _function_diff(self, children, idx): """See :meth:`pybamm.Symbol._function_diff()`.""" return -sin(children[0]) @@ -320,6 +382,12 @@ class Cosh(SpecificFunction): def __init__(self, child): super().__init__(np.cosh, child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.SpecificFunction._from_json()`.""" + instance = super()._from_json(np.cosh, snippet) + return instance + def _function_diff(self, children, idx): """See :meth:`pybamm.Function._function_diff()`.""" return sinh(children[0]) @@ -336,6 +404,12 @@ class Erf(SpecificFunction): def __init__(self, child): super().__init__(special.erf, child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.SpecificFunction._from_json()`.""" + instance = super()._from_json(special.erf, snippet) + return instance + def _function_diff(self, children, idx): """See :meth:`pybamm.Function._function_diff()`.""" return 2 / np.sqrt(np.pi) * exp(-children[0] ** 2) @@ -357,6 +431,12 @@ class Exp(SpecificFunction): def __init__(self, child): super().__init__(np.exp, child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.SpecificFunction._from_json()`.""" + instance = super()._from_json(np.exp, snippet) + return instance + def _function_diff(self, children, idx): """See :meth:`pybamm.Function._function_diff()`.""" return exp(children[0]) @@ -373,6 +453,12 @@ class Log(SpecificFunction): def __init__(self, child): super().__init__(np.log, child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.SpecificFunction._from_json()`.""" + instance = super()._from_json(np.log, snippet) + return instance + def _function_evaluate(self, evaluated_children): # don't raise RuntimeWarning for NaNs with np.errstate(invalid="ignore"): @@ -403,6 +489,12 @@ class Max(SpecificFunction): def __init__(self, child): super().__init__(np.max, child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.SpecificFunction._from_json()`.""" + instance = super()._from_json(np.max, snippet) + return instance + def _evaluate_for_shape(self): """See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`""" # Max will always return a scalar @@ -423,6 +515,12 @@ class Min(SpecificFunction): def __init__(self, child): super().__init__(np.min, child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.SpecificFunction._from_json()`.""" + instance = super()._from_json(np.min, snippet) + return instance + def _evaluate_for_shape(self): """See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`""" # Min will always return a scalar @@ -448,6 +546,12 @@ class Sin(SpecificFunction): def __init__(self, child): super().__init__(np.sin, child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.SpecificFunction._from_json()`.""" + instance = super()._from_json(np.sin, snippet) + return instance + def _function_diff(self, children, idx): """See :meth:`pybamm.Function._function_diff()`.""" return cos(children[0]) @@ -464,6 +568,12 @@ class Sinh(SpecificFunction): def __init__(self, child): super().__init__(np.sinh, child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.SpecificFunction._from_json()`.""" + instance = super()._from_json(np.sinh, snippet) + return instance + def _function_diff(self, children, idx): """See :meth:`pybamm.Function._function_diff()`.""" return cosh(children[0]) @@ -480,6 +590,12 @@ class Sqrt(SpecificFunction): def __init__(self, child): super().__init__(np.sqrt, child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.SpecificFunction._from_json()`.""" + instance = super()._from_json(np.sqrt, snippet) + return instance + def _function_evaluate(self, evaluated_children): # don't raise RuntimeWarning for NaNs with np.errstate(invalid="ignore"): @@ -501,6 +617,12 @@ class Tanh(SpecificFunction): def __init__(self, child): super().__init__(np.tanh, child) + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.SpecificFunction._from_json()`.""" + instance = super()._from_json(np.tanh, snippet) + return instance + def _function_diff(self, children, idx): """See :meth:`pybamm.Function._function_diff()`.""" return sech(children[0]) ** 2 diff --git a/pybamm/expression_tree/independent_variable.py b/pybamm/expression_tree/independent_variable.py index 2f30da9a5e..146751928e 100644 --- a/pybamm/expression_tree/independent_variable.py +++ b/pybamm/expression_tree/independent_variable.py @@ -33,6 +33,14 @@ def __init__(self, name, domain=None, auxiliary_domains=None, domains=None): name, domain=domain, auxiliary_domains=auxiliary_domains, domains=domains ) + @classmethod + def _from_json(cls, snippet: dict): + instance = cls.__new__(cls) + + instance.__init__(snippet["name"], domains=snippet["domains"]) + + return instance + def _evaluate_for_shape(self): """See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`""" return pybamm.evaluate_for_shape_using_domain(self.domains) @@ -58,6 +66,14 @@ class Time(IndependentVariable): def __init__(self): super().__init__("time") + @classmethod + def _from_json(cls, snippet: dict): + instance = cls.__new__(cls) + + instance.__init__() + + return instance + def create_copy(self): """See :meth:`pybamm.Symbol.new_copy()`.""" return Time() diff --git a/pybamm/expression_tree/input_parameter.py b/pybamm/expression_tree/input_parameter.py index 62c08bf0fd..e66a4c8cdc 100644 --- a/pybamm/expression_tree/input_parameter.py +++ b/pybamm/expression_tree/input_parameter.py @@ -35,6 +35,18 @@ def __init__(self, name, domain=None, expected_size=None): self._expected_size = expected_size super().__init__(name, domain=domain) + @classmethod + def _from_json(cls, snippet: dict): + instance = cls.__new__(cls) + + instance.__init__( + snippet["name"], + domain=snippet["domain"], + expected_size=snippet["expected_size"], + ) + + return instance + def create_copy(self): """See :meth:`pybamm.Symbol.new_copy()`.""" new_input_parameter = InputParameter( @@ -101,3 +113,17 @@ def _base_evaluate(self, t=None, y=None, y_dot=None, inputs=None): self._expected_size ) ) + + def to_json(self): + """ + Method to serialise an InputParameter object into JSON. + """ + + json_dict = { + "name": self.name, + "id": self.id, + "domain": self.domain, + "expected_size": self._expected_size, + } + + return json_dict diff --git a/pybamm/expression_tree/interpolant.py b/pybamm/expression_tree/interpolant.py index cd0df4d077..20d4e0180b 100644 --- a/pybamm/expression_tree/interpolant.py +++ b/pybamm/expression_tree/interpolant.py @@ -202,6 +202,27 @@ def __init__( self.interpolator = interpolator self.extrapolate = extrapolate + @classmethod + def _from_json(cls, snippet: dict): + """Create an Interpolant object from JSON data""" + instance = cls.__new__(cls) + + if len(snippet["x"]) == 1: + x = [np.array(x) for x in snippet["x"]] + else: + x = tuple(np.array(x) for x in snippet["x"]) + + instance.__init__( + x, + np.array(snippet["y"]), + snippet["children"], + name=snippet["name"], + interpolator=snippet["interpolator"], + extrapolate=snippet["extrapolate"], + ) + + return instance + @property def entries_string(self): return self._entries_string @@ -290,3 +311,19 @@ def _function_evaluate(self, evaluated_children): else: # pragma: no cover raise ValueError("Invalid dimension: {0}".format(self.dimension)) + + def to_json(self): + """ + Method to serialise an Interpolant object into JSON. + """ + + json_dict = { + "name": self.name, + "id": self.id, + "x": [x_item.tolist() for x_item in self.x], + "y": self.y.tolist(), + "interpolator": self.interpolator, + "extrapolate": self.extrapolate, + } + + return json_dict diff --git a/pybamm/expression_tree/operations/serialise.py b/pybamm/expression_tree/operations/serialise.py new file mode 100644 index 0000000000..c7768217a3 --- /dev/null +++ b/pybamm/expression_tree/operations/serialise.py @@ -0,0 +1,396 @@ +from __future__ import annotations + +import pybamm +from datetime import datetime +import json +import importlib +import numpy as np +import re + +from typing import Optional + + +class Serialise: + """ + Converts a discretised model to and from a JSON file. + + """ + + def __init__(self): + pass + + class _SymbolEncoder(json.JSONEncoder): + """Converts PyBaMM symbols into a JSON-serialisable format""" + + def default(self, node: dict): + node_dict = {"py/object": str(type(node))[8:-2], "py/id": id(node)} + if isinstance(node, pybamm.Symbol): + node_dict.update(node.to_json()) # this doesn't include children + node_dict["children"] = [] + for c in node.children: + node_dict["children"].append(self.default(c)) + + if hasattr(node, "initial_condition"): # for ExplicitTimeIntegral + node_dict["initial_condition"] = self.default( + node.initial_condition + ) + + return node_dict + + if isinstance(node, pybamm.Event): + node_dict.update(node.to_json()) + node_dict["expression"] = self.default(node._expression) + return node_dict + + node_dict["json"] = json.JSONEncoder.default(self, node) # pragma: no cover + return node_dict # pragma: no cover + + class _MeshEncoder(json.JSONEncoder): + """Converts PyBaMM meshes into a JSON-serialisable format""" + + def default(self, node: dict): + node_dict = {"py/object": str(type(node))[8:-2], "py/id": id(node)} + if isinstance(node, pybamm.Mesh): + node_dict.update(node.to_json()) + + node_dict["sub_meshes"] = {} + for k, v in node.items(): + if len(k) == 1 and "ghost cell" not in k[0]: + node_dict["sub_meshes"][k[0]] = self.default(v) + + return node_dict + + if isinstance(node, pybamm.SubMesh): + node_dict.update(node.to_json()) + return node_dict + + node_dict["json"] = json.JSONEncoder.default(self, node) # pragma: no cover + return node_dict # pragma: no cover + + class _Empty: + """A dummy class to aid deserialisation""" + + pass + + class _EmptyDict(dict): + """A dummy dictionary class to aid deserialisation""" + + pass + + def save_model( + self, + model: pybamm.BaseModel, + mesh: Optional[pybamm.Mesh] = None, + variables: Optional[pybamm.FuzzyDict] = None, + filename: Optional[str] = None, + ): + """Saves a discretised model to a JSON file. + + As the model is discretised and ready to solve, only the right hand side, + algebraic and initial condition variables are saved. + + Parameters + ---------- + model : :class:`pybamm.BaseModel` + The discretised model to be saved + mesh : :class:`pybamm.Mesh` (optional) + The mesh the model has been discretised over. Not neccesary to solve + the model when read in, but required to use pybamm's plotting tools. + variables: :class:`pybamm.FuzzyDict` (optional) + The discretised model varaibles. Not necessary to solve a model, but + required to use pybamm's plotting tools. + filename: str (optional) + The desired name of the JSON file. If no name is provided, one will be + created based on the model name, and the current datetime. + """ + if model.is_discretised is False: + raise NotImplementedError( + "PyBaMM can only serialise a discretised, ready-to-solve model." + ) + + model_json = { + "py/object": str(type(model))[8:-2], + "py/id": id(model), + "pybamm_version": pybamm.__version__, + "name": model.name, + "options": model.options, + "bounds": [bound.tolist() for bound in model.bounds], + "concatenated_rhs": self._SymbolEncoder().default(model._concatenated_rhs), + "concatenated_algebraic": self._SymbolEncoder().default( + model._concatenated_algebraic + ), + "concatenated_initial_conditions": self._SymbolEncoder().default( + model._concatenated_initial_conditions + ), + "events": [self._SymbolEncoder().default(event) for event in model.events], + "mass_matrix": self._SymbolEncoder().default(model.mass_matrix), + "mass_matrix_inv": self._SymbolEncoder().default(model.mass_matrix_inv), + } + + if mesh: + model_json["mesh"] = self._MeshEncoder().default(mesh) + + if variables: + if model._geometry: + model_json["geometry"] = self._deconstruct_pybamm_dicts(model._geometry) + model_json["variables"] = { + k: self._SymbolEncoder().default(v) for k, v in dict(variables).items() + } + + if filename is None: + filename = model.name + "_" + datetime.now().strftime("%Y_%m_%d-%p%I_%M") + + with open(filename + ".json", "w") as f: + json.dump(model_json, f) + + def load_model( + self, filename: str, battery_model: Optional[pybamm.BaseModel] = None + ) -> pybamm.BaseModel: + """ + Loads a discretised, ready to solve model into PyBaMM. + + A new pybamm battery model instance will be created, which can be solved + and the results plotted as usual. + + Currently only available for pybamm models which have previously been written + out using the `save_model()` option. + + Warning: This only loads in discretised models. If you wish to make edits to the + model or initial conditions, a new model will need to be constructed seperately. + + Parameters + ---------- + + filename: str + Path to the JSON file containing the serialised model file + battery_model: :class:`pybamm.BaseModel` (optional) + PyBaMM model to be created (e.g. pybamm.lithium_ion.SPM), which will + override any model names within the file. If None, the function will look + for the saved object path, present if the original model came from PyBaMM. + + Returns + ------- + :class:`pybamm.BaseModel` + A PyBaMM model object, of type specified either in the JSON or in + `battery_model`. + """ + + with open(filename, "r") as f: + model_data = json.load(f) + + recon_model_dict = { + "name": model_data["name"], + "options": self._convert_options(model_data["options"]), + "bounds": tuple(np.array(bound) for bound in model_data["bounds"]), + "concatenated_rhs": self._reconstruct_expression_tree( + model_data["concatenated_rhs"] + ), + "concatenated_algebraic": self._reconstruct_expression_tree( + model_data["concatenated_algebraic"] + ), + "concatenated_initial_conditions": self._reconstruct_expression_tree( + model_data["concatenated_initial_conditions"] + ), + "events": [ + self._reconstruct_expression_tree(event) + for event in model_data["events"] + ], + "mass_matrix": self._reconstruct_expression_tree(model_data["mass_matrix"]), + "mass_matrix_inv": self._reconstruct_expression_tree( + model_data["mass_matrix_inv"] + ), + } + + recon_model_dict["geometry"] = ( + self._reconstruct_pybamm_dict(model_data["geometry"]) + if "geometry" in model_data.keys() + else None + ) + + recon_model_dict["mesh"] = ( + self._reconstruct_mesh(model_data["mesh"]) + if "mesh" in model_data.keys() + else None + ) + + recon_model_dict["variables"] = ( + { + k: self._reconstruct_expression_tree(v) + for k, v in model_data["variables"].items() + } + if "variables" in model_data.keys() + else None + ) + + if battery_model: + return battery_model.deserialise(recon_model_dict) + + if "py/object" in model_data.keys(): + model_framework = self._get_pybamm_class(model_data) + return model_framework.deserialise(recon_model_dict) + + raise TypeError( + """ + The PyBaMM battery model to use has not been provided. + """ + ) + + # Helper functions + + def _get_pybamm_class(self, snippet: dict): + """Find a pybamm class to initialise from object path""" + parts = snippet["py/object"].split(".") + module = importlib.import_module(".".join(parts[:-1])) + + class_ = getattr(module, parts[-1]) + + try: + empty_class = self._Empty() + empty_class.__class__ = class_ + except TypeError: + # Mesh objects have a different layouts + empty_class = self._EmptyDict() + empty_class.__class__ = class_ + + return empty_class + + def _deconstruct_pybamm_dicts(self, dct: dict): + """ + Converts dictionaries which contain pybamm classes as keys + into a json serialisable format. + + Dictionary keys present as pybamm objects are given a seperate key + as "symbol_" to store the dictionary required to reconstruct + a symbol, and their seperate key is used in the original dictionary. E.G: + + {'rod': + {SpatialVariable(name='spat_var'): {"min":0.0, "max":2.0} } + } + + converts to + + {'rod': + {'symbol_spat_var': {"min":0.0, "max":2.0} }, + 'spat_var': + {"py/object":pybamm....} + } + + Dictionaries which don't contain pybamm symbols are returned unchanged. + """ + + def nested_convert(obj): + if isinstance(obj, dict): + new_dict = {} + for k, v in obj.items(): + if isinstance(k, pybamm.Symbol): + new_k = self._SymbolEncoder().default(k) + new_dict["symbol_" + new_k["name"]] = new_k + k = new_k["name"] + new_dict[k] = nested_convert(v) + return new_dict + return obj + + try: + _ = json.dumps(dct) + return dict(dct) + except TypeError: # dct must contain pybamm objects + return nested_convert(dct) + + def _reconstruct_symbol(self, dct: dict): + """Reconstruct an individual pybamm Symbol""" + symbol_class = self._get_pybamm_class(dct) + symbol = symbol_class._from_json(dct) + return symbol + + def _reconstruct_expression_tree(self, node: dict): + """ + Loop through an expression tree creating pybamm Symbol classes + + Conducts post-order tree traversal to turn each tree node into a + `pybamm.Symbol` class, starting from leaf nodes without children and + working upwards. + + Parameters + ---------- + node: dict + A node in an expression tree. + """ + if "children" in node: + for i, c in enumerate(node["children"]): + child_obj = self._reconstruct_expression_tree(c) + node["children"][i] = child_obj + elif "expression" in node: + expression_obj = self._reconstruct_expression_tree(node["expression"]) + node["expression"] = expression_obj + + obj = self._reconstruct_symbol(node) + + return obj + + def _reconstruct_mesh(self, node: dict): + """Reconstructs a Mesh object""" + if "sub_meshes" in node: + for k, v in node["sub_meshes"].items(): + sub_mesh = self._reconstruct_symbol(v) + node["sub_meshes"][k] = sub_mesh + + new_mesh = self._reconstruct_symbol(node) + + return new_mesh + + def _reconstruct_pybamm_dict(self, obj: dict): + """ + pybamm.Geometry can contain PyBaMM symbols as dictionary keys. + + Converts + {"rod": + {"symbol_spat_var": + {"min":0.0, "max":2.0} }, + "spat_var": + {"py/object":"pybamm...."} + } + + from an exported JSON file to + + {"rod": + {SpatialVariable(name="spat_var"): {"min":0.0, "max":2.0} } + } + """ + + def recurse(obj): + if isinstance(obj, dict): + new_dict = {} + for k, v in obj.items(): + if "symbol_" in k: + new_dict[k] = self._reconstruct_symbol(v) + elif isinstance(v, dict): + new_dict[k] = recurse(v) + else: + new_dict[k] = v + + pattern = re.compile("symbol_") + symbol_keys = {k: v for k, v in new_dict.items() if pattern.match(k)} + + # rearrange the dictionary to make pybamm objects the dictionary keys + if symbol_keys: + for k, v in symbol_keys.items(): + new_dict[v] = new_dict[k.lstrip("symbol_")] + del new_dict[k] + del new_dict[k.lstrip("symbol_")] + + return new_dict + return obj + + return recurse(obj) + + def _convert_options(self, d): + """ + Converts a dictionary with nested lists to nested tuples, + used to convert model options back into correct format + """ + if isinstance(d, dict): + return {k: self._convert_options(v) for k, v in d.items()} + elif isinstance(d, list): + return tuple(self._convert_options(item) for item in d) + else: + return d diff --git a/pybamm/expression_tree/parameter.py b/pybamm/expression_tree/parameter.py index eebe77ad2f..ade925d608 100644 --- a/pybamm/expression_tree/parameter.py +++ b/pybamm/expression_tree/parameter.py @@ -50,6 +50,17 @@ def to_equation(self): else: return sympy.Symbol(self.name) + def to_json(self): + raise NotImplementedError( + "pybamm.Parameter: Serialisation is only implemented for discretised models" + ) + + @classmethod + def _from_json(cls, snippet): + raise NotImplementedError( + "pybamm.Parameter: Please use a discretised model when reading in from JSON" + ) + class FunctionParameter(pybamm.Symbol): """ @@ -223,3 +234,16 @@ def to_equation(self): return sympy.Symbol(self.print_name) else: return sympy.Symbol(self.name) + + def to_json(self): + raise NotImplementedError( + "pybamm.FunctionParameter:" + "Serialisation is only implemented for discretised models." + ) + + @classmethod + def _from_json(cls, snippet): + raise NotImplementedError( + "pybamm.FunctionParameter:" + "Please use a discretised model when reading in from JSON." + ) diff --git a/pybamm/expression_tree/scalar.py b/pybamm/expression_tree/scalar.py index 0209c02a8e..64a3893fc9 100644 --- a/pybamm/expression_tree/scalar.py +++ b/pybamm/expression_tree/scalar.py @@ -28,6 +28,14 @@ def __init__(self, value, name=None): super().__init__(name) + @classmethod + def _from_json(cls, snippet: dict): + instance = cls.__new__(cls) + + instance.__init__(snippet["value"], name=snippet["name"]) + + return instance + def __str__(self): return str(self.value) @@ -74,3 +82,12 @@ def to_equation(self): return sympy.Symbol(self.print_name) else: return self.value + + def to_json(self): + """ + Method to serialise a Symbol object into JSON. + """ + + json_dict = {"name": self.name, "id": self.id, "value": self.value} + + return json_dict diff --git a/pybamm/expression_tree/state_vector.py b/pybamm/expression_tree/state_vector.py index 6ef8bee904..72b1ed18a5 100644 --- a/pybamm/expression_tree/state_vector.py +++ b/pybamm/expression_tree/state_vector.py @@ -73,6 +73,21 @@ def __init__( domains=domains, ) + @classmethod + def _from_json(cls, snippet: dict): + instance = cls.__new__(cls) + + y_slices = [slice(s["start"], s["stop"], s["step"]) for s in snippet["y_slice"]] + + instance.__init__( + *y_slices, + name=snippet["name"], + domains=snippet["domains"], + evaluation_array=snippet["evaluation_array"], + ) + + return instance + @property def y_slices(self): return self._y_slices @@ -194,6 +209,28 @@ def _evaluate_for_shape(self): """ return np.nan * np.ones((self.size, 1)) + def to_json(self): + """ + Method to serialise a StateVector object into JSON. + """ + + json_dict = { + "name": self.name, + "id": self.id, + "domains": self.domains, + "y_slice": [ + { + "start": y.start, + "stop": y.stop, + "step": y.step, + } # are there ever more than 1? + for y in self.y_slices + ], + "evaluation_array": list(self.evaluation_array), + } + + return json_dict + class StateVector(StateVectorBase): """ diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index 8f1608e7ba..8857584385 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -232,6 +232,26 @@ def __init__( ): self.test_shape() + @classmethod + def _from_json(cls, snippet: dict): + """ + Reconstructs a Symbol instance during deserialisation of a JSON file. + + Parameters + ---------- + snippet: dict + Contains the information needed to reconstruct a specific instance. + At minimum, should contain "name", "children" and "domains". + """ + + instance = cls.__new__(cls) + + instance.__init__( + snippet["name"], children=snippet["children"], domains=snippet["domains"] + ) + + return instance + @property def children(self): """ @@ -988,3 +1008,16 @@ def print_name(self, name): def to_equation(self): sympy = have_optional_dependency("sympy") return sympy.Symbol(str(self.name)) + + def to_json(self): + """ + Method to serialise a Symbol object into JSON. + """ + + json_dict = { + "name": self.name, + "id": self.id, + "domains": self.domains, + } + + return json_dict diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index 81c3dc28c2..67f0a85252 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -32,6 +32,21 @@ def __init__(self, name, child, domains=None): super().__init__(name, children=[child], domains=domains) self.child = self.children[0] + @classmethod + def _from_json(cls, snippet: dict): + """Use to instantiate when deserialising""" + + instance = cls.__new__(cls) + + super(UnaryOperator, instance).__init__( + snippet["name"], + snippet["children"], + domains=snippet["domains"], + ) + instance.child = instance.children[0] + + return instance + def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" return "{}({!s})".format(self.name, self.child) @@ -154,6 +169,10 @@ def __init__(self, child): """See :meth:`pybamm.UnaryOperator.__init__()`.""" super().__init__("sign", child) + @classmethod + def _from_json(cls, snippet: dict): + raise NotImplementedError() + def diff(self, variable): """See :meth:`pybamm.Symbol.diff()`.""" return pybamm.Scalar(0) @@ -271,6 +290,25 @@ def __init__(self, child, index, name=None, check_size=True): if isinstance(index, int): self.clear_domains() + @classmethod + def _from_json(cls, snippet: dict): + """See :meth:`pybamm.UnaryOperator._from_json()`.""" + instance = cls.__new__(cls) + + index = slice( + snippet["index"]["start"], + snippet["index"]["stop"], + snippet["index"]["step"], + ) + + instance.__init__( + snippet["children"][0], + index, + name=snippet["name"], + check_size=snippet["check_size"], + ) + return instance + def _unary_jac(self, child_jac): """See :meth:`pybamm.UnaryOperator._unary_jac()`.""" @@ -315,6 +353,24 @@ def _evaluates_on_edges(self, dimension): """See :meth:`pybamm.Symbol._evaluates_on_edges()`.""" return False + def to_json(self): + """ + Method to serialise an Index object into JSON. + """ + + json_dict = { + "name": self.name, + "id": self.id, + "index": { + "start": self.slice.start, + "stop": self.slice.stop, + "step": self.slice.step, + }, + "check_size": False, + } + + return json_dict + class SpatialOperator(UnaryOperator): """ @@ -338,6 +394,19 @@ class with a :class:`Matrix` def __init__(self, name, child, domains=None): super().__init__(name, child, domains) + def to_json(self): + raise NotImplementedError( + "pybamm.SpatialOperator:" + "Serialisation is only implemented for discretised models." + ) + + @classmethod + def _from_json(cls, snippet): + raise NotImplementedError( + "pybamm.SpatialOperator:" + "Please use a discretised model when reading in from JSON." + ) + class Gradient(SpatialOperator): """ @@ -687,7 +756,8 @@ class DefiniteIntegralVector(SpatialOperator): Parameters ---------- variable : :class:`pybamm.Symbol` - The variable whose basis will be integrated over the entire domain + The variable whose basis will be integrated over the entire domain (will + become self.children[0]) vector_type : str, optional Whether to return a row or column vector (default is row) """ @@ -916,12 +986,34 @@ def __init__(self, children, initial_condition): super().__init__("explicit time integral", children) self.initial_condition = initial_condition + @classmethod + def _from_json(cls, snippet: dict): + instance = cls.__new__(cls) + + instance.__init__(snippet["children"][0], snippet["initial_condition"]) + + return instance + def _unary_new_copy(self, child): return self.__class__(child, self.initial_condition) def is_constant(self): return False + def to_json(self): + """ + Convert ExplicitTimeIntegral to json for serialisation. + + Both `children` and `initial_condition` contain Symbols, and are therefore + dealt with by `pybamm.Serialise._SymbolEncoder.default()` directly. + """ + json_dict = { + "name": self.name, + "id": self.id, + } + + return json_dict + class BoundaryGradient(BoundaryOperator): """ diff --git a/pybamm/expression_tree/variable.py b/pybamm/expression_tree/variable.py index 0d1e1fd424..eb36a29604 100644 --- a/pybamm/expression_tree/variable.py +++ b/pybamm/expression_tree/variable.py @@ -130,6 +130,13 @@ def to_equation(self): else: return self.name + def to_json( + self, + ): + raise NotImplementedError( + "pybamm.Variable: Serialisation is only implemented for discretised models." + ) + class Variable(VariableBase): """ diff --git a/pybamm/meshes/meshes.py b/pybamm/meshes/meshes.py index 4c86290a2f..182282319f 100644 --- a/pybamm/meshes/meshes.py +++ b/pybamm/meshes/meshes.py @@ -120,6 +120,21 @@ def __init__(self, geometry, submesh_types, var_pts): # add ghost meshes self.add_ghost_meshes() + @classmethod + def _from_json(cls, snippet: dict): + instance = cls.__new__(cls) + super(Mesh, instance).__init__() + + instance.submesh_pts = snippet["submesh_pts"] + instance.base_domains = snippet["base_domains"] + + for k, v in snippet["sub_meshes"].items(): + instance[k] = v + + # instance.add_ghost_meshes() + + return instance + def __getitem__(self, domains): if isinstance(domains, str): domains = (domains,) @@ -216,6 +231,14 @@ def geometry(self): def geometry(self, geometry): self._geometry = geometry + def to_json(self): + json_dict = { + "submesh_pts": self.submesh_pts, + "base_domains": self.base_domains, + } + + return json_dict + class SubMesh: """ diff --git a/pybamm/meshes/one_dimensional_submeshes.py b/pybamm/meshes/one_dimensional_submeshes.py index 2beae6bc3a..d68745daec 100644 --- a/pybamm/meshes/one_dimensional_submeshes.py +++ b/pybamm/meshes/one_dimensional_submeshes.py @@ -34,7 +34,7 @@ def __init__(self, edges, coord_sys, tabs=None): self.internal_boundaries = [] # Add tab locations in terms of "left" and "right" - if tabs: + if tabs and "negative tab" not in tabs.keys(): self.tabs = {} l_z = self.edges[-1] @@ -52,6 +52,9 @@ def near(x, point, tol=3e-16): f"{tab} tab located at {tab_location}, " f"but must be at either 0 or {l_z}" ) + elif tabs: + # tabs have already been calculated by a serialised model + self.tabs = tabs def read_lims(self, lims): # Separate limits and tabs @@ -70,6 +73,17 @@ def read_lims(self, lims): return spatial_var, spatial_lims, tabs + def to_json(self): + json_dict = { + "edges": self.edges.tolist(), + "coord_sys": self.coord_sys, + } + + if hasattr(self, "tabs"): + json_dict["tabs"] = self.tabs + + return json_dict + class Uniform1DSubMesh(SubMesh1D): """ @@ -95,6 +109,18 @@ def __init__(self, lims, npts): super().__init__(edges, coord_sys=coord_sys, tabs=tabs) + @classmethod + def _from_json(cls, snippet: dict): + instance = cls.__new__(cls) + + tabs = snippet["tabs"] if "tabs" in snippet.keys() else None + + super(Uniform1DSubMesh, instance).__init__( + np.array(snippet["edges"]), snippet["coord_sys"], tabs=tabs + ) + + return instance + class Exponential1DSubMesh(SubMesh1D): """ diff --git a/pybamm/meshes/scikit_fem_submeshes.py b/pybamm/meshes/scikit_fem_submeshes.py index 23c024dbbb..8f80d6f5ce 100644 --- a/pybamm/meshes/scikit_fem_submeshes.py +++ b/pybamm/meshes/scikit_fem_submeshes.py @@ -35,6 +35,9 @@ def __init__(self, edges, coord_sys, tabs): self.npts = len(self.edges["y"]) * len(self.edges["z"]) self.coord_sys = coord_sys + # save tabs for serialisation + self.tabs = tabs + # create mesh self.fem_mesh = skfem.MeshTri.init_tensor(self.edges["y"], self.edges["z"]) @@ -142,6 +145,15 @@ def between(x, interval, tol=3e-16): else: raise pybamm.GeometryError("tab location not valid") + def to_json(self): + json_dict = { + "edges": {k: v.tolist() for k, v in self.edges.items()}, + "coord_sys": self.coord_sys, + "tabs": self.tabs, + } + + return json_dict + class ScikitUniform2DSubMesh(ScikitSubMesh2D): """ @@ -178,6 +190,18 @@ def __init__(self, lims, npts): super().__init__(edges, coord_sys, tabs) + @classmethod + def _from_json(cls, snippet: dict): + instance = cls.__new__(cls) + + edges = {k: np.array(v) for k, v in snippet["edges"].items()} + + super(ScikitUniform2DSubMesh, instance).__init__( + edges, snippet["coord_sys"], snippet["tabs"] + ) + + return instance + class ScikitExponential2DSubMesh(ScikitSubMesh2D): """ diff --git a/pybamm/meshes/zero_dimensional_submesh.py b/pybamm/meshes/zero_dimensional_submesh.py index 5b2f38e29f..dd4afe70fd 100644 --- a/pybamm/meshes/zero_dimensional_submesh.py +++ b/pybamm/meshes/zero_dimensional_submesh.py @@ -38,6 +38,23 @@ def __init__(self, position, npts=None): self.coord_sys = None self.npts = 1 + @classmethod + def _from_json(cls, snippet): + instance = cls.__new__(cls) + + instance.nodes = np.array(snippet["spatial_position"]) + instance.edges = np.array(snippet["spatial_position"]) + instance.coord_sys = None + instance.npts = 1 + + return instance + def add_ghost_meshes(self): # No ghost meshes to be added to this class pass + + def to_json(self): + json_dict = { + "spatial_position": self.nodes.tolist(), + } + return json_dict diff --git a/pybamm/models/base_model.py b/pybamm/models/base_model.py index 08890757b7..ed26a9062a 100644 --- a/pybamm/models/base_model.py +++ b/pybamm/models/base_model.py @@ -2,6 +2,7 @@ # Base model class # import numbers +import warnings from collections import OrderedDict import copy @@ -9,6 +10,7 @@ import numpy as np import pybamm +from pybamm.expression_tree.operations.serialise import Serialise from pybamm.util import have_optional_dependency @@ -123,6 +125,61 @@ def __init__(self, name="Unnamed model"): self.is_discretised = False self.y_slices = None + @classmethod + def deserialise(cls, properties: dict): + """ + Create a model instance from a serialised object. + """ + instance = cls.__new__(cls) + + # append the model name with _saved to differentiate + instance.__init__(name=properties["name"] + "_saved") + + instance.options = properties["options"] + + # Initialise model with stored variables that have already been discretised + instance._concatenated_rhs = properties["concatenated_rhs"] + instance._concatenated_algebraic = properties["concatenated_algebraic"] + instance._concatenated_initial_conditions = properties[ + "concatenated_initial_conditions" + ] + + instance.len_rhs = instance.concatenated_rhs.size + instance.len_alg = instance.concatenated_algebraic.size + instance.len_rhs_and_alg = instance.len_rhs + instance.len_alg + + instance.bounds = properties["bounds"] + instance.events = properties["events"] + instance.mass_matrix = properties["mass_matrix"] + instance.mass_matrix_inv = properties["mass_matrix_inv"] + + # add optional properties not required for model to solve + if properties["variables"]: + instance._variables = pybamm.FuzzyDict(properties["variables"]) + + # assign meshes to each variable + for var in instance._variables.values(): + if var.domain != []: + var.mesh = properties["mesh"][var.domain] + else: + var.mesh = None + + if var.domains["secondary"] != []: + var.secondary_mesh = properties["mesh"][var.domains["secondary"]] + else: + var.secondary_mesh = None + + if properties["geometry"]: + instance._geometry = pybamm.Geometry(properties["geometry"]) + else: + # Delete the default variables which have not been discretised + instance._variables = pybamm.FuzzyDict({}) + + # Model has already been discretised + instance.is_discretised = True + + return instance + @property def name(self): return self._name @@ -1139,6 +1196,43 @@ def process_parameters_and_discretise(self, symbol, parameter_values, disc): return disc_symbol + def save_model(self, filename=None, mesh=None, variables=None): + """ + Write out a discretised model to a JSON file + + Parameters + ---------- + filename: str, optional + The desired name of the JSON file. If no name is provided, one will be created + based on the model name, and the current datetime. + """ + if variables and not mesh: + warnings.warn( + """ + Serialisation: Variables are being saved without a mesh. + Plotting may not be available. + """, + pybamm.ModelWarning, + ) + + Serialise().save_model(self, filename=filename, mesh=mesh, variables=variables) + + +def load_model(filename, battery_model: BaseModel = None): + """ + Load in a saved model from a JSON file + + Parameters + ---------- + filename: str + Path to the JSON file containing the serialised model file + battery_model: :class: pybamm.BaseBatteryModel, optional + PyBaMM model to be created (e.g. pybamm.lithium_ion.SPM), which will + override any model names within the file. If None, the function will look + for the saved object path, present if the original model came from PyBaMM. + """ + return Serialise().load_model(filename, battery_model) + # helper functions for finding symbols def find_symbol_in_tree(tree, name): diff --git a/pybamm/models/event.py b/pybamm/models/event.py index e93262641d..5bba4cd14b 100644 --- a/pybamm/models/event.py +++ b/pybamm/models/event.py @@ -46,6 +46,28 @@ def __init__(self, name, expression, event_type=EventType.TERMINATION): self._expression = expression self._event_type = event_type + @classmethod + def _from_json(cls, snippet: dict): + """ + Reconstructs an Event instance during deserialisation of a JSON file. + + Parameters + ---------- + snippet: dict + Contains the information needed to reconstruct a specific instance. + Should contain "name", "expression" and "event_type". + """ + + instance = cls.__new__(cls) + + instance.__init__( + snippet["name"], + snippet["expression"], + event_type=EventType(snippet["event_type"][1]), + ) + + return instance + def evaluate(self, t=None, y=None, y_dot=None, inputs=None): """ Acts as a drop-in replacement for :func:`pybamm.Symbol.evaluate` @@ -66,3 +88,21 @@ def expression(self): @property def event_type(self): return self._event_type + + def to_json(self): + """ + Method to serialise an Event object into JSON. + + The expression is written out seperately, + See :meth:`pybamm.Serialise._SymbolEncoder.default()` + """ + + # event_type contains string name, for JSON readability, + # and value for deserialisation. + + json_dict = { + "name": self._name, + "event_type": [str(self._event_type), self._event_type.value], + } + + return json_dict diff --git a/pybamm/models/full_battery_models/base_battery_model.py b/pybamm/models/full_battery_models/base_battery_model.py index ee3e0b5c6f..b174ef581c 100644 --- a/pybamm/models/full_battery_models/base_battery_model.py +++ b/pybamm/models/full_battery_models/base_battery_model.py @@ -6,6 +6,8 @@ from functools import cached_property import warnings +from pybamm.expression_tree.operations.serialise import Serialise + def represents_positive_integer(s): """Check if a string represents a positive integer""" @@ -823,6 +825,60 @@ def __init__(self, options=None, name="Unnamed battery model"): super().__init__(name) self.options = options + @classmethod + def deserialise(cls, properties: dict): + """ + Create a model instance from a serialised object. + """ + instance = cls.__new__(cls) + + # append the model name with _saved to differentiate + instance.__init__( + options=properties["options"], name=properties["name"] + "_saved" + ) + + # Initialise model with stored variables that have already been discretised + instance._concatenated_rhs = properties["concatenated_rhs"] + instance._concatenated_algebraic = properties["concatenated_algebraic"] + instance._concatenated_initial_conditions = properties[ + "concatenated_initial_conditions" + ] + + instance.len_rhs = instance.concatenated_rhs.size + instance.len_alg = instance.concatenated_algebraic.size + instance.len_rhs_and_alg = instance.len_rhs + instance.len_alg + + instance.bounds = properties["bounds"] + instance.events = properties["events"] + instance.mass_matrix = properties["mass_matrix"] + instance.mass_matrix_inv = properties["mass_matrix_inv"] + + # add optional properties not required for model to solve + if properties["variables"]: + instance._variables = pybamm.FuzzyDict(properties["variables"]) + + # assign meshes to each variable + for var in instance._variables.values(): + if var.domain != []: + var.mesh = properties["mesh"][var.domain] + else: + var.mesh = None + + if var.domains["secondary"] != []: + var.secondary_mesh = properties["mesh"][var.domains["secondary"]] + else: + var.secondary_mesh = None + + instance._geometry = pybamm.Geometry(properties["geometry"]) + else: + # Delete the default variables which have not been discretised + instance._variables = pybamm.FuzzyDict({}) + + # Model has already been discretised + instance.is_discretised = True + + return instance + @property def default_geometry(self): return pybamm.battery_geometry(options=self.options) @@ -1409,3 +1465,20 @@ def set_soc_variables(self): This function is overriden by the base battery models """ pass + + def save_model(self, filename=None, mesh=None, variables=None): + """ + Write out a discretised model to a JSON file + + Parameters + ---------- + filename: str, optional + The desired name of the JSON file. If no name is provided, one will be created + based on the model name, and the current datetime. + """ + if variables and not mesh: + raise ValueError( + "Serialisation: Please provide the mesh if variables are required" + ) + + Serialise().save_model(self, filename=filename, mesh=mesh, variables=variables) diff --git a/pybamm/models/full_battery_models/lithium_ion/msmr.py b/pybamm/models/full_battery_models/lithium_ion/msmr.py index 3ca07c4ef8..f1ec7f90bd 100644 --- a/pybamm/models/full_battery_models/lithium_ion/msmr.py +++ b/pybamm/models/full_battery_models/lithium_ion/msmr.py @@ -19,7 +19,7 @@ def __init__(self, options=None, name="MSMR", build=True): options["open-circuit potential"] ) ) - elif "particle" in options and options["particle"] == "MSMR": + elif "particle" in options and options["particle"] != "MSMR": raise pybamm.OptionError( "'particle' must be 'MSMR' for MSMR not '{}'".format( options["particle"] @@ -27,7 +27,7 @@ def __init__(self, options=None, name="MSMR", build=True): ) elif ( "intercalation kinetics" in options - and options["intercalation kinetics"] == "MSMR" + and options["intercalation kinetics"] != "MSMR" ): raise pybamm.OptionError( "'intercalation kinetics' must be 'MSMR' for MSMR not '{}'".format( diff --git a/pybamm/plotting/quick_plot.py b/pybamm/plotting/quick_plot.py index ff657ee375..1521bd753f 100644 --- a/pybamm/plotting/quick_plot.py +++ b/pybamm/plotting/quick_plot.py @@ -141,6 +141,10 @@ def __init__( f"No default output variables provided for {models[0].name}" ) + # check variables have been provided after any serialisation + if any(len(m.variables) == 0 for m in models): + raise AttributeError("No variables to plot") + self.n_rows = n_rows or int( len(output_variables) // np.sqrt(len(output_variables)) ) diff --git a/pybamm/simulation.py b/pybamm/simulation.py index f743f4fc0f..83a386fe98 100644 --- a/pybamm/simulation.py +++ b/pybamm/simulation.py @@ -10,6 +10,9 @@ from functools import lru_cache from datetime import timedelta from pybamm.util import have_optional_dependency +from typing import Optional + +from pybamm.expression_tree.operations.serialise import Serialise def is_notebook(): @@ -793,7 +796,9 @@ def solve( # Hacky patch to allow correct processing of end_time and next_starting time # For efficiency purposes, op_conds treats identical steps as the same object # regardless of the initial time. Should be refactored as part of #3176 - op_conds_unproc = self.experiment.operating_conditions_steps_unprocessed[idx] + op_conds_unproc = ( + self.experiment.operating_conditions_steps_unprocessed[idx] + ) start_time = current_solution.t[-1] @@ -1188,6 +1193,50 @@ def save(self, filename): with open(filename, "wb") as f: pickle.dump(self, f, pickle.HIGHEST_PROTOCOL) + def save_model( + self, + filename: Optional[str] = None, + mesh: bool = False, + variables: bool = False, + ): + """ + Write out a discretised model to a JSON file + + Parameters + ---------- + mesh: bool + The mesh used to discretise the model. If false, plotting tools will not + be available when the model is read back in and solved. + variables: bool + The discretised variables. Not required to solve a model, but if false + tools will not be availble. Will automatically save meshes as well, required + for plotting tools. + filename: str, optional + The desired name of the JSON file. If no name is provided, one will be + created based on the model name, and the current datetime. + """ + mesh = self.mesh if (mesh or variables) else None + variables = self.built_model.variables if variables else None + + if self.operating_mode == "with experiment": + raise NotImplementedError( + """ + Serialising models coupled to experiments is not yet supported. + """ + ) + + if self.built_model: + Serialise().save_model( + self.built_model, filename=filename, mesh=mesh, variables=variables + ) + else: + raise NotImplementedError( + """ + PyBaMM can only serialise a discretised model. + Ensure the model has been built (e.g. run `build()`) before saving. + """ + ) + def load_sim(filename): """Load a saved simulation""" diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index c2b81c1568..dbc2bfe875 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -762,9 +762,14 @@ def solve( # Make sure model isn't empty if len(model.rhs) == 0 and len(model.algebraic) == 0: if not isinstance(self, pybamm.DummySolver): - raise pybamm.ModelError( - "Cannot solve empty model, use `pybamm.DummySolver` instead" - ) + # check for a discretised model without original parameters + if not ( + model.concatenated_rhs is not None + or model.concatenated_algebraic is not None + ): + raise pybamm.ModelError( + "Cannot solve empty model, use `pybamm.DummySolver` instead" + ) # t_eval can only be None if the solver is an algebraic solver. In that case # set it to 0 diff --git a/tests/integration/test_models/standard_model_tests.py b/tests/integration/test_models/standard_model_tests.py index 9341122d84..d4074e15ef 100644 --- a/tests/integration/test_models/standard_model_tests.py +++ b/tests/integration/test_models/standard_model_tests.py @@ -5,6 +5,7 @@ import tests import numpy as np +import os class StandardModelTest(object): @@ -138,6 +139,44 @@ def test_sensitivities(self, param_name, param_value, output_name="Voltage [V]") atol=1e-6, ) + def test_serialisation(self, solver=None, t_eval=None): + self.model.save_model( + "test_model", variables=self.model.variables, mesh=self.disc.mesh + ) + + new_model = pybamm.load_model("test_model.json") + + # create new solver for re-created model + if solver is not None: + new_solver = solver + else: + new_solver = new_model.default_solver + + if isinstance(new_model, pybamm.lithium_ion.BaseModel): + new_solver.rtol = 1e-8 + new_solver.atol = 1e-8 + + accuracy = 5 + + Crate = abs( + self.parameter_values["Current function [A]"] + / self.parameter_values["Nominal cell capacity [A.h]"] + ) + # don't allow zero C-rate + if Crate == 0: + Crate = 1 + if t_eval is None: + t_eval = np.linspace(0, 3600 / Crate, 100) + + new_solution = new_solver.solve(new_model, t_eval) + + for x, val in enumerate(self.solution.all_ys): + np.testing.assert_array_almost_equal( + new_solution.all_ys[x], self.solution.all_ys[x], decimal=accuracy + ) + + os.remove("test_model.json") + def test_all( self, param=None, disc=None, solver=None, t_eval=None, skip_output_tests=False ): @@ -152,6 +191,7 @@ def test_all( ) and not skip_output_tests ): + self.test_serialisation(solver, t_eval) self.test_outputs() diff --git a/tests/unit/test_expression_tree/test_array.py b/tests/unit/test_expression_tree/test_array.py index da79dbb6e0..b75c313f47 100644 --- a/tests/unit/test_expression_tree/test_array.py +++ b/tests/unit/test_expression_tree/test_array.py @@ -3,6 +3,7 @@ # from tests import TestCase import unittest +import unittest.mock as mock import numpy as np import sympy @@ -41,6 +42,28 @@ def test_to_equation(self): pybamm.Array([1, 2]).to_equation(), sympy.Array([[1.0], [2.0]]) ) + def test_to_from_json(self): + arr = pybamm.Array(np.array([1, 2, 3])) + + json_dict = { + "name": "Array of shape (3, 1)", + "id": mock.ANY, + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "entries": [[1.0], [2.0], [3.0]], + } + + # array to json conversion + created_json = arr.to_json() + self.assertEqual(created_json, json_dict) + + # json to array conversion + self.assertEqual(pybamm.Array._from_json(created_json), arr) + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_expression_tree/test_binary_operators.py b/tests/unit/test_expression_tree/test_binary_operators.py index 225f8e93c9..20decfeb6f 100644 --- a/tests/unit/test_expression_tree/test_binary_operators.py +++ b/tests/unit/test_expression_tree/test_binary_operators.py @@ -3,6 +3,7 @@ # from tests import TestCase import unittest +import unittest.mock as mock import numpy as np from scipy.sparse import coo_matrix @@ -10,6 +11,13 @@ import pybamm from pybamm.util import have_optional_dependency +EMPTY_DOMAINS = { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], +} + class TestBinaryOperators(TestCase): def test_binary_operator(self): @@ -771,6 +779,72 @@ def test_to_equation(self): # Test NotEqualHeaviside self.assertEqual(pybamm.NotEqualHeaviside(2, 4).to_equation(), True) + def test_to_json(self): + # Test Addition + add_json = { + "name": "+", + "id": mock.ANY, + "domains": EMPTY_DOMAINS, + } + add = pybamm.Addition(2, 4) + + self.assertEqual(add.to_json(), add_json) + + add_json["children"] = [pybamm.Scalar(2), pybamm.Scalar(4)] + self.assertEqual(pybamm.Addition._from_json(add_json), add) + + # Test Power + pow_json = { + "name": "**", + "id": mock.ANY, + "domains": EMPTY_DOMAINS, + } + + pow = pybamm.Power(7, 2) + self.assertEqual(pow.to_json(), pow_json) + + pow_json["children"] = [pybamm.Scalar(7), pybamm.Scalar(2)] + self.assertEqual(pybamm.Power._from_json(pow_json), pow) + + # Test Division + div_json = { + "name": "/", + "id": mock.ANY, + "domains": EMPTY_DOMAINS, + } + + div = pybamm.Division(10, 5) + self.assertEqual(div.to_json(), div_json) + + div_json["children"] = [pybamm.Scalar(10), pybamm.Scalar(5)] + self.assertEqual(pybamm.Division._from_json(div_json), div) + + # Test EqualHeaviside + equal_json = { + "name": "<=", + "id": mock.ANY, + "domains": EMPTY_DOMAINS, + } + + equal_h = pybamm.EqualHeaviside(2, 4) + self.assertEqual(equal_h.to_json(), equal_json) + + equal_json["children"] = [pybamm.Scalar(2), pybamm.Scalar(4)] + self.assertEqual(pybamm.EqualHeaviside._from_json(equal_json), equal_h) + + # Test notEqualHeaviside + not_equal_json = { + "name": "<", + "id": mock.ANY, + "domains": EMPTY_DOMAINS, + } + + ne_h = pybamm.NotEqualHeaviside(2, 4) + self.assertEqual(ne_h.to_json(), not_equal_json) + + not_equal_json["children"] = [pybamm.Scalar(2), pybamm.Scalar(4)] + self.assertEqual(pybamm.NotEqualHeaviside._from_json(not_equal_json), ne_h) + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_expression_tree/test_broadcasts.py b/tests/unit/test_expression_tree/test_broadcasts.py index 81d1210229..be8fe1a677 100644 --- a/tests/unit/test_expression_tree/test_broadcasts.py +++ b/tests/unit/test_expression_tree/test_broadcasts.py @@ -350,6 +350,15 @@ def test_diff(self): self.assertIsInstance(d, pybamm.Scalar) self.assertEqual(d.evaluate(y=y), 0) + def test_to_from_json_error(self): + a = pybamm.StateVector(slice(0, 1)) + b = pybamm.PrimaryBroadcast(a, "separator") + with self.assertRaises(NotImplementedError): + b.to_json() + + with self.assertRaises(NotImplementedError): + pybamm.PrimaryBroadcast._from_json({}) + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_expression_tree/test_concatenations.py b/tests/unit/test_expression_tree/test_concatenations.py index 4b07b09fea..691b6a7ee2 100644 --- a/tests/unit/test_expression_tree/test_concatenations.py +++ b/tests/unit/test_expression_tree/test_concatenations.py @@ -2,6 +2,7 @@ # Tests for the Concatenation class and subclasses # import unittest +import unittest.mock as mock from tests import TestCase import numpy as np @@ -383,6 +384,79 @@ def test_to_equation(self): # Test concat_sym self.assertEqual(pybamm.Concatenation(a, b).to_equation(), func_symbol) + def test_to_from_json(self): + # test DomainConcatenation + mesh = get_mesh_for_testing() + a = pybamm.Symbol("a", domain=["negative electrode"]) + b = pybamm.Symbol("b", domain=["separator", "positive electrode"]) + conc = pybamm.DomainConcatenation([a, b], mesh) + + json_dict = { + "name": "domain_concatenation", + "id": mock.ANY, + "domains": { + "primary": ["negative electrode", "separator", "positive electrode"], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "slices": { + "negative electrode": [{"start": 0, "stop": 40, "step": None}], + "separator": [{"start": 40, "stop": 65, "step": None}], + "positive electrode": [{"start": 65, "stop": 100, "step": None}], + }, + "size": 100, + "children_slices": [ + {"negative electrode": [{"start": 0, "stop": 40, "step": None}]}, + { + "separator": [{"start": 0, "stop": 25, "step": None}], + "positive electrode": [{"start": 25, "stop": 60, "step": None}], + }, + ], + "secondary_dimensions_npts": 1, + } + + self.assertEqual( + conc.to_json(), + json_dict, + ) + + # manually add children + json_dict["children"] = [a, b] + + # check symbol re-creation + self.assertEqual(pybamm.pybamm.DomainConcatenation._from_json(json_dict), conc) + + # ----------------------------- + # test NumpyConcatenation ----- + # ----------------------------- + + y = np.linspace(0, 1, 15)[:, np.newaxis] + a_np = pybamm.Vector(y[:5]) + b_np = pybamm.Vector(y[5:9]) + c_np = pybamm.Vector(y[9:]) + conc_np = pybamm.NumpyConcatenation(a_np, b_np, c_np) + + np_json = { + "name": "numpy_concatenation", + "id": mock.ANY, + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + } + + # test to_json + self.assertEqual(conc_np.to_json(), np_json) + + # add children + np_json["children"] = [a_np, b_np, c_np] + + # test _from_json + self.assertEqual(pybamm.NumpyConcatenation._from_json(np_json), conc_np) + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_expression_tree/test_functions.py b/tests/unit/test_expression_tree/test_functions.py index 6d22571a01..e9bd8522e6 100644 --- a/tests/unit/test_expression_tree/test_functions.py +++ b/tests/unit/test_expression_tree/test_functions.py @@ -3,6 +3,7 @@ # from tests import TestCase import unittest +import unittest.mock as mock import numpy as np from scipy import special @@ -146,8 +147,30 @@ def test_to_equation(self): # Test Function self.assertEqual(pybamm.Function(np.log, 10).to_equation(), 10.0) + def test_to_from_json_error(self): + a = pybamm.Symbol("a") + funca = pybamm.Function(test_function, a) + + with self.assertRaises(NotImplementedError): + funca.to_json() + + with self.assertRaises(NotImplementedError): + pybamm.Function._from_json({}) + class TestSpecificFunctions(TestCase): + def test_to_json(self): + a = pybamm.InputParameter("a") + fun = pybamm.cos(a) + + expected_json = { + "name": "function (cos)", + "id": mock.ANY, + "function": "cos", + } + + self.assertEqual(fun.to_json(), expected_json) + def test_arcsinh(self): a = pybamm.InputParameter("a") fun = pybamm.arcsinh(a) @@ -181,6 +204,15 @@ def test_arcsinh(self): pybamm.PrimaryBroadcast(pybamm.PrimaryBroadcast(fun, "test"), "test2"), ) + # test creation from json + input_json = { + "name": "arcsinh", + "id": mock.ANY, + "function": "arcsinh", + "children": [a], + } + self.assertEqual(pybamm.Arcsinh._from_json(input_json), fun) + def test_arctan(self): a = pybamm.InputParameter("a") fun = pybamm.arctan(a) @@ -197,6 +229,15 @@ def test_arctan(self): places=5, ) + # test creation from json + input_json = { + "name": "arctan", + "id": mock.ANY, + "function": "arctan", + "children": [a], + } + self.assertEqual(pybamm.Arctan._from_json(input_json), fun) + def test_cos(self): a = pybamm.InputParameter("a") fun = pybamm.cos(a) @@ -214,6 +255,15 @@ def test_cos(self): places=5, ) + # test creation from json + input_json = { + "name": "cos", + "id": mock.ANY, + "function": "cos", + "children": [a], + } + self.assertEqual(pybamm.Cos._from_json(input_json), fun) + def test_cosh(self): a = pybamm.InputParameter("a") fun = pybamm.cosh(a) @@ -231,6 +281,15 @@ def test_cosh(self): places=5, ) + # test creation from json + input_json = { + "name": "cosh", + "id": mock.ANY, + "function": "cosh", + "children": [a], + } + self.assertEqual(pybamm.Cosh._from_json(input_json), fun) + def test_exp(self): a = pybamm.InputParameter("a") fun = pybamm.exp(a) @@ -248,6 +307,15 @@ def test_exp(self): places=5, ) + # test creation from json + input_json = { + "name": "exp", + "id": mock.ANY, + "function": "exp", + "children": [a], + } + self.assertEqual(pybamm.Exp._from_json(input_json), fun) + def test_log(self): a = pybamm.InputParameter("a") fun = pybamm.log(a) @@ -277,6 +345,17 @@ def test_log(self): places=5, ) + # test creation from json + a = pybamm.InputParameter("a") + fun = pybamm.log(a) + input_json = { + "name": "log", + "id": mock.ANY, + "function": "log", + "children": [a], + } + self.assertEqual(pybamm.Log._from_json(input_json), fun) + def test_max(self): a = pybamm.StateVector(slice(0, 3)) y_test = np.array([1, 2, 3]) @@ -308,6 +387,15 @@ def test_sin(self): places=5, ) + # test creation from json + input_json = { + "name": "sin", + "id": mock.ANY, + "function": "sin", + "children": [a], + } + self.assertEqual(pybamm.Sin._from_json(input_json), fun) + def test_sinh(self): a = pybamm.InputParameter("a") fun = pybamm.sinh(a) @@ -325,6 +413,15 @@ def test_sinh(self): places=5, ) + # test creation from json + input_json = { + "name": "sinh", + "id": mock.ANY, + "function": "sinh", + "children": [a], + } + self.assertEqual(pybamm.Sinh._from_json(input_json), fun) + def test_sqrt(self): a = pybamm.InputParameter("a") fun = pybamm.sqrt(a) @@ -341,6 +438,15 @@ def test_sqrt(self): places=5, ) + # test creation from json + input_json = { + "name": "sqrt", + "id": mock.ANY, + "function": "sqrt", + "children": [a], + } + self.assertEqual(pybamm.Sqrt._from_json(input_json), fun) + def test_tanh(self): a = pybamm.InputParameter("a") fun = pybamm.tanh(a) @@ -371,6 +477,15 @@ def test_erf(self): places=5, ) + # test creation from json + input_json = { + "name": "erf", + "id": mock.ANY, + "function": "erf", + "children": [a], + } + self.assertEqual(pybamm.Erf._from_json(input_json), fun) + def test_erfc(self): a = pybamm.InputParameter("a") fun = pybamm.erfc(a) diff --git a/tests/unit/test_expression_tree/test_input_parameter.py b/tests/unit/test_expression_tree/test_input_parameter.py index 82dd06fee5..a5fc79f2e2 100644 --- a/tests/unit/test_expression_tree/test_input_parameter.py +++ b/tests/unit/test_expression_tree/test_input_parameter.py @@ -6,6 +6,8 @@ import pybamm import unittest +import unittest.mock as mock + class TestInputParameter(TestCase): def test_input_parameter_init(self): @@ -49,6 +51,22 @@ def test_errors(self): with self.assertRaises(KeyError): a.evaluate() + def test_to_from_json(self): + a = pybamm.InputParameter("a") + + json_dict = { + "name": "a", + "id": mock.ANY, + "domain": [], + "expected_size": 1, + } + + # to_json + self.assertEqual(a.to_json(), json_dict) + + # from_json + self.assertEqual(pybamm.InputParameter._from_json(json_dict), a) + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_expression_tree/test_interpolant.py b/tests/unit/test_expression_tree/test_interpolant.py index e1547ef3fc..5fa078cffc 100644 --- a/tests/unit/test_expression_tree/test_interpolant.py +++ b/tests/unit/test_expression_tree/test_interpolant.py @@ -5,6 +5,7 @@ import pybamm import unittest +import unittest.mock as mock import numpy as np @@ -325,6 +326,65 @@ def test_processing(self): self.assertEqual(interp, interp.new_copy()) + def test_to_from_json(self): + x = np.linspace(0, 1, 10) + y = pybamm.StateVector(slice(0, 2)) + interp = pybamm.Interpolant(x, 2 * x, y) + + expected_json = { + "name": "interpolating_function", + "id": mock.ANY, + "x": [ + [ + 0.0, + 0.1111111111111111, + 0.2222222222222222, + 0.3333333333333333, + 0.4444444444444444, + 0.5555555555555556, + 0.6666666666666666, + 0.7777777777777777, + 0.8888888888888888, + 1.0, + ] + ], + "y": [ + 0.0, + 0.2222222222222222, + 0.4444444444444444, + 0.6666666666666666, + 0.8888888888888888, + 1.1111111111111112, + 1.3333333333333333, + 1.5555555555555554, + 1.7777777777777777, + 2.0, + ], + "interpolator": "linear", + "extrapolate": True, + } + + # check correct writing to json + self.assertEqual(interp.to_json(), expected_json) + + expected_json["children"] = [y] + # check correct re-creation + self.assertEqual(pybamm.Interpolant._from_json(expected_json), interp) + + # test to_from_json for 2d x & y + x = (np.arange(-5.01, 5.01, 0.05), np.arange(-5.01, 5.01, 0.01)) + xx, yy = np.meshgrid(x[0], x[1], indexing="ij") + z = np.sin(xx**2 + yy**2) + var1 = pybamm.StateVector(slice(0, 1)) + var2 = pybamm.StateVector(slice(1, 2)) + # linear + interp = pybamm.Interpolant(x, z, (var1, var2), interpolator="linear") + + interp2d_json = interp.to_json() + interp2d_json["children"] = (var1, var2) + + self.assertEqual(pybamm.Interpolant._from_json(interp2d_json), interp) + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_expression_tree/test_matrix.py b/tests/unit/test_expression_tree/test_matrix.py index 39aba44483..055902b15e 100644 --- a/tests/unit/test_expression_tree/test_matrix.py +++ b/tests/unit/test_expression_tree/test_matrix.py @@ -4,8 +4,10 @@ from tests import TestCase import pybamm import numpy as np +from scipy.sparse import csr_matrix import unittest +import unittest.mock as mock class TestMatrix(TestCase): @@ -38,6 +40,29 @@ def test_matrix_operations(self): (self.mat @ self.vect).evaluate(), np.array([[5], [2], [3]]) ) + def test_to_from_json(self): + arr = pybamm.Matrix(csr_matrix([[0, 1, 0, 0], [0, 0, 0, 1]])) + json_dict = { + "name": "Sparse Matrix (2, 4)", + "id": mock.ANY, + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "entries": { + "column_pointers": [0, 1, 2], + "data": [1.0, 1.0], + "row_indices": [1, 3], + "shape": (2, 4), + }, + } + + self.assertEqual(arr.to_json(), json_dict) + + self.assertEqual(pybamm.Matrix._from_json(json_dict), arr) + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_expression_tree/test_parameter.py b/tests/unit/test_expression_tree/test_parameter.py index d9a756b45d..6940ac38fe 100644 --- a/tests/unit/test_expression_tree/test_parameter.py +++ b/tests/unit/test_expression_tree/test_parameter.py @@ -31,6 +31,15 @@ def test_to_equation(self): # Test name self.assertEqual(func1.to_equation(), sympy.Symbol("test_name")) + def test_to_json_error(self): + func = pybamm.Parameter("test_string") + + with self.assertRaises(NotImplementedError): + func.to_json() + + with self.assertRaises(NotImplementedError): + pybamm.Parameter._from_json({}) + class TestFunctionParameter(TestCase): def test_function_parameter_init(self): @@ -110,6 +119,15 @@ def test_function_parameter_to_equation(self): func1.print_name = None self.assertEqual(func1.to_equation(), sympy.Symbol("func")) + def test_to_json_error(self): + func = pybamm.FunctionParameter("test", {"x": pybamm.Scalar(1)}) + + with self.assertRaises(NotImplementedError): + func.to_json() + + with self.assertRaises(NotImplementedError): + pybamm.FunctionParameter._from_json({}) + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_expression_tree/test_scalar.py b/tests/unit/test_expression_tree/test_scalar.py index af0a6e80ca..34ea1aa514 100644 --- a/tests/unit/test_expression_tree/test_scalar.py +++ b/tests/unit/test_expression_tree/test_scalar.py @@ -3,6 +3,7 @@ # from tests import TestCase import unittest +import unittest.mock as mock import pybamm @@ -44,6 +45,14 @@ def test_copy(self): b = a.create_copy() self.assertEqual(a, b) + def test_to_from_json(self): + a = pybamm.Scalar(5) + json_dict = {"name": "5.0", "id": mock.ANY, "value": 5.0} + + self.assertEqual(a.to_json(), json_dict) + + self.assertEqual(pybamm.Scalar._from_json(json_dict), a) + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_expression_tree/test_state_vector.py b/tests/unit/test_expression_tree/test_state_vector.py index d401487264..18025c0aa3 100644 --- a/tests/unit/test_expression_tree/test_state_vector.py +++ b/tests/unit/test_expression_tree/test_state_vector.py @@ -6,6 +6,7 @@ import numpy as np import unittest +import unittest.mock as mock class TestStateVector(TestCase): @@ -62,6 +63,39 @@ def test_failure(self): with self.assertRaisesRegex(TypeError, "all y_slices must be slice objects"): pybamm.StateVector(slice(0, 10), 1) + def test_to_from_json(self): + original_debug_mode = pybamm.settings.debug_mode + pybamm.settings.debug_mode = False + + array = np.array([1, 2, 3, 4, 5]) + sv = pybamm.StateVector(slice(0, 10), evaluation_array=array) + + json_dict = { + "name": "y[0:10]", + "id": mock.ANY, + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "y_slice": [ + { + "start": 0, + "stop": 10, + "step": None, + } + ], + "evaluation_array": [1, 2, 3, 4, 5], + } + + self.assertEqual(sv.to_json(), json_dict) + + self.assertEqual(pybamm.StateVector._from_json(json_dict), sv) + + # Turn debug mode back to what is was before + pybamm.settings.debug_mode = original_debug_mode + class TestStateVectorDot(TestCase): def test_evaluate(self): diff --git a/tests/unit/test_expression_tree/test_symbol.py b/tests/unit/test_expression_tree/test_symbol.py index 3eb7adae47..9a7939c66d 100644 --- a/tests/unit/test_expression_tree/test_symbol.py +++ b/tests/unit/test_expression_tree/test_symbol.py @@ -4,6 +4,7 @@ from tests import TestCase import os import unittest +import unittest.mock as mock from tempfile import TemporaryDirectory import numpy as np @@ -491,6 +492,28 @@ def test_numpy_array_ufunc(self): x = pybamm.Symbol("x") self.assertEqual(np.exp(x), pybamm.exp(x)) + def test_to_from_json(self): + symc1 = pybamm.Symbol("child1", domain=["domain_1"]) + symc2 = pybamm.Symbol("child2", domain=["domain_2"]) + symp = pybamm.Symbol("parent", domain=["domain_3"], children=[symc1, symc2]) + + json_dict = { + "name": "parent", + "id": mock.ANY, + "domains": { + "primary": ["domain_3"], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + } + + self.assertEqual(symp.to_json(), json_dict) + + json_dict["children"] = [symc1, symc2] + + self.assertEqual(pybamm.Symbol._from_json(json_dict), symp) + class TestIsZero(TestCase): def test_is_scalar_zero(self): diff --git a/tests/unit/test_expression_tree/test_unary_operators.py b/tests/unit/test_expression_tree/test_unary_operators.py index fc845cb574..3c34db7dcd 100644 --- a/tests/unit/test_expression_tree/test_unary_operators.py +++ b/tests/unit/test_expression_tree/test_unary_operators.py @@ -3,6 +3,7 @@ # import unittest from tests import TestCase +import unittest.mock as mock import numpy as np from scipy.sparse import diags @@ -50,6 +51,20 @@ def test_negation(self): pybamm.PrimaryBroadcast(pybamm.PrimaryBroadcast(nega, "test"), "test2"), ) + # Test from_json + input_json = { + "name": "-", + "id": mock.ANY, + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "children": [a], + } + self.assertEqual(pybamm.Negate._from_json(input_json), nega) + def test_absolute(self): a = pybamm.Symbol("a") absa = pybamm.AbsoluteValue(a) @@ -77,6 +92,20 @@ def test_absolute(self): pybamm.PrimaryBroadcast(pybamm.PrimaryBroadcast(absa, "test"), "test2"), ) + # Test from_json + input_json = { + "name": "abs", + "id": mock.ANY, + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "children": [a], + } + self.assertEqual(pybamm.AbsoluteValue._from_json(input_json), absa) + def test_smooth_absolute_value(self): a = pybamm.StateVector(slice(0, 1)) expr = pybamm.smooth_absolute_value(a, 10) @@ -113,6 +142,11 @@ def test_sign(self): ), ) + # Test from_json + with self.assertRaises(NotImplementedError): + # signs are always scalar/array types in a discretised model + pybamm.Sign._from_json({}) + def test_floor(self): a = pybamm.Symbol("a") floora = pybamm.Floor(a) @@ -127,6 +161,20 @@ def test_floor(self): floorc = pybamm.Floor(c) self.assertEqual(floorc.evaluate(), -4) + # Test from_json + input_json = { + "name": "floor", + "id": mock.ANY, + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "children": [a], + } + self.assertEqual(pybamm.Floor._from_json(input_json), floora) + def test_ceiling(self): a = pybamm.Symbol("a") ceila = pybamm.Ceiling(a) @@ -141,6 +189,20 @@ def test_ceiling(self): ceilc = pybamm.Ceiling(c) self.assertEqual(ceilc.evaluate(), -3) + # Test from_json + input_json = { + "name": "ceil", + "id": mock.ANY, + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "children": [a], + } + self.assertEqual(pybamm.Ceiling._from_json(input_json), ceila) + def test_gradient(self): # gradient of scalar symbol should fail a = pybamm.Symbol("a") @@ -671,6 +733,62 @@ def test_explicit_time_integral(self): self.assertEqual(expr.new_copy(), expr) self.assertFalse(expr.is_constant()) + def test_to_from_json(self): + # UnaryOperator + a = pybamm.Symbol("a", domain=["test"]) + un = pybamm.UnaryOperator("unary test", a) + + un_json = { + "name": "unary test", + "id": mock.ANY, + "domains": { + "primary": ["test"], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + } + + self.assertEqual(un.to_json(), un_json) + + un_json["children"] = [a] + self.assertEqual(pybamm.UnaryOperator._from_json(un_json), un) + + # Index + vec = pybamm.StateVector(slice(0, 5)) + ind = pybamm.Index(vec, 3) + + ind_json = { + "name": "Index[3]", + "id": mock.ANY, + "index": {"start": 3, "stop": 4, "step": None}, + "check_size": False, + } + + self.assertEqual(ind.to_json(), ind_json) + + ind_json["children"] = [vec] + self.assertEqual(pybamm.Index._from_json(ind_json), ind) + + # SpatialOperator + spatial_vec = pybamm.SpatialOperator("name", vec) + with self.assertRaises(NotImplementedError): + spatial_vec.to_json() + + with self.assertRaises(NotImplementedError): + pybamm.SpatialOperator._from_json({}) + + # ExplicitTimeIntegral + expr = pybamm.ExplicitTimeIntegral(pybamm.Parameter("param"), pybamm.Scalar(1)) + + expr_json = {"name": "explicit time integral", "id": mock.ANY} + + self.assertEqual(expr.to_json(), expr_json) + + expr_json["children"] = [pybamm.Parameter("param")] + expr_json["initial_condition"] = [pybamm.Scalar(1)] + self.assertEqual(pybamm.ExplicitTimeIntegral._from_json(expr_json), expr) + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_expression_tree/test_variable.py b/tests/unit/test_expression_tree/test_variable.py index 583008f882..0d5aa251d2 100644 --- a/tests/unit/test_expression_tree/test_variable.py +++ b/tests/unit/test_expression_tree/test_variable.py @@ -64,6 +64,11 @@ def test_to_equation(self): # Test name self.assertEqual(pybamm.Variable("name").to_equation(), sympy.Symbol("name")) + def test_to_json_error(self): + func = pybamm.Variable("test_string") + with self.assertRaises(NotImplementedError): + func.to_json() + class TestVariableDot(TestCase): def test_variable_init(self): diff --git a/tests/unit/test_meshes/test_meshes.py b/tests/unit/test_meshes/test_meshes.py index 6563ba232d..3066d14534 100644 --- a/tests/unit/test_meshes/test_meshes.py +++ b/tests/unit/test_meshes/test_meshes.py @@ -390,6 +390,30 @@ def test_1plus1D_tabs_right_left(self): # positive tab should be "left" self.assertEqual(mesh["current collector"].tabs["positive tab"], "left") + def test_to_json(self): + r = pybamm.SpatialVariable( + "r", domain=["negative particle"], coord_sys="spherical polar" + ) + + geometry = { + "negative particle": {r: {"min": pybamm.Scalar(0), "max": pybamm.Scalar(1)}} + } + + submesh_types = {"negative particle": pybamm.Uniform1DSubMesh} + var_pts = {r: 20} + + # create mesh + mesh = pybamm.Mesh(geometry, submesh_types, var_pts) + + mesh_json = mesh.to_json() + + expected_json = { + "submesh_pts": {"negative particle": {"r": 20}}, + "base_domains": ["negative particle"], + } + + self.assertEqual(mesh_json, expected_json) + class TestMeshGenerator(TestCase): def test_init_name(self): diff --git a/tests/unit/test_meshes/test_one_dimensional_submesh.py b/tests/unit/test_meshes/test_one_dimensional_submesh.py index 207f5f2b6f..514de4248b 100644 --- a/tests/unit/test_meshes/test_one_dimensional_submesh.py +++ b/tests/unit/test_meshes/test_one_dimensional_submesh.py @@ -18,6 +18,36 @@ def test_exceptions(self): with self.assertRaises(pybamm.GeometryError): pybamm.SubMesh1D(edges, None, tabs=tabs) + def test_to_json(self): + edges = np.linspace(0, 1, 10) + tabs = {"negative": {"z_centre": 0}, "positive": {"z_centre": 1}} + mesh = pybamm.SubMesh1D(edges, None, tabs=tabs) + + mesh_json = mesh.to_json() + + expected_json = { + "edges": [ + 0.0, + 0.1111111111111111, + 0.2222222222222222, + 0.3333333333333333, + 0.4444444444444444, + 0.5555555555555556, + 0.6666666666666666, + 0.7777777777777777, + 0.8888888888888888, + 1.0, + ], + "coord_sys": None, + "tabs": {"negative tab": "left", "positive tab": "right"}, + } + + self.assertEqual(mesh_json, expected_json) + + # check tabs work + new_mesh = pybamm.Uniform1DSubMesh._from_json(mesh_json) + self.assertEqual(mesh.tabs, new_mesh.tabs) + class TestUniform1DSubMesh(TestCase): def test_exceptions(self): diff --git a/tests/unit/test_meshes/test_scikit_fem_submesh.py b/tests/unit/test_meshes/test_scikit_fem_submesh.py index 2e646e1085..1e0839250e 100644 --- a/tests/unit/test_meshes/test_scikit_fem_submesh.py +++ b/tests/unit/test_meshes/test_scikit_fem_submesh.py @@ -180,6 +180,109 @@ def test_tab_left_right(self): param.process_geometry(geometry) pybamm.Mesh(geometry, submesh_types, var_pts) + def test_to_json(self): + param = get_param() + geometry = pybamm.battery_geometry( + include_particles=False, options={"dimensionality": 2} + ) + param.process_geometry(geometry) + + var_pts = {"x_n": 10, "x_s": 7, "x_p": 12, "y": 16, "z": 24} + + submesh_types = { + "negative electrode": pybamm.Uniform1DSubMesh, + "separator": pybamm.Uniform1DSubMesh, + "positive electrode": pybamm.Uniform1DSubMesh, + "current collector": pybamm.MeshGenerator(pybamm.ScikitUniform2DSubMesh), + } + + # create mesh + mesh = pybamm.Mesh(geometry, submesh_types, var_pts) + + mesh_json = mesh.to_json() + + expected_json = { + "submesh_pts": { + "negative electrode": {"x_n": 10}, + "separator": {"x_s": 7}, + "positive electrode": {"x_p": 12}, + "current collector": {"y": 16, "z": 24}, + }, + "base_domains": [ + "negative electrode", + "separator", + "positive electrode", + "current collector", + ], + } + + self.assertEqual(mesh_json, expected_json) + + # test Uniform2DSubMesh serialisation + + submesh = mesh["current collector"].to_json() + + expected_submesh = { + "edges": { + "y": [ + 0.0, + 0.02666666666666667, + 0.05333333333333334, + 0.08, + 0.10666666666666667, + 0.13333333333333333, + 0.16, + 0.18666666666666668, + 0.21333333333333335, + 0.24000000000000002, + 0.26666666666666666, + 0.29333333333333333, + 0.32, + 0.3466666666666667, + 0.37333333333333335, + 0.4, + ], + "z": [ + 0.0, + 0.021739130434782608, + 0.043478260869565216, + 0.06521739130434782, + 0.08695652173913043, + 0.10869565217391304, + 0.13043478260869565, + 0.15217391304347827, + 0.17391304347826086, + 0.19565217391304346, + 0.21739130434782608, + 0.2391304347826087, + 0.2608695652173913, + 0.2826086956521739, + 0.30434782608695654, + 0.32608695652173914, + 0.34782608695652173, + 0.3695652173913043, + 0.3913043478260869, + 0.41304347826086957, + 0.43478260869565216, + 0.45652173913043476, + 0.4782608695652174, + 0.5, + ], + }, + "coord_sys": "cartesian", + "tabs": { + "negative": {"y_centre": 0.1, "z_centre": 0.5, "width": 0.1}, + "positive": {"y_centre": 0.3, "z_centre": 0.5, "width": 0.1}, + }, + } + + self.assertEqual(submesh, expected_submesh) + + new_submesh = pybamm.ScikitUniform2DSubMesh._from_json(submesh) + + for x, y in zip(mesh['current collector'].edges, new_submesh.edges): + np.testing.assert_array_equal(x, y) + class TestScikitFiniteElementChebyshev2DSubMesh(TestCase): def test_mesh_creation(self): diff --git a/tests/unit/test_models/test_base_model.py b/tests/unit/test_models/test_base_model.py index 4167d5fff5..438b7391a7 100644 --- a/tests/unit/test_models/test_base_model.py +++ b/tests/unit/test_models/test_base_model.py @@ -9,6 +9,7 @@ import casadi import numpy as np +from numpy import testing import pybamm @@ -982,6 +983,83 @@ def test_timescale_lengthscale_get_set_not_implemented(self): with self.assertRaises(NotImplementedError): model.length_scales = 1 + def test_save_load_model(self): + # Set up model + model = pybamm.BaseModel() + var_scalar = pybamm.Variable("var_scalar") + var_1D = pybamm.Variable("var_1D", domain="negative electrode") + var_2D = pybamm.Variable( + "var_2D", + domain="negative particle", + auxiliary_domains={"secondary": "negative electrode"}, + ) + var_concat_neg = pybamm.Variable("var_concat_neg", domain="negative electrode") + var_concat_sep = pybamm.Variable("var_concat_sep", domain="separator") + var_concat = pybamm.concatenation(var_concat_neg, var_concat_sep) + model.rhs = {var_scalar: -var_scalar, var_1D: -var_1D} + model.algebraic = {var_2D: -var_2D, var_concat: -var_concat} + model.initial_conditions = {var_scalar: 1, var_1D: 1, var_2D: 1, var_concat: 1} + model.variables = { + "var_scalar": var_scalar, + "var_1D": var_1D, + "var_2D": var_2D, + "var_concat_neg": var_concat_neg, + "var_concat_sep": var_concat_sep, + "var_concat": var_concat, + } + + # Discretise + geometry = { + "negative electrode": {"x_n": {"min": 0, "max": 1}}, + "separator": {"x_s": {"min": 1, "max": 2}}, + "negative particle": {"r_n": {"min": 0, "max": 1}}, + } + submeshes = { + "negative electrode": pybamm.Uniform1DSubMesh, + "separator": pybamm.Uniform1DSubMesh, + "negative particle": pybamm.Uniform1DSubMesh, + } + var_pts = {"x_n": 10, "x_s": 10, "r_n": 5} + mesh = pybamm.Mesh(geometry, submeshes, var_pts) + spatial_methods = { + "negative electrode": pybamm.FiniteVolume(), + "separator": pybamm.FiniteVolume(), + "negative particle": pybamm.FiniteVolume(), + } + disc = pybamm.Discretisation(mesh, spatial_methods) + model_disc = disc.process_model(model, inplace=False) + t = np.linspace(0, 1) + y = np.tile(3 * t, (1 + 30 + 50, 1)) + + # Find baseline solution + solution = pybamm.Solution(t, y, model_disc, {}) + + # save model + model_disc.save_model(filename="test_base_model") + + # load without variables + new_model = pybamm.load_model("test_base_model.json") + + new_solution = pybamm.Solution(t, y, new_model, {}) + + # model solutions match + testing.assert_array_equal(solution.all_ys, new_solution.all_ys) + + # raises warning if variables are saved without mesh + with self.assertWarns(pybamm.ModelWarning): + model_disc.save_model( + filename="test_base_model", variables=model_disc.variables + ) + + model_disc.save_model( + filename="test_base_model", variables=model_disc.variables, mesh=mesh + ) + + # load with variables & mesh + new_model = pybamm.load_model("test_base_model.json") + + os.remove("test_base_model.json") + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_models/test_event.py b/tests/unit/test_models/test_event.py index 7d0d00f000..84b0dcde84 100644 --- a/tests/unit/test_models/test_event.py +++ b/tests/unit/test_models/test_event.py @@ -48,6 +48,28 @@ def test_event_types(self): event = pybamm.Event("my event", pybamm.Scalar(1), event_type) self.assertEqual(event.event_type, event_type) + def test_to_from_json(self): + expression = pybamm.Scalar(1) + event = pybamm.Event("my event", expression) + + event_json = { + "name": "my event", + "event_type": ["EventType.TERMINATION", 0], + } + + event_ser_json = event.to_json() + self.assertEqual(event_ser_json, event_json) + + event_json["expression"] = expression + + new_event = pybamm.Event._from_json(event_json) + + # check for equal expressions + self.assertEqual(new_event.expression, event.expression) + + # check for equal event types + self.assertEqual(new_event.event_type, event.event_type) + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_models/test_full_battery_models/test_base_battery_model.py b/tests/unit/test_models/test_full_battery_models/test_base_battery_model.py index 79c6d8a720..91bcfc28cc 100644 --- a/tests/unit/test_models/test_full_battery_models/test_base_battery_model.py +++ b/tests/unit/test_models/test_full_battery_models/test_base_battery_model.py @@ -7,6 +7,7 @@ import unittest import io from contextlib import redirect_stdout +import os OPTIONS_DICT = { "surface form": "differential", @@ -449,6 +450,29 @@ def test_option_type(self): model = pybamm.BaseBatteryModel(options) self.assertEqual(model.options, options) + def test_save_load_model(self): + model = ( + pybamm.lithium_ion.SPM() + ) + geometry = model.default_geometry + param = model.default_parameter_values + param.process_model(model) + param.process_geometry(geometry) + mesh = pybamm.Mesh(geometry, model.default_submesh_types, model.default_var_pts) + disc = pybamm.Discretisation(mesh, model.default_spatial_methods) + disc.process_model(model) + + # save model + model.save_model(filename="test_base_battery_model", mesh=mesh, + variables=model.variables) + + # raises error if variables are saved without mesh + with self.assertRaises(ValueError): + model.save_model(filename="test_base_battery_model", + variables=model.variables) + + os.remove("test_base_battery_model.json") + class TestOptions(TestCase): def test_print_options(self): diff --git a/tests/unit/test_serialisation/__init__.py b/tests/unit/test_serialisation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/test_serialisation/test_serialisation.py b/tests/unit/test_serialisation/test_serialisation.py new file mode 100644 index 0000000000..6c43eaa9d7 --- /dev/null +++ b/tests/unit/test_serialisation/test_serialisation.py @@ -0,0 +1,603 @@ +# +# Tests for the serialisation class +# +from tests import TestCase +import json +import os +import unittest +import unittest.mock as mock +from datetime import datetime +import numpy as np +import pybamm + +from numpy import testing +from pybamm.expression_tree.operations.serialise import Serialise + + +def scalar_var_dict(): + """variable, json pair for a pybamm.Scalar instance""" + a = pybamm.Scalar(5) + a_dict = { + "py/id": mock.ANY, + "py/object": "pybamm.expression_tree.scalar.Scalar", + "name": "5.0", + "id": mock.ANY, + "value": 5.0, + "children": [], + } + + return a, a_dict + + +def mesh_var_dict(): + """mesh, json pair for a pybamm.Mesh instance""" + + r = pybamm.SpatialVariable( + "r", domain=["negative particle"], coord_sys="spherical polar" + ) + + geometry = { + "negative particle": {r: {"min": pybamm.Scalar(0), "max": pybamm.Scalar(1)}} + } + + submesh_types = {"negative particle": pybamm.Uniform1DSubMesh} + var_pts = {r: 20} + + # create mesh + mesh = pybamm.Mesh(geometry, submesh_types, var_pts) + + mesh_json = { + "py/object": "pybamm.meshes.meshes.Mesh", + "py/id": mock.ANY, + "submesh_pts": {"negative particle": {"r": 20}}, + "base_domains": ["negative particle"], + "sub_meshes": { + "negative particle": { + "py/object": "pybamm.meshes.one_dimensional_submeshes.Uniform1DSubMesh", + "py/id": mock.ANY, + "edges": [ + 0.0, + 0.05, + 0.1, + 0.15000000000000002, + 0.2, + 0.25, + 0.30000000000000004, + 0.35000000000000003, + 0.4, + 0.45, + 0.5, + 0.55, + 0.6000000000000001, + 0.65, + 0.7000000000000001, + 0.75, + 0.8, + 0.8500000000000001, + 0.9, + 0.9500000000000001, + 1.0, + ], + "coord_sys": "spherical polar", + } + }, + } + + return mesh, mesh_json + + +class TestSerialiseModels(TestCase): + def test_user_defined_model_recreaction(self): + # Start with a base model + model = pybamm.BaseModel() + + # Define the variables and parameters + x = pybamm.SpatialVariable("x", domain="rod", coord_sys="cartesian") + T = pybamm.Variable("Temperature", domain="rod") + k = pybamm.Parameter("Thermal diffusivity") + + # Write the governing equations + N = -k * pybamm.grad(T) # Heat flux + Q = 1 - pybamm.Function(np.abs, x - 1) # Source term + dTdt = -pybamm.div(N) + Q + model.rhs = {T: dTdt} # add to model + + # Add the boundary and initial conditions + model.boundary_conditions = { + T: { + "left": (pybamm.Scalar(0), "Dirichlet"), + "right": (pybamm.Scalar(0), "Dirichlet"), + } + } + model.initial_conditions = {T: 2 * x - x**2} + + # Add desired output variables, geometry, parameters + model.variables = {"Temperature": T, "Heat flux": N, "Heat source": Q} + geometry = {"rod": {x: {"min": pybamm.Scalar(0), "max": pybamm.Scalar(2)}}} + param = pybamm.ParameterValues({"Thermal diffusivity": 0.75}) + + # Process model and geometry + param.process_model(model) + param.process_geometry(geometry) + + # Pick mesh, spatial method, and discretise + submesh_types = {"rod": pybamm.Uniform1DSubMesh} + var_pts = {x: 30} + mesh = pybamm.Mesh(geometry, submesh_types, var_pts) + spatial_methods = {"rod": pybamm.FiniteVolume()} + disc = pybamm.Discretisation(mesh, spatial_methods) + disc.process_model(model) + + # Solve + solver = pybamm.ScipySolver() + t = np.linspace(0, 1, 100) + solution = solver.solve(model, t) + + model.save_model("heat_equation", variables=model._variables, mesh=mesh) + new_model = pybamm.load_model("heat_equation.json") + + new_solver = pybamm.ScipySolver() + new_solution = new_solver.solve(new_model, t) + + for x, val in enumerate(solution.all_ys): + np.testing.assert_array_almost_equal( + solution.all_ys[x], new_solution.all_ys[x] + ) + os.remove("heat_equation.json") + + +class TestSerialise(TestCase): + # test the symbol encoder + + def test_symbol_encoder_symbol(self): + """test basic symbol encoder with & without children""" + + # without children + a, a_dict = scalar_var_dict() + + a_ser_json = Serialise._SymbolEncoder().default(a) + + self.assertEqual(a_ser_json, a_dict) + + # with children + add = pybamm.Addition(2, 4) + add_json = { + "py/id": mock.ANY, + "py/object": "pybamm.expression_tree.binary_operators.Addition", + "name": "+", + "id": mock.ANY, + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "children": [ + { + "py/id": mock.ANY, + "py/object": "pybamm.expression_tree.scalar.Scalar", + "name": "2.0", + "id": mock.ANY, + "value": 2.0, + "children": [], + }, + { + "py/id": mock.ANY, + "py/object": "pybamm.expression_tree.scalar.Scalar", + "name": "4.0", + "id": mock.ANY, + "value": 4.0, + "children": [], + }, + ], + } + + add_ser_json = Serialise._SymbolEncoder().default(add) + + self.assertEqual(add_ser_json, add_json) + + def test_symbol_encoder_explicitTimeIntegral(self): + """test symbol encoder with initial conditions""" + expr = pybamm.ExplicitTimeIntegral(pybamm.Scalar(5), pybamm.Scalar(1)) + + expr_json = { + "py/object": "pybamm.expression_tree.unary_operators.ExplicitTimeIntegral", + "py/id": mock.ANY, + "name": "explicit time integral", + "id": mock.ANY, + "children": [ + { + "py/object": "pybamm.expression_tree.scalar.Scalar", + "py/id": mock.ANY, + "name": "5.0", + "id": mock.ANY, + "value": 5.0, + "children": [], + } + ], + "initial_condition": { + "py/object": "pybamm.expression_tree.scalar.Scalar", + "py/id": mock.ANY, + "name": "1.0", + "id": mock.ANY, + "value": 1.0, + "children": [], + }, + } + + expr_ser_json = Serialise._SymbolEncoder().default(expr) + + self.assertEqual(expr_json, expr_ser_json) + + def test_symbol_encoder_event(self): + """test symbol encoder with event""" + + expression = pybamm.Scalar(1) + event = pybamm.Event("my event", expression) + + event_json = { + "py/object": "pybamm.models.event.Event", + "py/id": mock.ANY, + "name": "my event", + "event_type": ["EventType.TERMINATION", 0], + "expression": { + "py/object": "pybamm.expression_tree.scalar.Scalar", + "py/id": mock.ANY, + "name": "1.0", + "id": mock.ANY, + "value": 1.0, + "children": [], + }, + } + + event_ser_json = Serialise._SymbolEncoder().default(event) + self.assertEqual(event_ser_json, event_json) + + # test the mesh encoder + def test_mesh_encoder(self): + mesh, mesh_json = mesh_var_dict() + + # serialise mesh + mesh_ser_json = Serialise._MeshEncoder().default(mesh) + + self.assertEqual(mesh_ser_json, mesh_json) + + def test_deconstruct_pybamm_dicts(self): + """tests serialisation of dictionaries with pybamm classes as keys""" + + x = pybamm.SpatialVariable("x", "negative electrode") + + test_dict = {"rod": {x: {"min": 0.0, "max": 2.0}}} + + ser_dict = { + "rod": { + "symbol_x": { + "py/object": "pybamm.expression_tree.independent_variable.SpatialVariable", + "py/id": mock.ANY, + "name": "x", + "id": mock.ANY, + "domains": { + "primary": ["negative electrode"], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "children": [], + }, + "x": {"min": 0.0, "max": 2.0}, + } + } + + self.assertEqual(Serialise()._deconstruct_pybamm_dicts(test_dict), ser_dict) + + def test_get_pybamm_class(self): + # symbol + _, scalar_dict = scalar_var_dict() + + scalar_class = Serialise()._get_pybamm_class(scalar_dict) + + self.assertIsInstance(scalar_class, pybamm.Scalar) + + # mesh + _, mesh_dict = mesh_var_dict() + + mesh_class = Serialise()._get_pybamm_class(mesh_dict) + + self.assertIsInstance(mesh_class, pybamm.Mesh) + + with self.assertRaises(AttributeError): + unrecognised_symbol = { + "py/id": mock.ANY, + "py/object": "pybamm.expression_tree.scalar.Scale", + "name": "5.0", + "id": mock.ANY, + "value": 5.0, + "children": [], + } + Serialise()._get_pybamm_class(unrecognised_symbol) + + def test_reconstruct_symbol(self): + scalar, scalar_dict = scalar_var_dict() + + new_scalar = Serialise()._reconstruct_symbol(scalar_dict) + + self.assertEqual(new_scalar, scalar) + + def test_reconstruct_expression_tree(self): + y = pybamm.StateVector(slice(0, 1)) + t = pybamm.t + equation = 2 * y + t + + equation_json = { + "py/object": "pybamm.expression_tree.binary_operators.Addition", + "py/id": 139691619709376, + "name": "+", + "id": -2595875552397011963, + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "children": [ + { + "py/object": "pybamm.expression_tree.binary_operators.Multiplication", + "py/id": 139691619709232, + "name": "*", + "id": 6094209803352873499, + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "children": [ + { + "py/object": "pybamm.expression_tree.scalar.Scalar", + "py/id": 139691619709040, + "name": "2.0", + "id": 1254626814648295285, + "value": 2.0, + "children": [], + }, + { + "py/object": "pybamm.expression_tree.state_vector.StateVector", + "py/id": 139691619589760, + "name": "y[0:1]", + "id": 5063056989669636089, + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "y_slice": [{"start": 0, "stop": 1, "step": None}], + "evaluation_array": [True], + "children": [], + }, + ], + }, + { + "py/object": "pybamm.expression_tree.independent_variable.Time", + "py/id": 139692083289392, + "name": "time", + "id": -3301344124754766351, + "domains": { + "primary": [], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "children": [], + }, + ], + } + + new_equation = Serialise()._reconstruct_expression_tree(equation_json) + + self.assertEqual(new_equation, equation) + + def test_reconstruct_mesh(self): + mesh, mesh_dict = mesh_var_dict() + + new_mesh = Serialise()._reconstruct_mesh(mesh_dict) + + testing.assert_array_equal( + new_mesh["negative particle"].edges, mesh["negative particle"].edges + ) + testing.assert_array_equal( + new_mesh["negative particle"].nodes, mesh["negative particle"].nodes + ) + + # reconstructed meshes are only used for plotting, geometry not reconstructed. + with self.assertRaisesRegex( + AttributeError, "'Mesh' object has no attribute '_geometry'" + ): + self.assertEqual(new_mesh.geometry, mesh.geometry) + + def test_reconstruct_pybamm_dict(self): + x = pybamm.SpatialVariable("x", "negative electrode") + + test_dict = {"rod": {x: {"min": 0.0, "max": 2.0}}} + + ser_dict = { + "rod": { + "symbol_x": { + "py/object": "pybamm.expression_tree.independent_variable.SpatialVariable", + "py/id": mock.ANY, + "name": "x", + "id": mock.ANY, + "domains": { + "primary": ["negative electrode"], + "secondary": [], + "tertiary": [], + "quaternary": [], + }, + "children": [], + }, + "x": {"min": 0.0, "max": 2.0}, + } + } + + new_dict = Serialise()._reconstruct_pybamm_dict(ser_dict) + + self.assertEqual(new_dict, test_dict) + + # test recreation if not passed a dict + test_list = ["left", "right"] + new_list = Serialise()._reconstruct_pybamm_dict(test_list) + + self.assertEqual(test_list, new_list) + + def test_convert_options(self): + options_dict = { + "current collector": "uniform", + "particle phases": ["2", "1"], + "open-circuit potential": [["single", "current sigmoid"], "single"], + } + + options_result = { + "current collector": "uniform", + "particle phases": ("2", "1"), + "open-circuit potential": (("single", "current sigmoid"), "single"), + } + + self.assertEqual(Serialise()._convert_options(options_dict), options_result) + + def test_save_load_model(self): + model = pybamm.lithium_ion.SPM(name="test_spm") + geometry = model.default_geometry + param = model.default_parameter_values + param.process_model(model) + param.process_geometry(geometry) + mesh = pybamm.Mesh(geometry, model.default_submesh_types, model.default_var_pts) + + # test error if not discretised + with self.assertRaisesRegex( + NotImplementedError, + "PyBaMM can only serialise a discretised, ready-to-solve model", + ): + Serialise().save_model(model, filename="test_model") + + disc = pybamm.Discretisation(mesh, model.default_spatial_methods) + disc.process_model(model) + + # default save + Serialise().save_model(model, filename="test_model") + self.assertTrue(os.path.exists("test_model.json")) + + # default save where filename isn't provided + Serialise().save_model(model) + filename = "test_spm_" + datetime.now().strftime("%Y_%m_%d-%p%I_%M") + ".json" + self.assertTrue(os.path.exists(filename)) + os.remove(filename) + + # default load + new_model = Serialise().load_model("test_model.json") + + # check new model solves + new_solver = new_model.default_solver + new_solution = new_solver.solve(new_model, [0, 3600]) + + # check an error is raised when plotting the solution + with self.assertRaisesRegex( + AttributeError, + "No variables to plot", + ): + new_solution.plot() + + # load when specifying the battery model to use + newest_model = Serialise().load_model( + "test_model.json", battery_model=pybamm.lithium_ion.SPM + ) + + # Test for error if no model type is provided + with open("test_model.json", "r") as f: + model_data = json.load(f) + del model_data["py/object"] + + with open("test_model.json", "w") as f: + json.dump(model_data, f) + + with self.assertRaises(TypeError): + Serialise().load_model("test_model.json") + + os.remove("test_model.json") + + # check new model solves + newest_solver = newest_model.default_solver + newest_solver.solve(newest_model, [0, 3600]) + + def test_save_experiment_model_error(self): + model = pybamm.lithium_ion.SPM() + experiment = pybamm.Experiment(["Discharge at 1C for 1 hour"]) + sim = pybamm.Simulation(model, experiment=experiment) + sim.solve() + + with self.assertRaisesRegex( + NotImplementedError, + "Serialising models coupled to experiments is not yet supported.", + ): + sim.save_model("spm_experiment", mesh=False, variables=False) + + def test_serialised_model_plotting(self): + # models without a mesh + model = pybamm.BaseModel() + c = pybamm.Variable("c") + model.rhs = {c: -c} + model.initial_conditions = {c: 1} + model.variables["c"] = c + model.variables["2c"] = 2 * c + + # setup and discretise + _ = pybamm.ScipySolver().solve(model, np.linspace(0, 1)) + + Serialise().save_model( + model, + variables=model.variables, + filename="test_base_model", + ) + + new_model = Serialise().load_model("test_base_model.json") + os.remove("test_base_model.json") + + new_solution = pybamm.ScipySolver().solve(new_model, np.linspace(0, 1)) + + # check dynamic plot loads + new_solution.plot(["c", "2c"], testing=True) + + # models with a mesh ---------------- + model = pybamm.lithium_ion.SPM(name="test_spm_plotting") + geometry = model.default_geometry + param = model.default_parameter_values + param.process_model(model) + param.process_geometry(geometry) + mesh = pybamm.Mesh(geometry, model.default_submesh_types, model.default_var_pts) + disc = pybamm.Discretisation(mesh, model.default_spatial_methods) + disc.process_model(model) + + Serialise().save_model( + model, + variables=model.variables, + mesh=mesh, + filename="test_plotting_model", + ) + + new_model = Serialise().load_model("test_plotting_model.json") + os.remove("test_plotting_model.json") + + new_solver = new_model.default_solver + new_solution = new_solver.solve(new_model, [0, 3600]) + + # check dynamic plot loads + new_solution.plot(testing=True) + + +if __name__ == "__main__": + print("Add -v for more debug output") + import sys + + if "-v" in sys.argv: + debug = True + pybamm.settings.debug_mode = True + unittest.main() diff --git a/tests/unit/test_simulation.py b/tests/unit/test_simulation.py index dac94a2538..ac70f0b43b 100644 --- a/tests/unit/test_simulation.py +++ b/tests/unit/test_simulation.py @@ -363,6 +363,26 @@ def test_save_load_dae(self): sim_load = pybamm.load_sim(test_name) self.assertEqual(sim.model.name, sim_load.model.name) + def test_save_load_model(self): + model = pybamm.lead_acid.LOQS({"surface form": "algebraic"}) + model.use_jacobian = True + sim = pybamm.Simulation(model) + + # test exception if not discretised + with self.assertRaises(NotImplementedError): + sim.save_model("sim_save") + + # save after solving + sim.solve([0, 600]) + sim.save_model("sim_save") + + # load model + saved_model = pybamm.load_model("sim_save.json") + + self.assertEqual(model.options, saved_model.options) + + os.remove("sim_save.json") + def test_plot(self): sim = pybamm.Simulation(pybamm.lithium_ion.SPM())