Skip to content

Commit

Permalink
#1230 adding simplifications, need to be careful about domains
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Nov 5, 2020
1 parent 2e69bd6 commit 23266fd
Show file tree
Hide file tree
Showing 4 changed files with 271 additions and 212 deletions.
5 changes: 4 additions & 1 deletion pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,7 +1015,10 @@ def _process_symbol(self, symbol):

elif isinstance(symbol, pybamm.Concatenation):
new_children = [self.process_symbol(child) for child in symbol.children]
new_symbol = spatial_method.concatenation(new_children)
try:
new_symbol = spatial_method.concatenation(new_children)
except:
self._process_symbol(symbol.children[1])

return new_symbol

Expand Down
177 changes: 3 additions & 174 deletions pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,68 +8,6 @@
from scipy.sparse import issparse, csr_matrix


def is_scalar_zero(expr):
"""
Utility function to test if an expression evaluates to a constant scalar zero
"""
if expr.is_constant():
result = expr.evaluate_ignoring_errors(t=None)
return isinstance(result, numbers.Number) and result == 0
else:
return False


def is_matrix_zero(expr):
"""
Utility function to test if an expression evaluates to a constant matrix zero
"""
if expr.is_constant():
result = expr.evaluate_ignoring_errors(t=None)
return (issparse(result) and result.count_nonzero() == 0) or (
isinstance(result, np.ndarray) and np.all(result == 0)
)
else:
return False


def is_scalar_one(expr):
"""
Utility function to test if an expression evaluates to a constant scalar one
"""
if expr.is_constant():
result = expr.evaluate_ignoring_errors(t=None)
return isinstance(result, numbers.Number) and result == 1
else:
return False


def is_matrix_one(expr):
"""
Utility function to test if an expression evaluates to a constant matrix one
"""
if expr.is_constant():
result = expr.evaluate_ignoring_errors(t=None)
return (issparse(result) and np.all(result.toarray() == 1)) or (
isinstance(result, np.ndarray) and np.all(result == 1)
)
else:
return False


def zeros_of_shape(shape):
"""
Utility function to create a scalar zero, or a vector or matrix of zeros of
the correct shape
"""
if shape == ():
return pybamm.Scalar(0)
else:
if len(shape) == 1 or shape[1] == 1:
return pybamm.Vector(np.zeros(shape))
else:
return pybamm.Matrix(csr_matrix(shape))


class BinaryOperator(pybamm.Symbol):
"""A node in the expression tree representing a binary operator (e.g. `+`, `*`)
Expand Down Expand Up @@ -222,7 +160,9 @@ def _binary_jac(self, left_jac, right_jac):

def _binary_simplify(self, new_left, new_right):
""" Simplify a binary operator. Default behaviour: unchanged"""
return self._binary_new_copy(new_left, new_right)
return pybamm.simplify_if_constant(
self._binary_new_copy(new_left, new_right), clear_domains=False
)

def _binary_evaluate(self, left, right):
""" Perform binary operation on nodes 'left' and 'right'. """
Expand Down Expand Up @@ -282,23 +222,6 @@ def _binary_evaluate(self, left, right):
with np.errstate(invalid="ignore"):
return left ** right

def _binary_simplify(self, left, right):
""" See :meth:`pybamm.BinaryOperator._binary_simplify()`. """

# anything to the power of zero is one
if is_scalar_zero(right):
return pybamm.Scalar(1)

# zero to the power of anything is zero
if is_scalar_zero(left):
return pybamm.Scalar(0)

# anything to the power of one is itself
if is_scalar_one(right):
return left

return self.__class__(left, right)


class Addition(BinaryOperator):
"""A node in the expression tree representing an addition operator
Expand All @@ -325,31 +248,7 @@ def _binary_evaluate(self, left, right):
def _binary_simplify(self, left, right):
"""
See :meth:`pybamm.BinaryOperator._binary_simplify()`.
Note
----
We check for scalars first, then matrices. This is because
(Zero Matrix) + (Zero Scalar)
should return (Zero Matrix), not (Zero Scalar).
"""

