Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: auto-add state variables to output variables #4700

Merged
merged 6 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
- Made composite electrode model compatible with particle size distribution ([#4687](https://github.com/pybamm-team/PyBaMM/pull/4687))
- Added `Symbol.post_order()` method to return an iterable that steps through the tree in post-order fashion. ([#4684](https://github.com/pybamm-team/PyBaMM/pull/4684))
- Added two more submodels (options) for the SEI: Lars von Kolzenberg (2020) model and Tunneling Limit model ([#4394](https://github.com/pybamm-team/PyBaMM/pull/4394))

- Automatically add state variables of the model to the output variables if they are not already present ([#4700](https://github.com/pybamm-team/PyBaMM/pull/4700))

## Breaking changes

Expand Down
83 changes: 47 additions & 36 deletions src/pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ def process_model(self, model, inplace=True):

model_disc.bcs = self.bcs

# pre-process variables so that all state variables are included
pre_processed_variables = self._pre_process_variables(
model.variables, model.initial_conditions
)

pybamm.logger.verbose(f"Discretise initial conditions for {model.name}")
ics, concat_ics = self.process_initial_conditions(model)
model_disc.initial_conditions = ics
Expand All @@ -202,7 +207,8 @@ def process_model(self, model, inplace=True):
# Note that we **do not** discretise the keys of model.rhs,
# model.initial_conditions and model.boundary_conditions
pybamm.logger.verbose(f"Discretise variables for {model.name}")
model_disc.variables = self.process_dict(model.variables)

model_disc.variables = self.process_dict(pre_processed_variables)

# Process parabolic and elliptic equations
pybamm.logger.verbose(f"Discretise model equations for {model.name}")
Expand Down Expand Up @@ -657,6 +663,46 @@ def create_mass_matrix(self, model):

return mass_matrix, mass_matrix_inv

def _pre_process_variables(
self,
variables: dict[str, pybamm.Symbol],
initial_conditions: dict[pybamm.Variable, pybamm.Symbol],
):
"""
Pre-process variables before discretisation. This involves:
- ensuring that all the state variables are included in the variables,
any missing are added

Parameters
----------
variables : dict
Dictionary of variables to pre-process
initial_conditions : dict
Dictionary of initial conditions

Returns
-------
dict
Pre-processed variables (copy of input variables with any missing state)

Raises
------
:class:`pybamm.ModelError`
If any state variable names are already included but with
incorrect expressions
"""
new_variables = {k: v for k, v in variables.items()}
for var in initial_conditions.keys():
if var.name not in new_variables:
new_variables[var.name] = var
else:
if new_variables[var.name] != var:
raise pybamm.ModelError(
f"Variable '{var.name}' should have expression "
f"'{var}', but has expression '{new_variables[var.name]}'"
)
return new_variables

def process_dict(self, var_eqn_dict, ics=False):
"""Discretise a dictionary of {variable: equation}, broadcasting if necessary
(can be model.rhs, model.algebraic, model.initial_conditions or
Expand Down Expand Up @@ -1008,7 +1054,6 @@ def _concatenate_in_order(self, var_eqn_dict, check_complete=False, sparse=False
def check_model(self, model):
"""Perform some basic checks to make sure the discretised model makes sense."""
self.check_initial_conditions(model)
self.check_variables(model)

def check_initial_conditions(self, model):
# Check initial conditions are a numpy array
Expand Down Expand Up @@ -1049,40 +1094,6 @@ def check_initial_conditions(self, model):
f"{model.algebraic[var].shape} and initial_conditions.shape = {model.initial_conditions[var].shape} for variable '{var}'."
)

def check_variables(self, model):
"""
Check variables in variable list against rhs.
Be lenient with size check if the variable in model.variables is broadcasted, or
a concatenation
(if broadcasted, variable is a multiplication with a vector of ones)
"""
for rhs_var in model.rhs.keys():
if rhs_var.name in model.variables.keys():
var = model.variables[rhs_var.name]

different_shapes = not np.array_equal(
model.rhs[rhs_var].shape, var.shape
)

not_concatenation = not isinstance(var, pybamm.Concatenation)

not_mult_by_one_vec = not (
isinstance(
var, (pybamm.Multiplication, pybamm.MatrixMultiplication)
)
and (
pybamm.is_matrix_one(var.left)
or pybamm.is_matrix_one(var.right)
)
)

if different_shapes and not_concatenation and not_mult_by_one_vec:
raise pybamm.ModelError(
"variable and its eqn must have the same shape after "
"discretisation but variable.shape = "
f"{var.shape} and rhs.shape = {model.rhs[rhs_var].shape} for variable '{var}'. "
)

def is_variable_independent(self, var, all_vars_in_eqns):
pybamm.logger.verbose("Removing independent blocks.")
if not isinstance(var, pybamm.Variable):
Expand Down
42 changes: 22 additions & 20 deletions src/pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations
import copy
from collections import defaultdict
from typing import Optional

import numpy as np
import sympy
Expand Down Expand Up @@ -146,9 +147,9 @@ def _concatenation_new_copy(self, children, perform_simplifications: bool = True
children before creating the new copy.
"""
if perform_simplifications:
return concatenation(*children)
return concatenation(*children, name=self.name)
else:
return self.__class__(*children)
return self.__class__(*children, name=self.name)

def _concatenation_jac(self, children_jacs):
"""Calculate the Jacobian of a concatenation."""
Expand Down Expand Up @@ -468,17 +469,18 @@ def _concatenation_new_copy(self, children, perform_simplifications=True):
class ConcatenationVariable(Concatenation):
"""A Variable representing a concatenation of variables."""

def __init__(self, *children):
# Name is the intersection of the children names (should usually make sense
# if the children have been named consistently)
name = intersect(children[0].name, children[1].name)
for child in children[2:]:
name = intersect(name, child.name)
if len(name) == 0:
name = None
# name is unchanged if its length is 1
elif len(name) > 1:
name = name[0].capitalize() + name[1:]
def __init__(self, *children, name: Optional[str] = None):
if name is None:
# Name is the intersection of the children names (should usually make sense
# if the children have been named consistently)
name = intersect(children[0].name, children[1].name)
for child in children[2:]:
name = intersect(name, child.name)
if len(name) == 0:
name = None
# name is unchanged if its length is 1
elif len(name) > 1:
name = name[0].capitalize() + name[1:]

if len(children) > 0:
if all(child.scale == children[0].scale for child in children):
Expand Down Expand Up @@ -523,7 +525,7 @@ def intersect(s1: str, s2: str):
return intersect.lstrip().rstrip()


def simplified_concatenation(*children):
def simplified_concatenation(*children, name: Optional[str] = None):
"""Perform simplifications on a concatenation."""
# remove children that are None
children = list(filter(lambda x: x is not None, children))
Expand All @@ -534,29 +536,29 @@ def simplified_concatenation(*children):
elif len(children) == 1:
return children[0]
elif all(isinstance(child, pybamm.Variable) for child in children):
return pybamm.ConcatenationVariable(*children)
return pybamm.ConcatenationVariable(*children, name=name)
else:
# Create Concatenation to easily read domains
concat = Concatenation(*children)
concat = Concatenation(*children, name=name)
if all(
isinstance(child, pybamm.Broadcast) and child.child == children[0].child
for child in children
):
unique_child = children[0].orphans[0]
if isinstance(children[0], pybamm.PrimaryBroadcast):
return pybamm.PrimaryBroadcast(unique_child, concat.domain)
return pybamm.PrimaryBroadcast(unique_child, concat.domain, name=name)
else:
return pybamm.FullBroadcast(
unique_child, broadcast_domains=concat.domains
unique_child, broadcast_domains=concat.domains, name=name
)
else:
return concat


def concatenation(*children):
def concatenation(*children, name: Optional[str] = None):
"""Helper function to create concatenations."""
# TODO: add option to turn off simplifications
return simplified_concatenation(*children)
return simplified_concatenation(*children, name=name)


def simplified_numpy_concatenation(*children):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ def get_fundamental_variables(self):
domain="positive electrode",
auxiliary_domains={"secondary": "current collector"},
)
c_ox_s_p = pybamm.concatenation(c_ox_s, c_ox_p)
c_ox_s_p = pybamm.concatenation(
c_ox_s,
c_ox_p,
name="Separator and positive electrode oxygen concentration [mol.m-3]",
)
variables = {
"Separator and positive electrode oxygen concentration [mol.m-3]": c_ox_s_p
}
Expand Down
46 changes: 28 additions & 18 deletions tests/unit/test_discretisations/test_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,9 +1036,9 @@ def test_concatenation_2D(self):
assert expr.children[2].evaluate(0, y).shape == (105, 1)

def test_exceptions(self):
c_n = pybamm.Variable("c", domain=["negative electrode"])
c_n = pybamm.Variable("c_n", domain=["negative electrode"])
N_n = pybamm.grad(c_n)
c_s = pybamm.Variable("c", domain=["separator"])
c_s = pybamm.Variable("c_s", domain=["separator"])
N_s = pybamm.grad(c_s)
model = pybamm.BaseModel()
model.rhs = {c_n: pybamm.div(N_n), c_s: pybamm.div(N_s)}
Expand All @@ -1049,22 +1049,6 @@ def test_exceptions(self):
}

disc = get_discretisation_for_testing()

# check raises error if different sized key and output var
model.variables = {c_n.name: c_s}
with pytest.raises(pybamm.ModelError, match="variable and its eqn"):
disc.process_model(model)

# check doesn't raise if concatenation
model.variables = {c_n.name: pybamm.concatenation(2 * c_n, 3 * c_s)}
disc.process_model(model, inplace=False)

# check doesn't raise if broadcast
model.variables = {
c_n.name: pybamm.PrimaryBroadcast(
pybamm.InputParameter("a"), ["negative electrode"]
)
}
disc.process_model(model)

# Check setting up a 0D spatial method with 1D mesh raises error
Expand Down Expand Up @@ -1277,3 +1261,29 @@ def test_independent_rhs_with_event(self):
disc = pybamm.Discretisation(remove_independent_variables_from_rhs=True)
disc.process_model(model)
assert len(model.rhs) == 3

def test_pre_process_variables(self):
a = pybamm.Variable("a")
b = pybamm.Variable("b")
model = pybamm.BaseModel()
model.rhs = {a: b, b: a}
model.initial_conditions = {
a: pybamm.Scalar(0),
b: pybamm.Scalar(1),
}
model.variables = {
"a": a, # correct
# b missing
}
disc = pybamm.Discretisation()
disc_model = disc.process_model(model, inplace=False)
assert list(disc_model.variables.keys()) == ["a", "b"]

model.variables = {
"a": a,
"b": 2 * a,
}
with pytest.raises(
pybamm.ModelError, match="Variable 'b' should have expression"
):
disc.process_model(model, inplace=False)
Loading