Skip to content

Commit

Permalink
#735 add minimum and maximum
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Feb 21, 2020
1 parent d20b7cf commit 716f8ea
Show file tree
Hide file tree
Showing 13 changed files with 168 additions and 13 deletions.
10 changes: 10 additions & 0 deletions docs/source/expression_tree/binary_operator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,14 @@ Binary Operators
.. autoclass:: pybamm.NotEqualHeaviside
:members:

.. autoclass:: pybamm.Minimum
:members:

.. autoclass:: pybamm.Maximum
:members:

.. autofunction:: pybamm.minimum

.. autofunction:: pybamm.maximum

.. autofunction:: pybamm.source
4 changes: 2 additions & 2 deletions examples/scripts/DFN.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

# load model
model = pybamm.lithium_ion.DFN()

model.convert_to_format = "python"
# create geometry
geometry = model.default_geometry

Expand All @@ -30,7 +30,7 @@

# solve model
t_eval = np.linspace(0, 3600, 100)
solver = model.default_solver
solver = pybamm.IDAKLUSolver(root_method="lm")
solver.rtol = 1e-3
solver.atol = 1e-6
solution = solver.solve(model, t_eval)
Expand Down
4 changes: 4 additions & 0 deletions pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def version(formatted=False):
Heaviside,
EqualHeaviside,
NotEqualHeaviside,
Minimum,
minimum,
Maximum,
maximum,
source,
)
from .expression_tree.concatenations import (
Expand Down
72 changes: 72 additions & 0 deletions pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,78 @@ def _binary_evaluate(self, left, right):
return left < right


class Minimum(BinaryOperator):
" Returns the smaller of two objects "

def __init__(self, left, right):
super().__init__("minimum", left, right)

def __str__(self):
""" See :meth:`pybamm.Symbol.__str__()`. """
return "minimum({!s}, {!s})".format(self.left, self.right)

def _diff(self, variable):
""" See :meth:`pybamm.Symbol._diff()`. """
left, right = self.orphans
return (left <= right) * left.diff(variable) + (left > right) * right.diff(
variable
)

def _binary_jac(self, left_jac, right_jac):
""" See :meth:`pybamm.BinaryOperator._binary_jac()`. """
left, right = self.orphans
return (left <= right) * left_jac + (left > right) * right_jac

def _binary_evaluate(self, left, right):
""" See :meth:`pybamm.BinaryOperator._binary_evaluate()`. """
# don't raise RuntimeWarning for NaNs
return np.minimum(left, right)


class Maximum(BinaryOperator):
" Returns the smaller of two objects "

def __init__(self, left, right):
super().__init__("maximum", left, right)

def __str__(self):
""" See :meth:`pybamm.Symbol.__str__()`. """
return "maximum({!s}, {!s})".format(self.left, self.right)

def _diff(self, variable):
""" See :meth:`pybamm.Symbol._diff()`. """
left, right = self.orphans
return (left >= right) * left.diff(variable) + (left < right) * right.diff(
variable
)

def _binary_jac(self, left_jac, right_jac):
""" See :meth:`pybamm.BinaryOperator._binary_jac()`. """
left, right = self.orphans
return (left >= right) * left_jac + (left < right) * right_jac

def _binary_evaluate(self, left, right):
""" See :meth:`pybamm.BinaryOperator._binary_evaluate()`. """
# don't raise RuntimeWarning for NaNs
return np.maximum(left, right)


def minimum(left, right):
"""
Returns the smaller of two objects. Not to be confused with :meth:`pybamm.min`,
which returns min function of child.
"""
return pybamm.simplify_if_constant(Minimum(left, right), keep_domains=True)


def maximum(left, right):
"""
Returns the larger of two objects. Not to be confused with :meth:`pybamm.max`,
which returns max function of child.
"""
return pybamm.simplify_if_constant(Maximum(left, right), keep_domains=True)


def source(left, right, boundary=False):
"""A convinience function for creating (part of) an expression tree representing
a source term. This is necessary for spatial methods where the mass matrix
Expand Down
10 changes: 8 additions & 2 deletions pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,12 +341,18 @@ def log10(child):


def max(child):
" Returns max function of child. "
"""
Returns max function of child. Not to be confused with :meth:`pybamm.maximum`, which
returns the larger of two objects.
"""
return pybamm.simplify_if_constant(Function(np.max, child), keep_domains=True)


def min(child):
" Returns min function of child. "
"""
Returns min function of child. Not to be confused with :meth:`pybamm.minimum`, which
returns the smaller of two objects.
"""
return pybamm.simplify_if_constant(Function(np.min, child), keep_domains=True)


Expand Down
4 changes: 4 additions & 0 deletions pybamm/expression_tree/operations/convert_to_casadi.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def _convert(self, symbol, t=None, y=None, u=None):
# process children
converted_left = self.convert(left, t, y, u)
converted_right = self.convert(right, t, y, u)
if isinstance(symbol, pybamm.Minimum):
return casadi.fmin(converted_left, converted_right)
if isinstance(symbol, pybamm.Maximum):
return casadi.fmax(converted_left, converted_right)
# _binary_evaluate defined in derived classes for specific rules
return symbol._binary_evaluate(converted_left, converted_right)

Expand Down
4 changes: 4 additions & 0 deletions pybamm/expression_tree/operations/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def find_symbols(symbol, constant_symbols, variable_symbols):
"if scipy.sparse.issparse({1}) else "
"{0} * {1}".format(children_vars[0], children_vars[1])
)
elif isinstance(symbol, pybamm.Minimum):
symbol_str = "np.minimum({},{})".format(children_vars[0], children_vars[1])
elif isinstance(symbol, pybamm.Maximum):
symbol_str = "np.maximum({},{})".format(children_vars[0], children_vars[1])
else:
symbol_str = children_vars[0] + " " + symbol.name + " " + children_vars[1]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_coupled_variables(self, variables):

# N_e = N_e_diffusion + N_e_migration + N_e_convection

N_e = N_e_diffusion + c_e * v_box
N_e = N_e_diffusion + v_box

variables.update(self._get_standard_flux_variables(N_e))

Expand Down
13 changes: 11 additions & 2 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,19 @@ def set_up(self, model, inputs=None):
if self.ode_solver is True:
self.root_method = None
if (
isinstance(self, pybamm.CasadiSolver) or self.root_method == "casadi"
isinstance(self, pybamm.CasadiSolver)
) and model.convert_to_format != "casadi":
pybamm.logger.warning(
f"Converting {model.name} to CasADi for solving with CasADi solver"
"Converting {} to CasADi for solving with CasADi solver".format(
model.name
)
)
model.convert_to_format = "casadi"
if self.root_method == "casadi" and model.convert_to_format != "casadi":
pybamm.logger.warning(
"Converting {} to CasADi for calculating ICs with CasADi".format(
model.name
)
)
model.convert_to_format = "casadi"

Expand Down
15 changes: 15 additions & 0 deletions tests/unit/test_expression_tree/test_binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,21 @@ def test_heaviside(self):
self.assertEqual(heav.evaluate(y=np.array([0])), 1)
self.assertEqual(str(heav), "y[0:1] <= 1.0")

def test_minimum_maximum(self):
a = pybamm.Scalar(1)
b = pybamm.StateVector(slice(0, 1))
minimum = pybamm.minimum(a, b)
self.assertEqual(minimum.evaluate(y=np.array([2])), 1)
self.assertEqual(minimum.evaluate(y=np.array([1])), 1)
self.assertEqual(minimum.evaluate(y=np.array([0])), 0)
self.assertEqual(str(minimum), "minimum(1.0, y[0:1])")

maximum = pybamm.maximum(a, b)
self.assertEqual(maximum.evaluate(y=np.array([2])), 2)
self.assertEqual(maximum.evaluate(y=np.array([1])), 1)
self.assertEqual(maximum.evaluate(y=np.array([0])), 1)
self.assertEqual(str(maximum), "maximum(1.0, y[0:1])")


class TestIsZero(unittest.TestCase):
def test_is_scalar_zero(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,21 @@ def myfunction(x, y):
f = pybamm.Function(myfunction, b, d)
self.assertEqual(f.to_casadi(), casadi.MX(3))

# use classes to avoid simplification
# addition
self.assertEqual((a + b).to_casadi(), casadi.MX(1))
self.assertEqual((pybamm.Addition(a, b)).to_casadi(), casadi.MX(1))
# subtraction
self.assertEqual((c - d).to_casadi(), casadi.MX(-3))
self.assertEqual(pybamm.Subtraction(c, d).to_casadi(), casadi.MX(-3))
# multiplication
self.assertEqual((c * d).to_casadi(), casadi.MX(-2))
self.assertEqual(pybamm.Multiplication(c, d).to_casadi(), casadi.MX(-2))
# power
self.assertEqual((c ** d).to_casadi(), casadi.MX(1))
self.assertEqual(pybamm.Power(c, d).to_casadi(), casadi.MX(1))
# division
self.assertEqual((b / d).to_casadi(), casadi.MX(1 / 2))
self.assertEqual(pybamm.Division(b, d).to_casadi(), casadi.MX(1 / 2))

# minimum and maximum
self.assertEqual(pybamm.Minimum(a, b).to_casadi(), casadi.MX(0))
self.assertEqual(pybamm.Maximum(a, b).to_casadi(), casadi.MX(1))

def test_convert_array_symbols(self):
# Arrays
Expand Down
16 changes: 15 additions & 1 deletion tests/unit/test_expression_tree/test_operations/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_find_symbols(self):
a = pybamm.StateVector(slice(0, 1))
b = pybamm.StateVector(slice(1, 2))

# test a * b
# test a + b
constant_symbols = OrderedDict()
variable_symbols = OrderedDict()
expr = a + b
Expand Down Expand Up @@ -356,6 +356,20 @@ def test_evaluator_python(self):
result = evaluator.evaluate(t=t, y=y)
np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

# test something with a minimum or maximum
a = pybamm.Vector(np.array([1, 2]))
expr = pybamm.minimum(a, pybamm.StateVector(slice(0, 2)))
evaluator = pybamm.EvaluatorPython(expr)
for t, y in zip(t_tests, y_tests):
result = evaluator.evaluate(t=t, y=y)
np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

expr = pybamm.maximum(a, pybamm.StateVector(slice(0, 2)))
evaluator = pybamm.EvaluatorPython(expr)
for t, y in zip(t_tests, y_tests):
result = evaluator.evaluate(t=t, y=y)
np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

# test something with an index
expr = pybamm.Index(A @ pybamm.StateVector(slice(0, 2)), 0)
evaluator = pybamm.EvaluatorPython(expr)
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/test_expression_tree/test_operations/test_jac.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,18 @@ def test_jac_of_heaviside(self):
((a < y) * y ** 2).jac(y).evaluate(y=-5 * np.ones(5)), 0
)

def test_jac_of_minimum_maximum(self):
y = pybamm.StateVector(slice(0, 10))
y_test = np.linspace(0, 2, 10)
np.testing.assert_array_equal(
np.diag(pybamm.minimum(1, y ** 2).jac(y).evaluate(y=y_test)),
2 * y_test * (y_test < 1),
)
np.testing.assert_array_equal(
np.diag(pybamm.maximum(1, y ** 2).jac(y).evaluate(y=y_test)),
2 * y_test * (y_test > 1),
)

def test_jac_of_domain_concatenation(self):
# create mesh
mesh = get_mesh_for_testing()
Expand Down

0 comments on commit 716f8ea

Please sign in to comment.