Skip to content

Commit

Permalink
#658 merge 784
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Jan 27, 2020
2 parents fd84280 + 5832658 commit 7ae50d5
Show file tree
Hide file tree
Showing 12 changed files with 51 additions and 40 deletions.
6 changes: 4 additions & 2 deletions pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,14 +856,16 @@ def _process_symbol(self, symbol):
name, parent_and_slice = list(self.external_variables.keys())[idx]
if parent_and_slice is None:
# Variable didn't come from a concatenation so we can just create a
# normal external variable
# normal external variable using the symbol's name
return pybamm.ExternalVariable(
name,
symbol.name,
size=self._get_variable_size(symbol),
domain=symbol.domain,
auxiliary_domains=symbol.auxiliary_domains,
)
else:
# We have to use a special name since the concatenation doesn't have
# a very informative name. Needs improving
parent, start, end = parent_and_slice
ext = pybamm.ExternalVariable(
name,
Expand Down
14 changes: 13 additions & 1 deletion pybamm/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,14 @@ def build(self, check_model=True):
self._model, inplace=False, check_model=check_model
)

def solve(self, t_eval=None, solver=None, inputs=None, check_model=True):
def solve(
self,
t_eval=None,
solver=None,
external_variables=None,
inputs=None,
check_model=True,
):
"""
A method to solve the model. This method will automatically build
and set the model parameters if not already done so.
Expand All @@ -252,6 +259,11 @@ def solve(self, t_eval=None, solver=None, inputs=None, check_model=True):
non-dimensional time of 1.
solver : :class:`pybamm.BaseSolver`
The solver to use to solve the model.
external_variables : dict
A dictionary of external variables and their corresponding
values at the current time. The variables must correspond to
the variables that would normally be found by solving the
submodels that have been made external.
inputs : dict, optional
Any input parameters to pass to the model when solving
check_model : bool, optional
Expand Down
12 changes: 10 additions & 2 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@ def __init__(
self.root_tol = root_tol
self.max_steps = max_steps

self.name = "Base solver"

self.model_step_times = {}

# Defaults, can be overwritten by specific solver
self.name = "Base solver"
self.ode_solver = False

@property
def method(self):
return self._method
Expand Down Expand Up @@ -112,6 +114,12 @@ def set_up(self, model, inputs=None):
inputs = inputs or {}
y0 = model.concatenated_initial_conditions

# Check model.algebraic for ode solvers
if self.ode_solver is True and len(model.algebraic) > 0:
raise pybamm.SolverError(
"Cannot use ODE solver '{}' to solve DAE model".format(self.name)
)

if (
isinstance(self, pybamm.CasadiSolver)
and model.convert_to_format != "casadi"
Expand Down
1 change: 1 addition & 0 deletions pybamm/solvers/scikits_ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, method="cvode", rtol=1e-6, atol=1e-6, linsolver="dense"):

super().__init__(method, rtol, atol)
self.linsolver = linsolver
self.ode_solver = True
self.name = "Scikits ODE solver ({})".format(method)

def _integrate(self, model, t_eval, inputs=None):
Expand Down
1 change: 1 addition & 0 deletions pybamm/solvers/scipy_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class ScipySolver(pybamm.BaseSolver):

def __init__(self, method="BDF", rtol=1e-6, atol=1e-6):
super().__init__(method, rtol, atol)
self.ode_solver = True
self.name = "Scipy solver ({})".format(method)

def _integrate(self, model, t_eval, inputs=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Tests for the lead-acid LOQS model with capacitance
#
import pybamm
from pybamm.solvers.scikits_ode_solver import scikits_odes_spec
import tests

import unittest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@


class TestExternalCC(unittest.TestCase):
@unittest.skipIf(not pybamm.have_idaklu(), "idaklu solver is not installed")
def test_2p1d(self):
model_options = {
"current collector": "potential pair",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def test_external_temperature(self):
external_variables = {"Cell temperature": T}
sim.step(dt, external_variables=external_variables)

@unittest.skipIf(not pybamm.have_idaklu(), "idaklu solver is not installed")
def test_dae_external_temperature(self):

model_options = {
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,9 @@ def current_function(t):
self.assertEqual(sim.solution.t[0], 0)
self.assertEqual(sim.solution.t[1], dt)
self.assertEqual(sim.solution.t[2], 2 * dt)
np.testing.assert_array_equal(sim.solution.inputs["Current"], np.array([1,1,2]))
np.testing.assert_array_equal(
sim.solution.inputs["Current"], np.array([1, 1, 2])
)

def test_save_load(self):
model = pybamm.lead_acid.LOQS()
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/test_solvers/test_base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ def test_step_or_solve_empty_model(self):
with self.assertRaisesRegex(pybamm.ModelError, "Cannot solve empty model"):
solver.solve(model, None)

def test_ode_solver_fail_with_dae(self):
model = pybamm.BaseModel()
a = pybamm.Scalar(1)
model.algebraic = {a: a}
solver = pybamm.ScipySolver()
with self.assertRaisesRegex(pybamm.SolverError, "Cannot use ODE solver"):
solver.set_up(model)

def test_find_consistent_initial_conditions(self):
# Simple system: a single algebraic equation
class ScalarModel:
Expand Down
33 changes: 0 additions & 33 deletions tests/unit/test_solvers/test_idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,39 +58,6 @@ def test_set_atol(self):
variable_tols = {"Electrolyte concentration": 1e-3}
solver.set_atol_by_variable(variable_tols, model)

def test_model_step_events(self):
# Create model
model = pybamm.BaseModel()
var1 = pybamm.Variable("var1")
var2 = pybamm.Variable("var2")
model.rhs = {var1: 0.1 * var1}
model.algebraic = {var2: 2 * var1 - var2}
model.initial_conditions = {var1: 1, var2: 2}
model.events = {
"var1 = 1.5": pybamm.min(var1 - 1.5),
"var2 = 2.5": pybamm.min(var2 - 2.5),
}
disc = pybamm.Discretisation()
disc.process_model(model)

# Solve
step_solver = pybamm.IDAKLUSolver(rtol=1e-8, atol=1e-8)
dt = 0.05
time = 0
end_time = 5
step_solution = None
while time < end_time:
step_solution = step_solver.step(step_solution, model, dt=dt, npts=10)
time += dt
np.testing.assert_array_less(step_solution.y[0], 1.5)
np.testing.assert_array_less(step_solution.y[-1], 2.5001)
np.testing.assert_array_almost_equal(
step_solution.y[0], np.exp(0.1 * step_solution.t), decimal=5
)
np.testing.assert_array_almost_equal(
step_solution.y[-1], 2 * np.exp(0.1 * step_solution.t), decimal=5
)


if __name__ == "__main__":
print("Add -v for more debug output")
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/test_solvers/test_scikits_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,15 @@ def test_model_step_events(self):
step_solution.y[-1], 2 * np.exp(0.1 * step_solution.t), decimal=5
)

def test_ode_solver_fail_with_dae(self):
model = pybamm.BaseModel()
a = pybamm.Scalar(1)
model.algebraic = {a: a}
solver = pybamm.ScikitsOdeSolver()
with self.assertRaisesRegex(pybamm.SolverError, "Cannot use ODE solver"):
solver.set_up(model)


if __name__ == "__main__":
print("Add -v for more debug output")
import sys
Expand Down

0 comments on commit 7ae50d5

Please sign in to comment.