Skip to content

Commit

Permalink
Merge pull request #1316 from pybamm-team/issue-1221-convert-to-casadi
Browse files Browse the repository at this point in the history
Issue 1221 convert to casadi
  • Loading branch information
valentinsulzer authored Dec 31, 2020
2 parents 66b9589 + 4b84f39 commit 0926080
Show file tree
Hide file tree
Showing 28 changed files with 480 additions and 327 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

## Optimizations

- Variables are now post-processed using CasADi ([#1316](https://github.com/pybamm-team/PyBaMM/pull/1316))
- Operations such as `1*x` and `0+x` now directly return `x` ([#1252](https://github.com/pybamm-team/PyBaMM/pull/1252))

## Bug fixes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,23 @@
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mWARNING: You are using pip version 20.2.1; however, version 20.2.4 is available.\n",
"\u001b[33mWARNING: You are using pip version 20.2.4; however, version 20.3.3 is available.\n",
"You should consider upgrading via the '/Users/vsulzer/Documents/Energy_storage/PyBaMM/.tox/dev/bin/python -m pip install --upgrade pip' command.\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<pybamm.solvers.solution.Solution at 0x1467c6820>"
"<pybamm.solvers.solution.Solution at 0x1479dff10>"
]
},
"execution_count": 1,
"metadata": {},
"execution_count": 1
"output_type": "execute_result"
}
],
"source": [
Expand Down Expand Up @@ -102,33 +102,33 @@
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([3.77047806, 3.75305163, 3.74567013, 3.74038819, 3.73581198,\n",
" 3.73153388, 3.72742394, 3.72343938, 3.71956644, 3.71580196,\n",
" 3.71214617, 3.70860034, 3.70516557, 3.70184247, 3.69863116,\n",
" 3.69553115, 3.69254136, 3.6896602 , 3.68688564, 3.68421527,\n",
"array([3.77047806, 3.75305182, 3.74567027, 3.74038822, 3.73581196,\n",
" 3.73153391, 3.72742393, 3.72343929, 3.71956623, 3.71580184,\n",
" 3.71214621, 3.7086004 , 3.70516561, 3.70184253, 3.69863121,\n",
" 3.69553118, 3.69254137, 3.68966018, 3.68688562, 3.68421526,\n",
" 3.68164637, 3.67917591, 3.6768006 , 3.67451688, 3.67232094,\n",
" 3.6702087 , 3.66817572, 3.66621717, 3.66432763, 3.66250091,\n",
" 3.66072975, 3.65900537, 3.65731692, 3.65565067, 3.65398896,\n",
" 3.65230898, 3.65058136, 3.6487688 , 3.64682545, 3.64469796,\n",
" 3.64232964, 3.63966968, 3.63668791, 3.63339298, 3.62984705,\n",
" 3.62616685, 3.62250444, 3.61901236, 3.61580864, 3.61295718,\n",
" 3.61046845, 3.60831404, 3.60644483, 3.60480596, 3.60334607,\n",
" 3.67020869, 3.66817572, 3.66621717, 3.66432762, 3.6625009 ,\n",
" 3.66072974, 3.65900536, 3.65731692, 3.65565066, 3.65398895,\n",
" 3.65230898, 3.65058135, 3.6487688 , 3.64682546, 3.64469798,\n",
" 3.64232968, 3.63966973, 3.63668796, 3.63339303, 3.62984711,\n",
" 3.62616692, 3.6225045 , 3.61901241, 3.61580868, 3.6129572 ,\n",
" 3.61046847, 3.60831405, 3.60644483, 3.60480596, 3.60334607,\n",
" 3.60202167, 3.60079822, 3.5996495 , 3.59855637, 3.59750531,\n",
" 3.59648723, 3.59549638, 3.59452954, 3.59358541, 3.59266405,\n",
" 3.59176646, 3.59089417, 3.59004885, 3.58923192, 3.58844407,\n",
" 3.58768477, 3.58695179, 3.58624057, 3.58554372, 3.58485045,\n",
" 3.58414611, 3.58341187, 3.58262441, 3.58175587, 3.58077378,\n",
" 3.57964098, 3.57831538, 3.5767492 , 3.57488745, 3.57266504,\n",
" 3.5700019 , 3.56679523, 3.56290767, 3.5581495 , 3.55225276,\n",
" 3.54483362, 3.53533853, 3.52296795, 3.50656968, 3.48449277,\n",
" 3.45439366, 3.41299183, 3.35578872, 3.27680073, 3.16842637])"
" 3.5700019 , 3.56679523, 3.56290766, 3.5581495 , 3.55225276,\n",
" 3.54483361, 3.53533853, 3.52296795, 3.50656968, 3.48449277,\n",
" 3.45439366, 3.41299182, 3.35578871, 3.27680072, 3.16842636])"
]
},
"execution_count": 4,
"metadata": {},
"execution_count": 4
"output_type": "execute_result"
}
],
"source": [
Expand All @@ -148,7 +148,6 @@
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([ 0. , 36.36363636, 72.72727273, 109.09090909,\n",
Expand Down Expand Up @@ -178,8 +177,9 @@
" 3490.90909091, 3527.27272727, 3563.63636364, 3600. ])"
]
},
"execution_count": 5,
"metadata": {},
"execution_count": 5
"output_type": "execute_result"
}
],
"source": [
Expand All @@ -199,14 +199,14 @@
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([3.72947891, 3.70860034, 3.67810702, 3.65400558])"
"array([3.72947892, 3.7086004 , 3.67810702, 3.65400557])"
]
},
"execution_count": 6,
"metadata": {},
"execution_count": 6
"output_type": "execute_result"
}
],
"source": [
Expand Down Expand Up @@ -265,16 +265,18 @@
"metadata": {},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "interactive(children=(FloatSlider(value=0.0, description='t', max=1.0, step=0.01), Output()), _dom_classes=('w…",
"application/vnd.jupyter.widget-view+json": {
"model_id": "8ceaea70446149079f0194450d7828dc",
"version_major": 2,
"version_minor": 0,
"model_id": "0b4ebac3fdd947218f9444b2b381cf04"
}
"version_minor": 0
},
"text/plain": [
"interactive(children=(FloatSlider(value=0.0, description='t', max=1.0, step=0.01), Output()), _dom_classes=('w…"
]
},
"metadata": {}
"metadata": {},
"output_type": "display_data"
}
],
"source": [
Expand Down Expand Up @@ -311,26 +313,28 @@
"metadata": {},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "interactive(children=(FloatSlider(value=0.0, description='t', max=1.0, step=0.01), Output()), _dom_classes=('w…",
"application/vnd.jupyter.widget-view+json": {
"model_id": "9c9e516a7aef46688f03aaea77505636",
"version_major": 2,
"version_minor": 0,
"model_id": "f4a1b65b2bf945099197135c5598084b"
}
"version_minor": 0
},
"text/plain": [
"interactive(children=(FloatSlider(value=0.0, description='t', max=1.0, step=0.01), Output()), _dom_classes=('w…"
]
},
"metadata": {}
"metadata": {},
"output_type": "display_data"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<pybamm.plotting.quick_plot.QuickPlot at 0x14be13b80>"
"<pybamm.plotting.quick_plot.QuickPlot at 0x149685400>"
]
},
"execution_count": 11,
"metadata": {},
"execution_count": 11
"output_type": "execute_result"
}
],
"source": [
Expand Down Expand Up @@ -425,9 +429,22 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5-final"
"version": "3.8.6"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": true
}
},
"nbformat": 4,
"nbformat_minor": 2
}
}
3 changes: 2 additions & 1 deletion examples/notebooks/models/pouch-cell-model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,8 @@
"comsol_solution = pybamm.Solution(solutions[\"1+1D DFN\"].t, solutions[\"1+1D DFN\"].y)\n",
"comsol_model.timescale = simulations[\"1+1D DFN\"].model.timescale\n",
"comsol_model.length_scales = simulations[\"1+1D DFN\"].model.length_scales\n",
"comsol_solution.model = comsol_model"
"comsol_solution.model = comsol_model\n",
"comsol_solution.inputs = {}"
]
},
{
Expand Down
46 changes: 35 additions & 11 deletions examples/notebooks/parameter-values.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion examples/scripts/DFN.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

# solve model
t_eval = np.linspace(0, 3600, 100)
solver = pybamm.CasadiSolver(mode="fast", atol=1e-6, rtol=1e-3)
solver = pybamm.CasadiSolver(mode="safe", atol=1e-6, rtol=1e-3)
solution = solver.solve(model, t_eval)

# plot
Expand Down
1 change: 1 addition & 0 deletions pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def version(formatted=False):

ABSOLUTE_PATH = root_dir()
PARAMETER_PATH = [
root_dir(),
os.getcwd(),
os.path.join(root_dir(), "pybamm", "input", "parameters"),
]
Expand Down
40 changes: 36 additions & 4 deletions pybamm/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(self, name="Unnamed model"):
self.external_variables = []
self._parameters = None
self._input_parameters = None
self._variables_casadi = {}

# Default behaviour is to use the jacobian and simplify
self.use_jacobian = True
Expand All @@ -117,6 +118,7 @@ def __init__(self, name="Unnamed model"):

# Model is not initially discretised
self.is_discretised = False
self.y_slices = None

# Default timescale is 1 second
self.timescale = pybamm.Scalar(1)
Expand Down Expand Up @@ -327,6 +329,14 @@ def new_empty_copy(self):
new_model.convert_to_format = self.convert_to_format
new_model.timescale = self.timescale
new_model.length_scales = self.length_scales

# Variables from discretisation
new_model.is_discretised = self.is_discretised
new_model.y_slices = self.y_slices
new_model.concatenated_rhs = self.concatenated_rhs
new_model.concatenated_algebraic = self.concatenated_algebraic
new_model.concatenated_initial_conditions = self.concatenated_initial_conditions

return new_model

def new_copy(self):
Expand Down Expand Up @@ -412,6 +422,31 @@ def set_initial_conditions_from(self, solution, inplace=True):
"Variable must have type 'Variable' or 'Concatenation'"
)

