Skip to content

Commit

Permalink
Merge pull request #4556 from aabills/coupled-variable-2
Browse files Browse the repository at this point in the history
CoupledVariable
  • Loading branch information
valentinsulzer authored Nov 4, 2024
2 parents 01a7c08 + c259bb0 commit 19a7738
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .expression_tree.parameter import Parameter, FunctionParameter
from .expression_tree.scalar import Scalar
from .expression_tree.variable import *
from .expression_tree.coupled_variable import *
from .expression_tree.independent_variable import *
from .expression_tree.independent_variable import t
from .expression_tree.vector import Vector
Expand Down
5 changes: 5 additions & 0 deletions src/pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,11 @@ def _process_symbol(self, symbol):
if symbol._expected_size is None:
symbol._expected_size = expected_size
return symbol.create_copy()

elif isinstance(symbol, pybamm.CoupledVariable):
new_symbol = self.process_symbol(symbol.children[0])
return new_symbol

else:
# Backup option: return the object
return symbol
Expand Down
55 changes: 55 additions & 0 deletions src/pybamm/expression_tree/coupled_variable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pybamm

from pybamm.type_definitions import DomainType


class CoupledVariable(pybamm.Symbol):
"""
A node in the expression tree representing a variable whose equation is set by a different model or submodel.
Parameters
----------
name : str
name of the node
domain : iterable of str
list of domains that this coupled variable is valid over
"""

def __init__(
self,
name: str,
domain: DomainType = None,
) -> None:
super().__init__(name, domain=domain)

def _evaluate_for_shape(self):
"""
Returns the scalar 'NaN' to represent the shape of a parameter.
See :meth:`pybamm.Symbol.evaluate_for_shape()`
"""
return pybamm.evaluate_for_shape_using_domain(self.domains)

def create_copy(self):
"""Creates a new copy of the coupled variable."""
new_coupled_variable = CoupledVariable(self.name, self.domain)
return new_coupled_variable

@property
def children(self):
return self._children

@children.setter
def children(self, expr):
self._children = expr

def set_coupled_variable(self, symbol, expr):
"""Sets the children of the coupled variable to the expression passed in expr. If the symbol is not the coupled variable, then it searches the children of the symbol for the coupled variable. The coupled variable will be replaced by its first child (symbol.children[0], which should be expr) in the discretisation step."""
if self == symbol:
symbol.children = [
expr,
]
else:
for child in symbol.children:
self.set_coupled_variable(child, expr)
symbol.set_id()
26 changes: 26 additions & 0 deletions src/pybamm/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(self, name="Unnamed model"):
self._boundary_conditions = {}
self._variables_by_submodel = {}
self._variables = pybamm.FuzzyDict({})
self._coupled_variables = {}
self._summary_variables = []
self._events = []
self._concatenated_rhs = None
Expand Down Expand Up @@ -182,6 +183,31 @@ def boundary_conditions(self):
def boundary_conditions(self, boundary_conditions):
self._boundary_conditions = BoundaryConditionsDict(boundary_conditions)

@property
def coupled_variables(self):
"""Returns a dictionary mapping strings to expressions representing variables needed by the model but whose equations were set by other models."""
return self._coupled_variables

@coupled_variables.setter
def coupled_variables(self, coupled_variables):
for name, var in coupled_variables.items():
if (
isinstance(var, pybamm.CoupledVariable)
and var.name != name
# Exception if the variable is also there under its own name
and not (
var.name in coupled_variables and coupled_variables[var.name] == var
)
):
raise ValueError(
f"Coupled variable with name '{var.name}' is in coupled variables dictionary with "
f"name '{name}'. Names must match."
)
self._coupled_variables = coupled_variables

def list_coupled_variables(self):
return list(self._coupled_variables.keys())

@property
def variables(self):
"""Returns a dictionary mapping strings to expressions representing the model's useful variables."""
Expand Down
94 changes: 94 additions & 0 deletions tests/unit/test_expression_tree/test_coupled_variable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#
# Tests for the CoupledVariable class
#


import numpy as np

import pybamm

import pytest


def combine_models(list_of_models):
model = pybamm.BaseModel()

for submodel in list_of_models:
model.coupled_variables.update(submodel.coupled_variables)
model.variables.update(submodel.variables)
model.rhs.update(submodel.rhs)
model.algebraic.update(submodel.algebraic)
model.initial_conditions.update(submodel.initial_conditions)
model.boundary_conditions.update(submodel.boundary_conditions)

for name, coupled_variable in model.coupled_variables.items():
if name in model.variables:
for sym in model.rhs.values():
coupled_variable.set_coupled_variable(sym, model.variables[name])
for sym in model.algebraic.values():
coupled_variable.set_coupled_variable(sym, model.variables[name])
return model


class TestCoupledVariable:
def test_coupled_variable(self):
model_1 = pybamm.BaseModel()
model_1_var_1 = pybamm.CoupledVariable("a")
model_1_var_2 = pybamm.Variable("b")
model_1.rhs[model_1_var_2] = -0.2 * model_1_var_1
model_1.variables["b"] = model_1_var_2
model_1.coupled_variables["a"] = model_1_var_1
model_1.initial_conditions[model_1_var_2] = 1.0

model_2 = pybamm.BaseModel()
model_2_var_1 = pybamm.Variable("a")
model_2_var_2 = pybamm.CoupledVariable("b")
model_2.rhs[model_2_var_1] = -0.2 * model_2_var_2
model_2.variables["a"] = model_2_var_1
model_2.coupled_variables["b"] = model_2_var_2
model_2.initial_conditions[model_2_var_1] = 1.0

model = combine_models([model_1, model_2])

params = pybamm.ParameterValues({})
geometry = {}

# Process parameters
params.process_model(model)
params.process_geometry(geometry)

# mesh and discretise
submesh_types = {}
var_pts = {}
mesh = pybamm.Mesh(geometry, submesh_types, var_pts)

spatial_methods = {}
disc = pybamm.Discretisation(mesh, spatial_methods)
disc.process_model(model)

# solve
solver = pybamm.CasadiSolver()
t = np.linspace(0, 10, 1000)
solution = solver.solve(model, t)

np.testing.assert_almost_equal(
solution["a"].entries, solution["b"].entries, decimal=10
)

assert set(model.list_coupled_variables()) == set(["a", "b"])

def test_create_copy(self):
a = pybamm.CoupledVariable("a")
b = a.create_copy()
assert a == b

def test_setter(self):
model = pybamm.BaseModel()
a = pybamm.CoupledVariable("a")
coupled_variables = {"a": a}
model.coupled_variables = coupled_variables
assert model.coupled_variables == coupled_variables

with pytest.raises(ValueError, match="Coupled variable with name"):
coupled_variables = {"b": a}
model.coupled_variables = coupled_variables

0 comments on commit 19a7738

Please sign in to comment.