-
-
Notifications
You must be signed in to change notification settings - Fork 559
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4556 from aabills/coupled-variable-2
CoupledVariable
- Loading branch information
Showing
5 changed files
with
181 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import pybamm | ||
|
||
from pybamm.type_definitions import DomainType | ||
|
||
|
||
class CoupledVariable(pybamm.Symbol): | ||
""" | ||
A node in the expression tree representing a variable whose equation is set by a different model or submodel. | ||
Parameters | ||
---------- | ||
name : str | ||
name of the node | ||
domain : iterable of str | ||
list of domains that this coupled variable is valid over | ||
""" | ||
|
||
def __init__( | ||
self, | ||
name: str, | ||
domain: DomainType = None, | ||
) -> None: | ||
super().__init__(name, domain=domain) | ||
|
||
def _evaluate_for_shape(self): | ||
""" | ||
Returns the scalar 'NaN' to represent the shape of a parameter. | ||
See :meth:`pybamm.Symbol.evaluate_for_shape()` | ||
""" | ||
return pybamm.evaluate_for_shape_using_domain(self.domains) | ||
|
||
def create_copy(self): | ||
"""Creates a new copy of the coupled variable.""" | ||
new_coupled_variable = CoupledVariable(self.name, self.domain) | ||
return new_coupled_variable | ||
|
||
@property | ||
def children(self): | ||
return self._children | ||
|
||
@children.setter | ||
def children(self, expr): | ||
self._children = expr | ||
|
||
def set_coupled_variable(self, symbol, expr): | ||
"""Sets the children of the coupled variable to the expression passed in expr. If the symbol is not the coupled variable, then it searches the children of the symbol for the coupled variable. The coupled variable will be replaced by its first child (symbol.children[0], which should be expr) in the discretisation step.""" | ||
if self == symbol: | ||
symbol.children = [ | ||
expr, | ||
] | ||
else: | ||
for child in symbol.children: | ||
self.set_coupled_variable(child, expr) | ||
symbol.set_id() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# | ||
# Tests for the CoupledVariable class | ||
# | ||
|
||
|
||
import numpy as np | ||
|
||
import pybamm | ||
|
||
import pytest | ||
|
||
|
||
def combine_models(list_of_models): | ||
model = pybamm.BaseModel() | ||
|
||
for submodel in list_of_models: | ||
model.coupled_variables.update(submodel.coupled_variables) | ||
model.variables.update(submodel.variables) | ||
model.rhs.update(submodel.rhs) | ||
model.algebraic.update(submodel.algebraic) | ||
model.initial_conditions.update(submodel.initial_conditions) | ||
model.boundary_conditions.update(submodel.boundary_conditions) | ||
|
||
for name, coupled_variable in model.coupled_variables.items(): | ||
if name in model.variables: | ||
for sym in model.rhs.values(): | ||
coupled_variable.set_coupled_variable(sym, model.variables[name]) | ||
for sym in model.algebraic.values(): | ||
coupled_variable.set_coupled_variable(sym, model.variables[name]) | ||
return model | ||
|
||
|
||
class TestCoupledVariable: | ||
def test_coupled_variable(self): | ||
model_1 = pybamm.BaseModel() | ||
model_1_var_1 = pybamm.CoupledVariable("a") | ||
model_1_var_2 = pybamm.Variable("b") | ||
model_1.rhs[model_1_var_2] = -0.2 * model_1_var_1 | ||
model_1.variables["b"] = model_1_var_2 | ||
model_1.coupled_variables["a"] = model_1_var_1 | ||
model_1.initial_conditions[model_1_var_2] = 1.0 | ||
|
||
model_2 = pybamm.BaseModel() | ||
model_2_var_1 = pybamm.Variable("a") | ||
model_2_var_2 = pybamm.CoupledVariable("b") | ||
model_2.rhs[model_2_var_1] = -0.2 * model_2_var_2 | ||
model_2.variables["a"] = model_2_var_1 | ||
model_2.coupled_variables["b"] = model_2_var_2 | ||
model_2.initial_conditions[model_2_var_1] = 1.0 | ||
|
||
model = combine_models([model_1, model_2]) | ||
|
||
params = pybamm.ParameterValues({}) | ||
geometry = {} | ||
|
||
# Process parameters | ||
params.process_model(model) | ||
params.process_geometry(geometry) | ||
|
||
# mesh and discretise | ||
submesh_types = {} | ||
var_pts = {} | ||
mesh = pybamm.Mesh(geometry, submesh_types, var_pts) | ||
|
||
spatial_methods = {} | ||
disc = pybamm.Discretisation(mesh, spatial_methods) | ||
disc.process_model(model) | ||
|
||
# solve | ||
solver = pybamm.CasadiSolver() | ||
t = np.linspace(0, 10, 1000) | ||
solution = solver.solve(model, t) | ||
|
||
np.testing.assert_almost_equal( | ||
solution["a"].entries, solution["b"].entries, decimal=10 | ||
) | ||
|
||
assert set(model.list_coupled_variables()) == set(["a", "b"]) | ||
|
||
def test_create_copy(self): | ||
a = pybamm.CoupledVariable("a") | ||
b = a.create_copy() | ||
assert a == b | ||
|
||
def test_setter(self): | ||
model = pybamm.BaseModel() | ||
a = pybamm.CoupledVariable("a") | ||
coupled_variables = {"a": a} | ||
model.coupled_variables = coupled_variables | ||
assert model.coupled_variables == coupled_variables | ||
|
||
with pytest.raises(ValueError, match="Coupled variable with name"): | ||
coupled_variables = {"b": a} | ||
model.coupled_variables = coupled_variables |