# Also update the concatenated initial conditions if the model is already
# discretised
if model.is_discretised:
# Unpack slices for sorting
y_slices = {var.id: slce for var, slce in model.y_slices.items()}
slices = []
for symbol in model.initial_conditions.keys():
if isinstance(symbol, pybamm.Concatenation):
# must append the slice for the whole concatenation, so that
# equations get sorted correctly
slices.append(
slice(
y_slices[symbol.children[0].id][0].start,
y_slices[symbol.children[-1].id][0].stop,
)
)
else:
slices.append(y_slices[symbol.id][0])
equations = list(model.initial_conditions.values())
# sort equations according to slices
sorted_equations = [eq for _, eq in sorted(zip(slices, equations))]
model.concatenated_initial_conditions = pybamm.NumpyConcatenation(
*sorted_equations
)

return model

def check_and_combine_dict(self, dict1, dict2):
Expand Down Expand Up @@ -888,10 +923,7 @@ def default_spatial_methods(self):
@property
def default_solver(self):
"Return default solver based on whether model is ODE model or DAE model"
if len(self.algebraic) == 0:
return pybamm.ScipySolver()
else:
return pybamm.CasadiSolver(mode="safe")
return pybamm.CasadiSolver(mode="safe")


# helper functions for finding symbols
Expand Down
31 changes: 2 additions & 29 deletions pybamm/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,35 +517,6 @@ def step(self, dt, solver=None, npts=2, save=True, **kwargs):

