Skip to content

Commit

Permalink
#846 coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Mar 10, 2020
1 parent 6c73878 commit 3d7bc4a
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 1 deletion.
1 change: 0 additions & 1 deletion pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def set_up(self, model, inputs=None):
self.root_method = None
if (
isinstance(self, (pybamm.CasadiSolver, pybamm.CasadiAlgebraicSolver))
or self.root_method == "casadi"
) and model.convert_to_format != "casadi":
pybamm.logger.warning(
"Converting {} to CasADi for solving with CasADi solver".format(
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/test_expression_tree/test_unary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import unittest
import numpy as np
from scipy.sparse import diags


class TestUnaryOperators(unittest.TestCase):
Expand Down Expand Up @@ -43,6 +44,13 @@ def test_sign(self):
signb = pybamm.sign(b)
self.assertEqual(signb.evaluate(), -1)

A = diags(np.linspace(-1, 1, 5))
b = pybamm.Matrix(A)
signb = pybamm.sign(b)
np.testing.assert_array_equal(
np.diag(signb.evaluate().toarray()), [-1, -1, 0, 1, 1]
)

def test_gradient(self):
a = pybamm.Symbol("a")
grad = pybamm.Gradient(a)
Expand Down
17 changes: 17 additions & 0 deletions tests/unit/test_solvers/test_base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,23 @@ def algebraic_eval(self, t, y):
):
solver.calculate_consistent_state(Model())

def test_convert_to_casadi_format(self):
# Make sure model is converted to casadi format
model = pybamm.BaseModel()
v = pybamm.Variable("v")
model.rhs = {v: -1}
model.initial_conditions = {v: 1}
model.convert_to_format = "python"

disc = pybamm.Discretisation()
disc.process_model(model)

solver = pybamm.BaseSolver()
pybamm.set_logging_level("ERROR")
solver.set_up(model, {})
self.assertEqual(model.convert_to_format, "casadi")
pybamm.set_logging_level("WARNING")


if __name__ == "__main__":
print("Add -v for more debug output")
Expand Down

0 comments on commit 3d7bc4a

Please sign in to comment.