# anything added by a scalar zero returns the other child
if is_scalar_zero(left):
return right
if is_scalar_zero(right):
return left
# Check matrices after checking scalars
if is_matrix_zero(left):
if isinstance(right, pybamm.Scalar):
return pybamm.Array(right.value * np.ones(left.shape_for_testing))
else:
return right
if is_matrix_zero(right):
if isinstance(left, pybamm.Scalar):
return pybamm.Array(left.value * np.ones(right.shape_for_testing))
else:
return left

return pybamm.simplify_addition_subtraction(self.__class__, left, right)


Expand Down Expand Up @@ -379,31 +278,7 @@ def _binary_evaluate(self, left, right):
def _binary_simplify(self, left, right):
"""
See :meth:`pybamm.BinaryOperator._binary_simplify()`.
Note
----
We check for scalars first, then matrices. This is because
(Zero Matrix) - (Zero Scalar)
should return (Zero Matrix), not -(Zero Scalar).
"""

# anything added by a scalar zero returns the other child
if is_scalar_zero(left):
return -right
if is_scalar_zero(right):
return left
# Check matrices after checking scalars
if is_matrix_zero(left):
if isinstance(right, pybamm.Scalar):
return pybamm.Array(-right.value * np.ones(left.shape_for_testing))
else:
return -right
if is_matrix_zero(right):
if isinstance(left, pybamm.Scalar):
return pybamm.Array(left.value * np.ones(right.shape_for_testing))
else:
return left

return pybamm.simplify_addition_subtraction(self.__class__, left, right)


Expand Down Expand Up @@ -453,24 +328,6 @@ def _binary_evaluate(self, left, right):

def _binary_simplify(self, left, right):
""" See :meth:`pybamm.BinaryOperator._binary_simplify()`. """

# simplify multiply by scalar zero, being careful about shape
if is_scalar_zero(left):
return zeros_of_shape(right.shape_for_testing)
if is_scalar_zero(right):
return zeros_of_shape(left.shape_for_testing)

# if one of the children is a zero matrix, we have to be careful about shapes
if is_matrix_zero(left) or is_matrix_zero(right):
shape = (left * right).shape
return zeros_of_shape(shape)

# anything multiplied by a scalar one returns itself
if is_scalar_one(left):
return right
if is_scalar_one(right):
return left

return pybamm.simplify_multiplication_division(self.__class__, left, right)


Expand Down Expand Up @@ -518,10 +375,6 @@ def _binary_evaluate(self, left, right):

def _binary_simplify(self, left, right):
""" See :meth:`pybamm.BinaryOperator._binary_simplify()`. """
if is_matrix_zero(left) or is_matrix_zero(right):
shape = (left @ right).shape
return zeros_of_shape(shape)

return pybamm.simplify_multiplication_division(self.__class__, left, right)


Expand Down Expand Up @@ -569,30 +422,6 @@ def _binary_evaluate(self, left, right):

def _binary_simplify(self, left, right):
""" See :meth:`pybamm.BinaryOperator._binary_simplify()`. """

# zero divided by zero returns nan scalar
if is_scalar_zero(left) and is_scalar_zero(right):
return pybamm.Scalar(np.nan)

# zero divided by anything returns zero (being careful about shape)
if is_scalar_zero(left):
return zeros_of_shape(right.shape_for_testing)

# matrix zero divided by anything returns matrix zero (i.e. itself)
if is_matrix_zero(left):
return left

# anything divided by zero returns inf
if is_scalar_zero(right):
if left.shape_for_testing == ():
return pybamm.Scalar(np.inf)
else:
return pybamm.Array(np.inf * np.ones(left.shape_for_testing))

# anything divided by one is itself
if is_scalar_one(right):
return left

return pybamm.simplify_multiplication_division(self.__class__, left, right)


Expand Down
Loading

0 comments on commit 23266fd

Please sign in to comment.