Skip to content

Commit

Permalink
#899 fix unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Mar 20, 2020
1 parent cd20d99 commit a702299
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 29 deletions.
11 changes: 8 additions & 3 deletions pybamm/solvers/algebraic_solver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#
# Algebraic solver class
#
import casadi
import pybamm
import numpy as np
from scipy import optimize
Expand Down Expand Up @@ -49,6 +50,10 @@ def _integrate(self, model, t_eval, inputs=None):
inputs : dict, optional
Any input parameters to pass to the model when solving
"""
inputs = inputs or {}
if model.convert_to_format == "casadi":
inputs = casadi.vertcat(*[x for x in inputs.values()])

algebraic = model.algebraic_eval
y0 = model.y0

Expand All @@ -58,7 +63,7 @@ def _integrate(self, model, t_eval, inputs=None):

def root_fun(y):
"Evaluates algebraic using y"
out = algebraic(t, y)
out = algebraic(t, y, inputs)
pybamm.logger.debug(
"Evaluating algebraic equations at t={}, L2-norm is {}".format(
t, np.linalg.norm(out)
Expand All @@ -69,14 +74,14 @@ def root_fun(y):
if model.jacobian_eval is not None:

def jac(y):
return model.jacobian_eval(t, y)
return model.jacobian_eval(t, y, inputs)

else:
jac = None

# Evaluate algebraic with new t and previous y0, if it's already close
# enough then keep it
if np.all(abs(algebraic(t, y0)) < self.tol):
if np.all(abs(algebraic(t, y0, inputs)) < self.tol):
pybamm.logger.debug("Keeping same solution at t={}".format(t))
y[:, idx] = y0
# Otherwise calculate new y0
Expand Down
4 changes: 2 additions & 2 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def calculate_consistent_state(self, model, time=0, y0_guess=None, inputs=None):
y0_guess = model.concatenated_initial_conditions.flatten()

inputs = inputs or {}
if model.rhs_eval.form == "casadi":
if model.convert_to_format == "casadi":
inputs = casadi.vertcat(*[x for x in inputs.values()])

# Split y0_guess into differential and algebraic
Expand Down Expand Up @@ -834,7 +834,7 @@ def function(self, t, y, inputs):
# keep jacobians sparse
return states_eval
else:
return self._function(t, y, inputs, known_evals={})[0]
return self._function(t, y, params=inputs, known_evals={})[0]


class Residuals(SolverCallable):
Expand Down
12 changes: 6 additions & 6 deletions pybamm/solvers/casadi_algebraic_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def _integrate(self, model, t_eval, inputs=None):
y = np.empty((len(y0), len(t_eval)))

# Set up
p_stacked = casadi.vertcat(*[x for x in inputs.values()])
inputs = casadi.vertcat(*[x for x in inputs.values()])
t_sym = casadi.MX.sym("t")
y_sym = casadi.MX.sym("y_alg", y0.shape[0])
p_sym = casadi.MX.sym("p", p_stacked.shape[0])
p_sym = casadi.MX.sym("p", inputs.shape[0])

t_p_sym = casadi.vertcat(t_sym, p_sym)
alg = model.casadi_algebraic(t_sym, y_sym, p_sym)
Expand All @@ -71,21 +71,21 @@ def _integrate(self, model, t_eval, inputs=None):
for idx, t in enumerate(t_eval):
# Evaluate algebraic with new t and previous y0, if it's already close
# enough then keep it
if np.all(abs(model.algebraic_eval(t, y0)) < self.tol):
if np.all(abs(model.algebraic_eval(t, y0, inputs)) < self.tol):
pybamm.logger.debug(
"Keeping same solution at t={}".format(t * model.timescale_eval)
)
y[:, idx] = y0
# Otherwise calculate new y0
else:
t_p_stacked = casadi.vertcat(t, p_stacked)
t_inputs = casadi.vertcat(t, inputs)
# Solve
try:
y_sol = roots(y0, t_p_stacked).full().flatten()
y_sol = roots(y0, t_inputs).full().flatten()
success = True
message = None
# Check final output
fun = model.casadi_algebraic(t, y_sol, p_stacked)
fun = model.casadi_algebraic(t, y_sol, inputs)
except RuntimeError as err:
success = False
message = err.args[0]
Expand Down
11 changes: 7 additions & 4 deletions tests/unit/test_solvers/test_algebraic_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ def test_simple_root_find(self):
class Model:
y0 = np.array([2])
jacobian_eval = None
convert_to_format = "python"

def algebraic_eval(self, t, y):
def algebraic_eval(self, t, y, inputs):
return y + 2

solver = pybamm.AlgebraicSolver()
Expand All @@ -51,8 +52,9 @@ def test_root_find_fail(self):
class Model:
y0 = np.array([2])
jacobian_eval = None
convert_to_format = "casadi"

def algebraic_eval(self, t, y):
def algebraic_eval(self, t, y, inputs):
# algebraic equation has no real root
return y ** 2 + 1

Expand All @@ -77,11 +79,12 @@ def test_with_jacobian(self):

class Model:
y0 = np.zeros(2)
convert_to_format = "python"

def algebraic_eval(self, t, y):
def algebraic_eval(self, t, y, inputs):
return A @ y - b

def jacobian_eval(self, t, y):
def jacobian_eval(self, t, y, inputs):
return A

model = Model()
Expand Down
25 changes: 14 additions & 11 deletions tests/unit/test_solvers/test_base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,14 @@ def __init__(self):
y = casadi.MX.sym("y")
p = casadi.MX.sym("p")
self.casadi_algebraic = casadi.Function(
"alg", [t, y, p], [self.algebraic_eval(t, y)]
"alg", [t, y, p], [self.algebraic_eval(t, y, p)]
)
self.convert_to_format = "casadi"

def rhs_eval(self, t, y):
def rhs_eval(self, t, y, inputs):
return np.array([])

def algebraic_eval(self, t, y):
def algebraic_eval(self, t, y, inputs):
return y + 2

solver = pybamm.BaseSolver(root_method="lm")
Expand All @@ -101,13 +102,14 @@ def __init__(self):
y = casadi.MX.sym("y", vec.size)
p = casadi.MX.sym("p")
self.casadi_algebraic = casadi.Function(
"alg", [t, y, p], [self.algebraic_eval(t, y)]
"alg", [t, y, p], [self.algebraic_eval(t, y, p)]
)
self.convert_to_format = "casadi"

def rhs_eval(self, t, y):
def rhs_eval(self, t, y, inputs):
return y[0:1]

def algebraic_eval(self, t, y):
def algebraic_eval(self, t, y, inputs):
return (y[1:] - vec[1:]) ** 2

model = VectorModel()
Expand All @@ -118,15 +120,15 @@ def algebraic_eval(self, t, y):
np.testing.assert_array_almost_equal(init_cond, vec)

# With jacobian
def jac_dense(t, y):
def jac_dense(t, y, inputs):
return 2 * np.hstack([np.zeros((3, 1)), np.diag(y[1:] - vec[1:])])

model.jac_algebraic_eval = jac_dense
init_cond = solver.calculate_consistent_state(model)
np.testing.assert_array_almost_equal(init_cond, vec)

# With sparse jacobian
def jac_sparse(t, y):
def jac_sparse(t, y, inputs):
return 2 * csr_matrix(
np.hstack([np.zeros((3, 1)), np.diag(y[1:] - vec[1:])])
)
Expand All @@ -145,13 +147,14 @@ def __init__(self):
y = casadi.MX.sym("y")
p = casadi.MX.sym("p")
self.casadi_algebraic = casadi.Function(
"alg", [t, y, p], [self.algebraic_eval(t, y)]
"alg", [t, y, p], [self.algebraic_eval(t, y, p)]
)
self.convert_to_format = "casadi"

def rhs_eval(self, t, y):
def rhs_eval(self, t, y, inputs):
return np.array([])

def algebraic_eval(self, t, y):
def algebraic_eval(self, t, y, inputs):
# algebraic equation has no root
return y ** 2 + 1

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_solvers/test_casadi_algebraic_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class Model:
p = casadi.MX.sym("p")
casadi_algebraic = casadi.Function("alg", [t, y, p], [y ** 2 + 1])

def algebraic_eval(self, t, y):
def algebraic_eval(self, t, y, inputs):
# algebraic equation has no real root
return y ** 2 + 1

Expand Down
2 changes: 0 additions & 2 deletions tests/unit/test_solvers/test_scikits_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,8 +713,6 @@ def test_dae_solver_algebraic_model(self):

if __name__ == "__main__":
print("Add -v for more debug output")
import warnings
warnings.simplefilter("error")
if "-v" in sys.argv:
debug = True
pybamm.set_logging_level("DEBUG")
Expand Down

0 comments on commit a702299

Please sign in to comment.