return self.solution

def get_variable_array(self, *variables):
"""
A helper function to easily obtain a dictionary of arrays of values
for a list of variables at the latest timestep.
Parameters
----------
variable: str
The name of the variable/variables you wish to obtain the arrays for.
Returns
-------
variable_arrays: dict
A dictionary of the variable names and their corresponding
arrays.
"""

variable_arrays = [
self.built_model.variables[var].evaluate(
self.solution.t[-1], self.solution.y[:, -1]
)
for var in variables
]

if len(variable_arrays) == 1:
return variable_arrays[0]
else:
return tuple(variable_arrays)

def plot(self, output_variables=None, quick_plot_vars=None, **kwargs):
"""
A method to quickly plot the outputs of the simulation. Creates a
Expand Down Expand Up @@ -695,6 +666,8 @@ def save(self, filename):
and self._solver.integrator_specs != {}
):
self._solver.integrator_specs = {}
if self.solution is not None:
self.solution.clear_casadi_attributes()
with open(filename, "wb") as f:
pickle.dump(self, f, pickle.HIGHEST_PROTOCOL)

Expand Down
7 changes: 5 additions & 2 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,11 @@ def get_termination_reason(self, solution, events):
# causes an error later in ProcessedVariable)
if solution.t_event - solution._t[-1] > self.atol:
solution._t = np.concatenate((solution._t, solution.t_event))
solution._y = np.concatenate((solution._y, solution.y_event), axis=1)
if isinstance(solution.y, casadi.DM):
solution._y = casadi.horzcat(solution.y, solution.y_event)
else:
solution._y = np.hstack((solution._y, solution.y_event))

for name, inp in solution.inputs.items():
solution._inputs[name] = np.c_[inp, inp[:, -1]]

Expand Down Expand Up @@ -956,7 +960,6 @@ def __init__(self, function, name, model):
self.timescale = self.model.timescale_eval

def __call__(self, t, y, inputs):
y = y.reshape(-1, 1)
if self.name in ["RHS", "algebraic", "residuals"]:
pybamm.logger.debug(
"Evaluating {} for {} at t={}".format(
Expand Down
Loading

0 comments on commit 0926080

Please sign in to comment.