Skip to content

Commit

Permalink
feat: auto-add state variables to output variables (#4700)
Browse files Browse the repository at this point in the history
* feat: auto-add state variables to output variables

* add changelog

* fix lead acid model tests

---------

Co-authored-by: Eric G. Kratz <[email protected]>
  • Loading branch information
martinjrobins and kratman authored Jan 8, 2025
1 parent 8aaaab1 commit 3dccd52
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 76 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
- Made composite electrode model compatible with particle size distribution ([#4687](https://github.com/pybamm-team/PyBaMM/pull/4687))
- Added `Symbol.post_order()` method to return an iterable that steps through the tree in post-order fashion. ([#4684](https://github.com/pybamm-team/PyBaMM/pull/4684))
- Added two more submodels (options) for the SEI: Lars von Kolzenberg (2020) model and Tunneling Limit model ([#4394](https://github.com/pybamm-team/PyBaMM/pull/4394))

- Automatically add state variables of the model to the output variables if they are not already present ([#4700](https://github.com/pybamm-team/PyBaMM/pull/4700))

## Breaking changes

Expand Down
83 changes: 47 additions & 36 deletions src/pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ def process_model(self, model, inplace=True):

model_disc.bcs = self.bcs

# pre-process variables so that all state variables are included
pre_processed_variables = self._pre_process_variables(
model.variables, model.initial_conditions
)

pybamm.logger.verbose(f"Discretise initial conditions for {model.name}")
ics, concat_ics = self.process_initial_conditions(model)
model_disc.initial_conditions = ics
Expand All @@ -202,7 +207,8 @@ def process_model(self, model, inplace=True):
# Note that we **do not** discretise the keys of model.rhs,
# model.initial_conditions and model.boundary_conditions
pybamm.logger.verbose(f"Discretise variables for {model.name}")
model_disc.variables = self.process_dict(model.variables)

model_disc.variables = self.process_dict(pre_processed_variables)

# Process parabolic and elliptic equations
pybamm.logger.verbose(f"Discretise model equations for {model.name}")
Expand Down Expand Up @@ -657,6 +663,46 @@ def create_mass_matrix(self, model):

return mass_matrix, mass_matrix_inv

def _pre_process_variables(
self,
variables: dict[str, pybamm.Symbol],
initial_conditions: dict[pybamm.Variable, pybamm.Symbol],
):
"""
Pre-process variables before discretisation. This involves:
- ensuring that all the state variables are included in the variables,
any missing are added
Parameters
----------
variables : dict
Dictionary of variables to pre-process
initial_conditions : dict
Dictionary of initial conditions
Returns
-------
dict
Pre-processed variables (copy of input variables with any missing state)
Raises
------
:class:`pybamm.ModelError`
If any state variable names are already included but with
incorrect expressions
"""
new_variables = {k: v for k, v in variables.items()}
for var in initial_conditions.keys():
if var.name not in new_variables:
new_variables[var.name] = var
else:
if new_variables[var.name] != var:
raise pybamm.ModelError(
f"Variable '{var.name}' should have expression "
f"'{var}', but has expression '{new_variables[var.name]}'"
)
return new_variables

def process_dict(self, var_eqn_dict, ics=False):
"""Discretise a dictionary of {variable: equation}, broadcasting if necessary
(can be model.rhs, model.algebraic, model.initial_conditions or
Expand Down Expand Up @@ -1008,7 +1054,6 @@ def _concatenate_in_order(self, var_eqn_dict, check_complete=False, sparse=False
def check_model(self, model):
"""Perform some basic checks to make sure the discretised model makes sense."""
self.check_initial_conditions(model)
self.check_variables(model)

def check_initial_conditions(self, model):
# Check initial conditions are a numpy array
Expand Down Expand Up @@ -1049,40 +1094,6 @@ def check_initial_conditions(self, model):
f"{model.algebraic[var].shape} and initial_conditions.shape = {model.initial_conditions[var].shape} for variable '{var}'."
)

def check_variables(self, model):
"""
Check variables in variable list against rhs.
Be lenient with size check if the variable in model.variables is broadcasted, or
a concatenation
(if broadcasted, variable is a multiplication with a vector of ones)
"""
for rhs_var in model.rhs.keys():
if rhs_var.name in model.variables.keys():
var = model.variables[rhs_var.name]

different_shapes = not np.array_equal(
model.rhs[rhs_var].shape, var.shape
)

not_concatenation = not isinstance(var, pybamm.Concatenation)

not_mult_by_one_vec = not (
isinstance(
var, (pybamm.Multiplication, pybamm.MatrixMultiplication)
)
and (
pybamm.is_matrix_one(var.left)
or pybamm.is_matrix_one(var.right)
)
)

if different_shapes and not_concatenation and not_mult_by_one_vec:
raise pybamm.ModelError(
"variable and its eqn must have the same shape after "
"discretisation but variable.shape = "
f"{var.shape} and rhs.shape = {model.rhs[rhs_var].shape} for variable '{var}'. "
)

def is_variable_independent(self, var, all_vars_in_eqns):
pybamm.logger.verbose("Removing independent blocks.")
if not isinstance(var, pybamm.Variable):
Expand Down
42 changes: 22 additions & 20 deletions src/pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations
import copy
from collections import defaultdict
from typing import Optional

import numpy as np
import sympy
Expand Down Expand Up @@ -146,9 +147,9 @@ def _concatenation_new_copy(self, children, perform_simplifications: bool = True
children before creating the new copy.
"""
if perform_simplifications:
return concatenation(*children)
return concatenation(*children, name=self.name)
else:
return self.__class__(*children)
return self.__class__(*children, name=self.name)

def _concatenation_jac(self, children_jacs):
"""Calculate the Jacobian of a concatenation."""
Expand Down Expand Up @@ -468,17 +469,18 @@ def _concatenation_new_copy(self, children, perform_simplifications=True):
class ConcatenationVariable(Concatenation):
"""A Variable representing a concatenation of variables."""

def __init__(self, *children):
# Name is the intersection of the children names (should usually make sense
# if the children have been named consistently)
name = intersect(children[0].name, children[1].name)
for child in children[2:]:
name = intersect(name, child.name)
if len(name) == 0:
name = None
# name is unchanged if its length is 1
elif len(name) > 1:
name = name[0].capitalize() + name[1:]
def __init__(self, *children, name: Optional[str] = None):
if name is None:
# Name is the intersection of the children names (should usually make sense
# if the children have been named consistently)
name = intersect(children[0].name, children[1].name)
for child in children[2:]:
name = intersect(name, child.name)
if len(name) == 0:
name = None
# name is unchanged if its length is 1
elif len(name) > 1:
name = name[0].capitalize() + name[1:]

if len(children) > 0:
if all(child.scale == children[0].scale for child in children):
Expand Down Expand Up @@ -523,7 +525,7 @@ def intersect(s1: str, s2: str):
return intersect.lstrip().rstrip()


def simplified_concatenation(*children):
def simplified_concatenation(*children, name: Optional[str] = None):
"""Perform simplifications on a concatenation."""
# remove children that are None
children = list(filter(lambda x: x is not None, children))
Expand All @@ -534,29 +536,29 @@ def simplified_concatenation(*children):
elif len(children) == 1:
return children[0]
elif all(isinstance(child, pybamm.Variable) for child in children):
return pybamm.ConcatenationVariable(*children)
return pybamm.ConcatenationVariable(*children, name=name)
else:
# Create Concatenation to easily read domains
concat = Concatenation(*children)
concat = Concatenation(*children, name=name)
if all(
isinstance(child, pybamm.Broadcast) and child.child == children[0].child
for child in children
):
unique_child = children[0].orphans[0]
if isinstance(children[0], pybamm.PrimaryBroadcast):
return pybamm.PrimaryBroadcast(unique_child, concat.domain)
return pybamm.PrimaryBroadcast(unique_child, concat.domain, name=name)
else:
return pybamm.FullBroadcast(
unique_child, broadcast_domains=concat.domains
unique_child, broadcast_domains=concat.domains, name=name
)
else:
return concat


def concatenation(*children):
def concatenation(*children, name: Optional[str] = None):
"""Helper function to create concatenations."""
# TODO: add option to turn off simplifications
return simplified_concatenation(*children)
return simplified_concatenation(*children, name=name)


def simplified_numpy_concatenation(*children):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ def get_fundamental_variables(self):
domain="positive electrode",
auxiliary_domains={"secondary": "current collector"},
)
c_ox_s_p = pybamm.concatenation(c_ox_s, c_ox_p)
c_ox_s_p = pybamm.concatenation(
c_ox_s,
c_ox_p,
name="Separator and positive electrode oxygen concentration [mol.m-3]",
)
variables = {
"Separator and positive electrode oxygen concentration [mol.m-3]": c_ox_s_p
}
Expand Down
46 changes: 28 additions & 18 deletions tests/unit/test_discretisations/test_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,9 +1036,9 @@ def test_concatenation_2D(self):
assert expr.children[2].evaluate(0, y).shape == (105, 1)

def test_exceptions(self):
c_n = pybamm.Variable("c", domain=["negative electrode"])
c_n = pybamm.Variable("c_n", domain=["negative electrode"])
N_n = pybamm.grad(c_n)
c_s = pybamm.Variable("c", domain=["separator"])
c_s = pybamm.Variable("c_s", domain=["separator"])
N_s = pybamm.grad(c_s)
model = pybamm.BaseModel()
model.rhs = {c_n: pybamm.div(N_n), c_s: pybamm.div(N_s)}
Expand All @@ -1049,22 +1049,6 @@ def test_exceptions(self):
}

disc = get_discretisation_for_testing()

# check raises error if different sized key and output var
model.variables = {c_n.name: c_s}
with pytest.raises(pybamm.ModelError, match="variable and its eqn"):
disc.process_model(model)

# check doesn't raise if concatenation
model.variables = {c_n.name: pybamm.concatenation(2 * c_n, 3 * c_s)}
disc.process_model(model, inplace=False)

# check doesn't raise if broadcast
model.variables = {
c_n.name: pybamm.PrimaryBroadcast(
pybamm.InputParameter("a"), ["negative electrode"]
)
}
disc.process_model(model)

# Check setting up a 0D spatial method with 1D mesh raises error
Expand Down Expand Up @@ -1277,3 +1261,29 @@ def test_independent_rhs_with_event(self):
disc = pybamm.Discretisation(remove_independent_variables_from_rhs=True)
disc.process_model(model)
assert len(model.rhs) == 3

def test_pre_process_variables(self):
a = pybamm.Variable("a")
b = pybamm.Variable("b")
model = pybamm.BaseModel()
model.rhs = {a: b, b: a}
model.initial_conditions = {
a: pybamm.Scalar(0),
b: pybamm.Scalar(1),
}
model.variables = {
"a": a, # correct
# b missing
}
disc = pybamm.Discretisation()
disc_model = disc.process_model(model, inplace=False)
assert list(disc_model.variables.keys()) == ["a", "b"]

model.variables = {
"a": a,
"b": 2 * a,
}
with pytest.raises(
pybamm.ModelError, match="Variable 'b' should have expression"
):
disc.process_model(model, inplace=False)

0 comments on commit 3dccd52

Please sign in to comment.