diff --git a/CHANGELOG.md b/CHANGELOG.md index c73d1cd503..4a422dbad8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,7 @@ ## Breaking changes +- Removed `Simplification` class and `.simplify()` function ([#1369](https://github.com/pybamm-team/PyBaMM/pull/1369)) - All example notebooks in PyBaMM's GitHub repository must now include the command `pybamm.print_citations()`, otherwise the tests will fail. This is to encourage people to use this command to cite the relevant papers ([#1340](https://github.com/pybamm-team/PyBaMM/pull/1340)) - Notation has been homogenised to use positive and negative electrode (instead of cathode and anode). This applies to the parameter folders (now called `'positive_electrodes'` and `'negative_electrodes'`) and the options of `active_material` and `particle_cracking` submodels (now called `'positive'` and `'negative'`) ([#1337](https://github.com/pybamm-team/PyBaMM/pull/1337)) - `Interpolant` now takes `x` and `y` instead of a single `data` entry ([#1312](https://github.com/pybamm-team/PyBaMM/pull/1312)) diff --git a/docs/source/expression_tree/operations/index.rst b/docs/source/expression_tree/operations/index.rst index 2064dcaae2..c084389f1a 100644 --- a/docs/source/expression_tree/operations/index.rst +++ b/docs/source/expression_tree/operations/index.rst @@ -5,7 +5,6 @@ Classes and functions that operate on the expression tree .. toctree:: - simplify evaluate jacobian convert_to_casadi diff --git a/docs/source/expression_tree/operations/simplify.rst b/docs/source/expression_tree/operations/simplify.rst deleted file mode 100644 index cfbf11e253..0000000000 --- a/docs/source/expression_tree/operations/simplify.rst +++ /dev/null @@ -1,11 +0,0 @@ -Simplify -======== - -.. autoclass:: pybamm.Simplification - :members: - -.. autofunction:: pybamm.simplify_if_constant - -.. autofunction:: pybamm.simplify_addition_subtraction - -.. autofunction:: pybamm.simplify_multiplication_division diff --git a/docs/source/expression_tree/symbol.rst b/docs/source/expression_tree/symbol.rst index afe3991826..be252c9b06 100644 --- a/docs/source/expression_tree/symbol.rst +++ b/docs/source/expression_tree/symbol.rst @@ -1,5 +1,6 @@ Symbol ====== +.. autofunction:: pybamm.simplify_if_constant .. autoclass:: pybamm.Symbol :special-members: diff --git a/examples/notebooks/expression_tree/expression-tree.ipynb b/examples/notebooks/expression_tree/expression-tree.ipynb index bb5316316b..79bc175a17 100644 --- a/examples/notebooks/expression_tree/expression-tree.ipynb +++ b/examples/notebooks/expression_tree/expression-tree.ipynb @@ -82,7 +82,7 @@ "metadata": {}, "outputs": [], "source": [ - "diff_wrt_equation = equation.diff(t).simplify()\n", + "diff_wrt_equation = equation.diff(t)\n", "diff_wrt_equation.visualise('expression_tree2.png')" ] }, @@ -274,4 +274,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/pybamm/__init__.py b/pybamm/__init__.py index 239a8c55a6..223a75106e 100644 --- a/pybamm/__init__.py +++ b/pybamm/__init__.py @@ -96,13 +96,6 @@ def version(formatted=False): from .expression_tree.exceptions import * # Operations -from .expression_tree.operations.simplify import ( - Simplification, - simplify_if_constant, - simplify_addition_subtraction, - simplify_multiplication_division, -) - from .expression_tree.operations.evaluate import ( find_symbols, id_to_python_variable, diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index a240e1bf83..a2a81d26dc 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -162,12 +162,6 @@ def _binary_jac(self, left_jac, right_jac): """ Calculate the jacobian of a binary operator. """ raise NotImplementedError - def _binary_simplify(self, new_left, new_right): - """ Simplify a binary operator """ - return pybamm.simplify_if_constant( - self._binary_new_copy(new_left, new_right), clear_domains=False - ) - def _binary_evaluate(self, left, right): """ Perform binary operation on nodes 'left' and 'right'. """ raise NotImplementedError @@ -226,10 +220,6 @@ def _binary_evaluate(self, left, right): with np.errstate(invalid="ignore"): return left ** right - def _binary_simplify(self, new_left, new_right): - """ See :meth:`pybamm.BinaryOperator._binary_simplify()`. """ - return pybamm.simplified_power(new_left, new_right) - class Addition(BinaryOperator): """A node in the expression tree representing an addition operator @@ -253,10 +243,6 @@ def _binary_evaluate(self, left, right): """ See :meth:`pybamm.BinaryOperator._binary_evaluate()`. """ return left + right - def _binary_simplify(self, left, right): - """ See :meth:`pybamm.BinaryOperator._binary_simplify()`. """ - return pybamm.simplify_addition_subtraction(self.__class__, left, right) - class Subtraction(BinaryOperator): """A node in the expression tree representing a subtraction operator @@ -281,12 +267,6 @@ def _binary_evaluate(self, left, right): """ See :meth:`pybamm.BinaryOperator._binary_evaluate()`. """ return left - right - def _binary_simplify(self, left, right): - """ - See :meth:`pybamm.BinaryOperator._binary_simplify()`. - """ - return pybamm.simplify_addition_subtraction(self.__class__, left, right) - class Multiplication(BinaryOperator): """ @@ -332,10 +312,6 @@ def _binary_evaluate(self, left, right): else: return left * right - def _binary_simplify(self, left, right): - """ See :meth:`pybamm.BinaryOperator._binary_simplify()`. """ - return pybamm.simplify_multiplication_division(self.__class__, left, right) - class MatrixMultiplication(BinaryOperator): """A node in the expression tree representing a matrix multiplication operator @@ -378,10 +354,6 @@ def _binary_evaluate(self, left, right): """ See :meth:`pybamm.BinaryOperator._binary_evaluate()`. """ return left @ right - def _binary_simplify(self, left, right): - """ See :meth:`pybamm.BinaryOperator._binary_simplify()`. """ - return pybamm.simplify_multiplication_division(self.__class__, left, right) - class Division(BinaryOperator): """A node in the expression tree representing a division operator @@ -425,10 +397,6 @@ def _binary_evaluate(self, left, right): else: return left / right - def _binary_simplify(self, left, right): - """ See :meth:`pybamm.BinaryOperator._binary_simplify()`. """ - return pybamm.simplify_multiplication_division(self.__class__, left, right) - class Inner(BinaryOperator): """ @@ -486,10 +454,6 @@ def _binary_new_copy(self, left, right): """ See :meth:`pybamm.BinaryOperator._binary_new_copy()`. """ return pybamm.inner(left, right) - def _binary_simplify(self, left, right): - """ See :meth:`pybamm.BinaryOperator._binary_simplify()`. """ - return pybamm.simplify_multiplication_division(self.__class__, left, right) - def _evaluates_on_edges(self, dimension): """ See :meth:`pybamm.Symbol._evaluates_on_edges()`. """ return False diff --git a/pybamm/expression_tree/concatenations.py b/pybamm/expression_tree/concatenations.py index db144a0b5b..6abfc29265 100644 --- a/pybamm/expression_tree/concatenations.py +++ b/pybamm/expression_tree/concatenations.py @@ -94,12 +94,6 @@ def _concatenation_jac(self, children_jacs): """ Calculate the jacobian of a concatenation """ return NotImplementedError - def _concatenation_simplify(self, children): - """ See :meth:`pybamm.Symbol.simplify()`. """ - new_symbol = self.__class__(*children) - new_symbol.clear_domains() - return new_symbol - def _evaluate_for_shape(self): """ See :meth:`pybamm.Symbol.evaluate_for_shape` """ if len(self.children) == 0: @@ -155,12 +149,6 @@ def _concatenation_jac(self, children_jacs): else: return SparseStack(*children_jacs) - def _concatenation_simplify(self, children): - """ See :meth:`pybamm.Concatenation._concatenation_simplify()`. """ - new_symbol = simplified_numpy_concatenation(*children) - new_symbol.clear_domains() - return new_symbol - class DomainConcatenation(Concatenation): """A node in the expression tree representing a concatenation of symbols, being @@ -308,17 +296,6 @@ def _concatenation_new_copy(self, children): ) return new_symbol - def _concatenation_simplify(self, children): - """ See :meth:`pybamm.Concatenation._concatenation_simplify()`. """ - new_symbol = simplified_domain_concatenation( - children, self.full_mesh, copy_this=self - ) - # TODO: this should not be needed, but somehow we are still getting domains in - # the simplified children - new_symbol.clear_domains() - - return new_symbol - class SparseStack(Concatenation): """A node in the expression tree representing a concatenation of sparse diff --git a/pybamm/expression_tree/operations/simplify.py b/pybamm/expression_tree/operations/simplify.py deleted file mode 100644 index b831cbcdfe..0000000000 --- a/pybamm/expression_tree/operations/simplify.py +++ /dev/null @@ -1,648 +0,0 @@ -# -# Simplify a symbol -# -import pybamm - -import numpy as np -import numbers -from scipy.sparse import issparse, csr_matrix - - -def simplify_if_constant(symbol, clear_domains=True): - """ - Utility function to simplify an expression tree if it evalutes to a constant - scalar, vector or matrix - """ - if clear_domains is True: - domain = None - auxiliary_domains = None - else: - domain = symbol.domain - auxiliary_domains = symbol.auxiliary_domains - if symbol.is_constant(): - result = symbol.evaluate_ignoring_errors() - if result is not None: - if ( - isinstance(result, numbers.Number) - or (isinstance(result, np.ndarray) and result.ndim == 0) - or isinstance(result, np.bool_) - ): - return pybamm.Scalar(result) - elif isinstance(result, np.ndarray) or issparse(result): - if result.ndim == 1 or result.shape[1] == 1: - return pybamm.Vector( - result, domain=domain, auxiliary_domains=auxiliary_domains - ) - else: - # Turn matrix of zeros into sparse matrix - if isinstance(result, np.ndarray) and np.all(result == 0): - result = csr_matrix(result) - return pybamm.Matrix( - result, domain=domain, auxiliary_domains=auxiliary_domains - ) - - return symbol - - -def simplify_addition_subtraction(myclass, left, right): - """ - if children are associative (addition, subtraction, etc) then try to find groups of - constant children (that produce a value) and simplify them to a single term - - The purpose of this function is to simplify expressions like (1 + (1 + p)), which - should be simplified to (2 + p). The former expression consists of an Addition, with - a left child of Scalar type, and a right child of another Addition containing a - Scalar and a Parameter. For this case, this function will first flatten the - expression to a list of the bottom level children (i.e. [Scalar(1), Scalar(2), - Parameter(p)]), and their operators (i.e. [None, Addition, Addition]), and then - combine all the constant children (i.e. Scalar(1) and Scalar(1)) to a single child - (i.e. Scalar(2)) - - Note that this function will flatten the expression tree until a symbol is found - that is not either an Addition or a Subtraction, so this function would simplify - (3 - (2 + a*b*c)) to (1 + a*b*c) - - This function is useful if different children expressions contain non-constant terms - that prevent them from being simplified, so for example (1 + a) + (b - 2) - (6 + c) - will be simplified to (-7 + a + b - c) - - Parameters - ---------- - - myclass: class - the binary operator class (pybamm.Addition or pybamm.Subtraction) operating on - children left and right - left: derived from pybamm.Symbol - the left child of the binary operator - right: derived from pybamm.Symbol - the right child of the binary operator - - """ - numerator = [] - numerator_types = [] - - def flatten(this_class, left_child, right_child, in_subtraction): - """ - recursive function to flatten a term involving only additions or subtractions - - outputs to lists `numerator` and `numerator_types` - - Note that domains are all set to [] as we do not wish to consider domains once - simplifications are applied - - e.g. - - (1 + 2) + 3 -> [1, 2, 3] and [None, Addition, Addition] - 1 + (2 - 3) -> [1, 2, 3] and [None, Addition, Subtraction] - 1 - (2 + 3) -> [1, 2, 3] and [None, Subtraction, Subtraction] - (1 + 2) - (2 + 3) -> [1, 2, 2, 3] and [None, Addition, Subtraction, Subtraction] - """ - - left_child.clear_domains() - right_child.clear_domains() - for side, child in [("left", left_child), ("right", right_child)]: - if isinstance(child, (pybamm.Addition, pybamm.Subtraction)): - left, right = child.orphans - flatten(child.__class__, left, right, in_subtraction) - - else: - numerator.append(child) - if in_subtraction is None: - numerator_types.append(None) - elif in_subtraction: - numerator_types.append(pybamm.Subtraction) - else: - numerator_types.append(pybamm.Addition) - - if side == "left": - if in_subtraction is None: - in_subtraction = this_class == pybamm.Subtraction - elif this_class == pybamm.Subtraction: - in_subtraction = not in_subtraction - - flatten(myclass, left, right, None) - - def partition_by_constant(source, types): - """ - function to partition a source list of symbols into those that return a constant - value, and those that do not - """ - constant = [] - nonconstant = [] - constant_types = [] - nonconstant_types = [] - - for child, op_type in zip(source, types): - if child.is_constant() and child.evaluate_ignoring_errors() is not None: - constant.append(child) - constant_types.append(op_type) - else: - nonconstant.append(child) - nonconstant_types.append(op_type) - return constant, nonconstant, constant_types, nonconstant_types - - def fold_add_subtract(array, types): - """ - performs a fold operation on the children nodes in `array`, using the operator - types given in `types` - - e.g. if the input was: - array = [1, 2, 3, 4] - types = [None, +, -, +] - - the result would be 1 + 2 - 3 + 4 - """ - ret = None - if len(array) > 0: - if types[0] in [None, pybamm.Addition]: - ret = array[0] - elif types[0] == pybamm.Subtraction: - ret = -array[0] - for child, typ in zip(array[1:], types[1:]): - if typ == pybamm.Addition: - ret += child - else: - ret -= child - return ret - - # simplify identical terms - i = 0 - while i < len(numerator) - 1: - if isinstance(numerator[i], pybamm.Multiplication) and isinstance( - numerator[i].children[0], pybamm.Scalar - ): - term_i = numerator[i].orphans[1] - term_i_count = numerator[i].children[0].evaluate() - else: - term_i = numerator[i] - term_i_count = 1 - - # loop through rest of numerator counting up and deleting identical terms - for j, (term_j, typ_j) in enumerate( - zip(numerator[i + 1 :], numerator_types[i + 1 :]) - ): - if isinstance(term_j, pybamm.Multiplication) and isinstance( - term_j.left, pybamm.Scalar - ): - factor = term_j.left.evaluate() - term_j = term_j.right - else: - factor = 1 - if term_i.id == term_j.id: - if typ_j == pybamm.Addition: - term_i_count += factor - elif typ_j == pybamm.Subtraction: - term_i_count -= factor - del numerator[j + i + 1] - del numerator_types[j + i + 1] - - # replace this term by count * term if count > 1 - if term_i_count != 1: - # simplify the result just in case - # (e.g. count == 0, or can fold constant into the term) - numerator[i] = (term_i_count * term_i).simplify() - - i += 1 - - # can reorder the numerator - (constant, nonconstant, constant_types, nonconstant_types) = partition_by_constant( - numerator, numerator_types - ) - - constant_expr = fold_add_subtract(constant, constant_types) - nonconstant_expr = fold_add_subtract(nonconstant, nonconstant_types) - - if constant_expr is not None and nonconstant_expr is None: - # might be no nonconstants - new_expression = pybamm.simplify_if_constant(constant_expr) - elif constant_expr is None and nonconstant_expr is not None: - # might be no constants - new_expression = nonconstant_expr - else: - # or mix of both - constant_expr = pybamm.simplify_if_constant(constant_expr) - new_expression = constant_expr + nonconstant_expr - - return new_expression - - -def simplify_multiplication_division(myclass, left, right): - """ - if children are associative (multiply, division, etc) then try to find - groups of constant children (that produce a value) and simplify them - - The purpose of this function is to simplify expressions of the type (1 * c / 2), - which should simplify to (0.5 * c). The former expression consists of a Division, - with a left child of a Multiplication containing a Scalar and a Parameter, and a - right child consisting of a Scalar. For this case, this function will first flatten - the expression to a list of the bottom level children on the numerator (i.e. - [Scalar(1), Parameter(c)]) and their operators (i.e. [None, Multiplication]), as - well as those children on the denominator (i.e. [Scalar(2)]. After this, all the - constant children on the numerator and denominator (i.e. Scalar(1) and Scalar(2)) - will be combined appropriately, in this case to Scalar(0.5), and combined with the - nonconstant children (i.e. Parameter(c)) - - Note that this function will flatten the expression tree until a symbol is found - that is not either an Multiplication, Division or MatrixMultiplication, so this - function would simplify (3*(1 + d)*2) to (6 * (1 + d)) - - As well as Multiplication and Division, this function can handle - MatrixMultiplication. If any MatrixMultiplications are found on the - numerator/denominator, no reordering of children is done to find groups of constant - children. In this case only neighbouring constant children on the numerator are - simplified - - Parameters - ---------- - - myclass: class - the binary operator class (pybamm.Addition or pybamm.Subtraction) operating on - children left and right - left: derived from pybamm.Symbol - the left child of the binary operator - right: derived from pybamm.Symbol - the right child of the binary operator - - """ - numerator = [] - denominator = [] - numerator_types = [] - denominator_types = [] - - # recursive function to flatten a term involving only multiplications or divisions - def flatten( - previous_class, - this_class, - left_child, - right_child, - in_numerator, - in_matrix_multiplication, - ): - """ - recursive function to flatten a term involving only Multiplication, Division or - MatrixMultiplication. keeps track of wether a term is on the numerator or - denominator. For those terms on the numerator, their operator type - (Multiplication or MatrixMultiplication) is stored - - Note that multiplication *within* matrix multiplications, e.g. a@(b*c), are not - flattened into a@b*c, as this would be incorrect (see #253) - - Note that the domains are all set to [] as we do not wish to consider domains - once simplifications are applied - - outputs to lists `numerator`, `denominator` and `numerator_types` - - e.g. - expression numerator denominator numerator_types - (1 * 2) / 3 -> [1, 2] [3] [None, Multiplication] - (1 @ 2) / 3 -> [1, 2] [3] [None, MatrixMultiplication] - 1 / (c / 2) -> [1, 2] [c] [None, Multiplication] - """ - - left_child.clear_domains() - right_child.clear_domains() - for side, child in [("left", left_child), ("right", right_child)]: - - if side == "left": - other_child = right_child - else: - other_child = left_child - - # flatten if all matrix multiplications - # flatten if one child is a matrix mult if the other term is a scalar or - # vector - if isinstance(child, pybamm.MatrixMultiplication) and ( - in_matrix_multiplication - or isinstance(other_child, (pybamm.Scalar, pybamm.Vector)) - ): - left, right = child.orphans - if ( - side == "left" - and this_class == pybamm.Multiplication - and isinstance(other_child, pybamm.Vector) - ): - # change (m @ v1) * v2 -> v2 * m @ v so can simplify correctly - # (#341) - numerator.append(other_child) - numerator_types.append(previous_class) - flatten( - this_class, child.__class__, left, right, in_numerator, True - ) - break - if side == "left": - flatten( - previous_class, child.__class__, left, right, in_numerator, True - ) - else: - flatten( - this_class, child.__class__, left, right, in_numerator, True - ) - # flatten if all multiplies and divides - elif ( - isinstance(child, (pybamm.Multiplication, pybamm.Division)) - and not in_matrix_multiplication - ): - left, right = child.orphans - if side == "left": - flatten( - previous_class, - child.__class__, - left, - right, - in_numerator, - False, - ) - else: - flatten( - this_class, child.__class__, left, right, in_numerator, False - ) - # everything else don't flatten - else: - if in_numerator: - numerator.append(child) - if side == "left": - numerator_types.append(previous_class) - else: - numerator_types.append(this_class) - else: - denominator.append(child) - if side == "left": - denominator_types.append(previous_class) - else: - denominator_types.append(this_class) - - if side == "left" and this_class == pybamm.Division: - in_numerator = not in_numerator - - flatten(None, myclass, left, right, True, myclass == pybamm.MatrixMultiplication) - - # check if there is a matrix multiply in the numerator (if so we can't reorder it) - numerator_has_mat_mul = any( - [typ == pybamm.MatrixMultiplication for typ in numerator_types + [myclass]] - ) - - denominator_has_mat_mul = any( - [typ == pybamm.MatrixMultiplication for typ in denominator_types] - ) - - def partition_by_constant(source, types=None): - """ - function to partition a source list of symbols into those that return a constant - value, and those that do not - """ - constant = [] - nonconstant = [] - - for child in source: - if child.is_constant() and child.evaluate_ignoring_errors() is not None: - constant.append(child) - else: - nonconstant.append(child) - return constant, nonconstant - - def fold_multiply(array, types=None): - """ - performs a fold operation on the children nodes in `array`, using the operator - types given in `types` - - e.g. if the input was: - array = [1, 2, 3, 4] - types = [None, *, @, *] - - the result would be 1 * 2 @ 3 * 4 - """ - ret = None - if len(array) > 0: - if types is None: - ret = array[0] - for child in array[1:]: - ret *= child - else: - # work backwards through 'array' and 'types' so that multiplications - # and matrix multiplications are performed in the most efficient order - ret = array[-1] - for child, typ in zip(reversed(array[:-1]), reversed(types[1:])): - if typ == pybamm.MatrixMultiplication: - ret = child @ ret - else: - ret = child * ret - return ret - - def simplify_with_mat_mul(nodes, types): - new_nodes = [nodes[0]] - new_types = [types[0]] - for child, typ in zip(nodes[1:], types[1:]): - if ( - new_nodes[-1].is_constant() - and child.is_constant() - and new_nodes[-1].evaluate_ignoring_errors() is not None - and child.evaluate_ignoring_errors() is not None - ): - if typ == pybamm.MatrixMultiplication: - new_nodes[-1] = new_nodes[-1] @ child - else: - new_nodes[-1] *= child - new_nodes[-1] = pybamm.simplify_if_constant(new_nodes[-1]) - else: - new_nodes.append(child) - new_types.append(typ) - new_nodes = fold_multiply(new_nodes, new_types) - return new_nodes - - if numerator_has_mat_mul and denominator_has_mat_mul: - new_numerator = simplify_with_mat_mul(numerator, numerator_types) - new_denominator = simplify_with_mat_mul(denominator, denominator_types) - if new_denominator is None: - result = new_numerator - else: - result = new_numerator / new_denominator - - elif numerator_has_mat_mul and not denominator_has_mat_mul: - # can reorder the denominator since no matrix multiplies - denominator_constant, denominator_nonconst = partition_by_constant(denominator) - - constant_denominator_expr = fold_multiply(denominator_constant) - nonconst_denominator_expr = fold_multiply(denominator_nonconst) - - # fold constant denominator expr into numerator if possible - if constant_denominator_expr is not None: - for i, child in enumerate(numerator): - if child.is_constant() and child.evaluate_ignoring_errors() is not None: - numerator[i] = child / constant_denominator_expr - numerator[i] = pybamm.simplify_if_constant(numerator[i]) - constant_denominator_expr = None - - new_numerator = simplify_with_mat_mul(numerator, numerator_types) - - # result = constant_numerator_expr * new_numerator / nonconst_denominator_expr - # need to take into accound that terms can be None - if constant_denominator_expr is None: - if nonconst_denominator_expr is None: - result = new_numerator - else: - result = new_numerator / nonconst_denominator_expr - else: - # invert constant denominator terms for speed - constant_numerator_expr = pybamm.simplify_if_constant( - 1 / constant_denominator_expr - ) - - if nonconst_denominator_expr is None: - result = constant_numerator_expr * new_numerator - else: - result = ( - constant_numerator_expr * new_numerator / nonconst_denominator_expr - ) - - elif not numerator_has_mat_mul and denominator_has_mat_mul: - new_denominator = simplify_with_mat_mul(denominator, denominator_types) - - # can reorder the numerator since no matrix multiplies - numerator_constant, numerator_nonconst = partition_by_constant(numerator) - - constant_numerator_expr = fold_multiply(numerator_constant) - nonconst_numerator_expr = fold_multiply(numerator_nonconst) - - # result = constant_numerator_expr * nonconst_numerator_expr / new_denominator - # need to take into account that terms can be None - if constant_numerator_expr is None: - result = nonconst_numerator_expr / new_denominator - else: - constant_numerator_expr = pybamm.simplify_if_constant( - constant_numerator_expr - ) - if nonconst_numerator_expr is None: - result = constant_numerator_expr / new_denominator - else: - result = ( - constant_numerator_expr * nonconst_numerator_expr / new_denominator - ) - - else: - # can reorder the numerator since no matrix multiplies - numerator_constant, numerator_nonconstant = partition_by_constant(numerator) - - constant_numerator_expr = fold_multiply(numerator_constant) - nonconst_numerator_expr = fold_multiply(numerator_nonconstant) - - # can reorder the denominator since no matrix multiplies - denominator_constant, denominator_nonconst = partition_by_constant(denominator) - - constant_denominator_expr = fold_multiply(denominator_constant) - nonconst_denominator_expr = fold_multiply(denominator_nonconst) - - if constant_numerator_expr is not None: - if constant_denominator_expr is not None: - constant_numerator_expr = pybamm.simplify_if_constant( - constant_numerator_expr / constant_denominator_expr - ) - else: - constant_numerator_expr = pybamm.simplify_if_constant( - constant_numerator_expr - ) - else: - if constant_denominator_expr is not None: - constant_numerator_expr = pybamm.simplify_if_constant( - 1 / constant_denominator_expr - ) - - # result = constant_numerator_expr * nonconst_numerator_expr - # / nonconst_denominator_expr - # need to take into account that terms can be None - if constant_numerator_expr is None: - result = nonconst_numerator_expr - else: - if nonconst_numerator_expr is None: - result = constant_numerator_expr - else: - result = constant_numerator_expr * nonconst_numerator_expr - - if nonconst_denominator_expr is not None: - result = result / nonconst_denominator_expr - - return result - - -class Simplification(object): - def __init__(self, simplified_symbols=None): - self._simplified_symbols = simplified_symbols or {} - - def simplify(self, symbol, clear_domains=True): - """ - This function recurses down the tree, applying any simplifications necessary. - - Parameters - ---------- - symbol : :class:`pybamm.Symbol` - The symbol to simplify - clear_domains : bool - Whether to remove a symbol's domain when simplifying. Default is True. - - Returns - ------- - :class:`pybamm.Symbol` - Simplified symbol - """ - - try: - return self._simplified_symbols[symbol.id] - except KeyError: - simplified_symbol = self._simplify(symbol, clear_domains) - - self._simplified_symbols[symbol.id] = simplified_symbol - - return simplified_symbol - - def _simplify(self, symbol, clear_domains=True): - """ See :meth:`Simplification.simplify()`. """ - if clear_domains: - symbol.clear_domains() - - if isinstance(symbol, pybamm.BinaryOperator): - left, right = symbol.children - # process children - new_left = self.simplify(left) - new_right = self.simplify(right) - # _binary_simplify defined in derived classes for specific rules - new_symbol = symbol._binary_simplify(new_left, new_right) - - elif isinstance(symbol, pybamm.UnaryOperator): - # Reassign domain for gradient and divergence - if isinstance( - symbol, (pybamm.Gradient, pybamm.Divergence, pybamm.Integral) - ): - new_child = self.simplify(symbol.child, clear_domains=False) - else: - new_child = self.simplify(symbol.child) - # _unary_simplify defined in derived classes for specific rules - new_symbol = symbol._unary_simplify(new_child) - - elif isinstance(symbol, pybamm.Function): - simplified_children = [None] * len(symbol.children) - for i, child in enumerate(symbol.children): - simplified_children[i] = self.simplify(child) - # _function_simplify defined in function class - new_symbol = symbol._function_simplify(simplified_children) - - elif isinstance(symbol, pybamm.Concatenation): - new_children = [self.simplify(child) for child in symbol.children] - new_symbol = symbol._concatenation_simplify(new_children) - - else: - # Backup option: return new copy of the object - try: - new_symbol = symbol.new_copy() - return new_symbol - except NotImplementedError: - raise NotImplementedError( - "Cannot simplify symbol of type '{}'".format(type(symbol)) - ) - - new_symbol = simplify_if_constant(new_symbol) - - if clear_domains: - new_symbol.clear_domains() - else: - new_symbol.copy_domains(symbol) - - return new_symbol diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index 99e805eaff..55b1c103dd 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -8,7 +8,7 @@ import copy import numpy as np from anytree.exporter import DotExporter -from scipy.sparse import issparse +from scipy.sparse import issparse, csr_matrix def domain_size(domain): @@ -122,6 +122,42 @@ def is_matrix_one(expr): return False +def simplify_if_constant(symbol, clear_domains=True): + """ + Utility function to simplify an expression tree if it evalutes to a constant + scalar, vector or matrix + """ + if clear_domains is True: + domain = None + auxiliary_domains = None + else: + domain = symbol.domain + auxiliary_domains = symbol.auxiliary_domains + if symbol.is_constant(): + result = symbol.evaluate_ignoring_errors() + if result is not None: + if ( + isinstance(result, numbers.Number) + or (isinstance(result, np.ndarray) and result.ndim == 0) + or isinstance(result, np.bool_) + ): + return pybamm.Scalar(result) + elif isinstance(result, np.ndarray) or issparse(result): + if result.ndim == 1 or result.shape[1] == 1: + return pybamm.Vector( + result, domain=domain, auxiliary_domains=auxiliary_domains + ) + else: + # Turn matrix of zeros into sparse matrix + if isinstance(result, np.ndarray) and np.all(result == 0): + result = csr_matrix(result) + return pybamm.Matrix( + result, domain=domain, auxiliary_domains=auxiliary_domains + ) + + return symbol + + class Symbol(anytree.NodeMixin): """Base node class for the expression tree @@ -787,8 +823,8 @@ def has_symbol_of_classes(self, symbol_classes): return any(isinstance(symbol, symbol_classes) for symbol in self.pre_order()) def simplify(self, simplified_symbols=None, clear_domains=True): - """ Simplify the expression tree. See :class:`pybamm.Simplification`. """ - return pybamm.Simplification(simplified_symbols).simplify(self, clear_domains) + """ `simplify()` has now been removed. """ + raise pybamm.ModelError("simplify is deprecated as it now has no effect") def to_casadi(self, t=None, y=None, y_dot=None, inputs=None, casadi_symbols=None): """ diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index 669e80ebde..58adfbcc44 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -54,14 +54,6 @@ def _unary_jac(self, child_jac): """ Calculate the jacobian of a unary operator. """ raise NotImplementedError - def _unary_simplify(self, simplified_child): - """ - Simplify a unary operator. Default behaviour is to make a new copy, with - simplified child. - """ - - return self._unary_new_copy(simplified_child) - def _unary_evaluate(self, child): """Perform unary operation on a child. """ raise NotImplementedError @@ -344,19 +336,6 @@ def diff(self, variable): # We shouldn't need this raise NotImplementedError - def _unary_simplify(self, simplified_child): - """ See :meth:`pybamm.UnaryOperator.simplify()`. """ - - # if there are none of these nodes in the child tree, then this expression - # does not depend on space, and therefore the spatial operator result is zero - search_types = (pybamm.Variable, pybamm.StateVector, pybamm.SpatialVariable) - - # do the search, return a scalar zero node if no relevent nodes are found - if not self.has_symbol_of_classes(search_types): - return pybamm.Scalar(0) - else: - return self.__class__(simplified_child) - class Gradient(SpatialOperator): """A node in the expression tree representing a grad operator @@ -572,11 +551,6 @@ def set_id(self): + tuple(self.domain) ) - def _unary_simplify(self, simplified_child): - """ See :meth:`UnaryOperator._unary_simplify()`. """ - - return self.__class__(simplified_child, self.integration_variable) - def _unary_new_copy(self, child): """ See :meth:`UnaryOperator._unary_new_copy()`. """ @@ -721,11 +695,6 @@ def set_id(self): + tuple(self.domain) ) - def _unary_simplify(self, simplified_child): - """ See :meth:`UnaryOperator._unary_simplify()`. """ - - return self.__class__(simplified_child, vector_type=self.vector_type) - def _unary_new_copy(self, child): """ See :meth:`UnaryOperator._unary_new_copy()`. """ @@ -782,11 +751,6 @@ def set_id(self): (self.__class__, self.name) + (self.children[0].id,) + tuple(self.domain) ) - def _unary_simplify(self, simplified_child): - """ See :meth:`UnaryOperator._unary_simplify()`. """ - - return self.__class__(simplified_child, region=self.region) - def _unary_new_copy(self, child): """ See :meth:`UnaryOperator._unary_new_copy()`. """ @@ -836,10 +800,6 @@ def _evaluates_on_edges(self, dimension): """ See :meth:`pybamm.Symbol._evaluates_on_edges()`. """ return False - def _unary_simplify(self, simplified_child): - """ See :meth:`UnaryOperator._unary_simplify()`. """ - return self.__class__(simplified_child, self.side, self.domain) - def _unary_new_copy(self, child): """ See :meth:`UnaryOperator._unary_new_copy()`. """ return self.__class__(child, self.side, self.domain) @@ -904,10 +864,6 @@ def set_id(self): + tuple([(k, tuple(v)) for k, v in self.auxiliary_domains.items()]) ) - def _unary_simplify(self, simplified_child): - """ See :meth:`UnaryOperator._unary_simplify()`. """ - return self.__class__(simplified_child, self.side) - def _unary_new_copy(self, child): """ See :meth:`UnaryOperator._unary_new_copy()`. """ return self.__class__(child, self.side) diff --git a/pybamm/models/base_model.py b/pybamm/models/base_model.py index 984047ee53..9e1a5fe5c7 100644 --- a/pybamm/models/base_model.py +++ b/pybamm/models/base_model.py @@ -70,10 +70,6 @@ class BaseModel(object): solver set up use_jacobian : bool Whether to use the Jacobian when solving the model (default is True) - use_simplify : bool - Whether to simplify the expression tress representing the rhs and - algebraic equations, Jacobain (if using) and events, before solving the - model (default is True) convert_to_format : str Whether to convert the expression trees representing the rhs and algebraic equations, Jacobain (if using) and events into a different format: @@ -111,9 +107,8 @@ def __init__(self, name="Unnamed model"): self._input_parameters = None self._variables_casadi = {} - # Default behaviour is to use the jacobian and simplify + # Default behaviour is to use the jacobian self.use_jacobian = True - self.use_simplify = True self.convert_to_format = "casadi" # Model is not initially discretised @@ -325,7 +320,6 @@ def new_empty_copy(self): """ new_model = self.__class__(name=self.name) new_model.use_jacobian = self.use_jacobian - new_model.use_simplify = self.use_simplify new_model.convert_to_format = self.convert_to_format new_model.timescale = self.timescale new_model.length_scales = self.length_scales diff --git a/pybamm/models/full_battery_models/base_battery_model.py b/pybamm/models/full_battery_models/base_battery_model.py index 1369b6708b..b76d9bb7f3 100644 --- a/pybamm/models/full_battery_models/base_battery_model.py +++ b/pybamm/models/full_battery_models/base_battery_model.py @@ -679,7 +679,6 @@ def new_empty_copy(self): "See :meth:`pybamm.BaseModel.new_empty_copy()`" new_model = self.__class__(name=self.name, options=self.options, build=False) new_model.use_jacobian = self.use_jacobian - new_model.use_simplify = self.use_simplify new_model.convert_to_format = self.convert_to_format new_model.timescale = self.timescale new_model.length_scales = self.length_scales diff --git a/pybamm/models/full_battery_models/lithium_ion/basic_dfn_half_cell.py b/pybamm/models/full_battery_models/lithium_ion/basic_dfn_half_cell.py index 3e478e18a7..dbfab5cff0 100644 --- a/pybamm/models/full_battery_models/lithium_ion/basic_dfn_half_cell.py +++ b/pybamm/models/full_battery_models/lithium_ion/basic_dfn_half_cell.py @@ -428,7 +428,6 @@ def default_spatial_methods(self): def new_copy(self, build=False): new_model = self.__class__(name=self.name, options=self.options) new_model.use_jacobian = self.use_jacobian - new_model.use_simplify = self.use_simplify new_model.convert_to_format = self.convert_to_format new_model.timescale = self.timescale new_model.length_scales = self.length_scales diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index 82ea1b2319..e2029687bc 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -124,7 +124,7 @@ def copy(self): return new_solver def set_up(self, model, inputs=None, t_eval=None): - """Unpack model, perform checks, simplify and calculate jacobian. + """Unpack model, perform checks, and calculate jacobian. Parameters ---------- @@ -199,7 +199,6 @@ def set_up(self, model, inputs=None, t_eval=None): model.convert_to_format = "casadi" if model.convert_to_format != "casadi": - simp = pybamm.Simplification() # Create Jacobian from concatenated rhs and algebraic y = pybamm.StateVector(slice(0, model.concatenated_initial_conditions.size)) # set up Jacobian object, for re-use of dict @@ -228,9 +227,6 @@ def report(string): use_jacobian = model.use_jacobian if model.convert_to_format != "casadi": # Process with pybamm functions - if model.use_simplify: - report(f"Simplifying {name}") - func = simp.simplify(func) if model.convert_to_format == "jax": report(f"Converting {name} to jax") @@ -239,9 +235,6 @@ def report(string): if use_jacobian: report(f"Calculating jacobian for {name}") jac = jacobian.jac(func, y) - if model.use_simplify: - report(f"Simplifying jacobian for {name}") - jac = simp.simplify(jac) if model.convert_to_format == "python": report(f"Converting jacobian for {name} to python") jac = pybamm.EvaluatorPython(jac) diff --git a/tests/integration/test_models/standard_model_tests.py b/tests/integration/test_models/standard_model_tests.py index 35718f717a..e109980e17 100644 --- a/tests/integration/test_models/standard_model_tests.py +++ b/tests/integration/test_models/standard_model_tests.py @@ -129,12 +129,10 @@ def __init__(self, model, parameter_values=None, disc=None): self.model = model def evaluate_model( - self, simplify=False, use_known_evals=False, to_python=False, to_jax=False + self, use_known_evals=False, to_python=False, to_jax=False ): result = np.empty((0, 1)) for eqn in [self.model.concatenated_rhs, self.model.concatenated_algebraic]: - if simplify: - eqn = eqn.simplify() y = self.model.concatenated_initial_conditions.evaluate(t=0) if use_known_evals: @@ -155,8 +153,7 @@ def evaluate_model( return result - def set_up_model(self, simplify=False, to_python=False): - self.model.use_simplify = simplify + def set_up_model(self, to_python=False): if to_python is True: self.model.convert_to_format = "python" self.model.default_solver.set_up(self.model) diff --git a/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_composite.py b/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_composite.py index 6c0d5df552..8ca0247f48 100644 --- a/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_composite.py +++ b/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_composite.py @@ -28,22 +28,16 @@ def test_optimisations(self): optimtest = tests.OptimisationsTest(model) original = optimtest.evaluate_model() - simplified = optimtest.evaluate_model(simplify=True) using_known_evals = optimtest.evaluate_model(use_known_evals=True) - simp_and_known = optimtest.evaluate_model(simplify=True, use_known_evals=True) - simp_and_python = optimtest.evaluate_model(simplify=True, to_python=True) - np.testing.assert_array_almost_equal(original, simplified) + to_python = optimtest.evaluate_model(to_python=True) np.testing.assert_array_almost_equal(original, using_known_evals) - np.testing.assert_array_almost_equal(original, simp_and_known) - np.testing.assert_array_almost_equal(original, simp_and_python) + np.testing.assert_array_almost_equal(original, to_python) def test_set_up(self): model = pybamm.lead_acid.Composite() optimtest = tests.OptimisationsTest(model) - optimtest.set_up_model(simplify=False, to_python=True) - optimtest.set_up_model(simplify=True, to_python=True) - optimtest.set_up_model(simplify=False, to_python=False) - optimtest.set_up_model(simplify=True, to_python=False) + optimtest.set_up_model(to_python=True) + optimtest.set_up_model(to_python=False) def test_basic_processing_1plus1D(self): options = {"current collector": "potential pair", "dimensionality": 1} diff --git a/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_foqs.py b/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_foqs.py index 9d54dc86dc..c29b984d27 100644 --- a/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_foqs.py +++ b/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_foqs.py @@ -29,22 +29,16 @@ def test_optimisations(self): optimtest = tests.OptimisationsTest(model) original = optimtest.evaluate_model() - simplified = optimtest.evaluate_model(simplify=True) using_known_evals = optimtest.evaluate_model(use_known_evals=True) - simp_and_known = optimtest.evaluate_model(simplify=True, use_known_evals=True) - simp_and_python = optimtest.evaluate_model(simplify=True, to_python=True) - np.testing.assert_array_almost_equal(original, simplified) + to_python = optimtest.evaluate_model(to_python=True) np.testing.assert_array_almost_equal(original, using_known_evals) - np.testing.assert_array_almost_equal(original, simp_and_known) - np.testing.assert_array_almost_equal(original, simp_and_python) + np.testing.assert_array_almost_equal(original, to_python) def test_set_up(self): model = pybamm.lead_acid.FOQS() optimtest = tests.OptimisationsTest(model) - optimtest.set_up_model(simplify=False, to_python=True) - optimtest.set_up_model(simplify=True, to_python=True) - optimtest.set_up_model(simplify=False, to_python=False) - optimtest.set_up_model(simplify=True, to_python=False) + optimtest.set_up_model(to_python=True) + optimtest.set_up_model(to_python=False) class TestLeadAcidFOQSSurfaceForm(unittest.TestCase): diff --git a/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_full.py b/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_full.py index bc67a3617b..5c4a4fdfb1 100644 --- a/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_full.py +++ b/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_full.py @@ -31,23 +31,17 @@ def test_optimisations(self): optimtest = tests.OptimisationsTest(model) original = optimtest.evaluate_model() - simplified = optimtest.evaluate_model(simplify=True) using_known_evals = optimtest.evaluate_model(use_known_evals=True) - simp_and_known = optimtest.evaluate_model(simplify=True, use_known_evals=True) - simp_and_python = optimtest.evaluate_model(simplify=True, to_python=True) - np.testing.assert_array_almost_equal(original, simplified) + to_python = optimtest.evaluate_model(to_python=True) np.testing.assert_array_almost_equal(original, using_known_evals) - np.testing.assert_array_almost_equal(original, simp_and_known) - np.testing.assert_array_almost_equal(original, simp_and_python) + np.testing.assert_array_almost_equal(original, to_python) def test_set_up(self): options = {"thermal": "isothermal"} model = pybamm.lead_acid.Full(options) optimtest = tests.OptimisationsTest(model) - optimtest.set_up_model(simplify=False, to_python=True) - optimtest.set_up_model(simplify=True, to_python=True) - optimtest.set_up_model(simplify=False, to_python=False) - optimtest.set_up_model(simplify=True, to_python=False) + optimtest.set_up_model(to_python=True) + optimtest.set_up_model(to_python=False) def test_basic_processing_1plus1D(self): options = {"current collector": "potential pair", "dimensionality": 1} @@ -86,21 +80,15 @@ def test_optimisations(self): optimtest = tests.OptimisationsTest(model) original = optimtest.evaluate_model() - simplified = optimtest.evaluate_model(simplify=True) using_known_evals = optimtest.evaluate_model(use_known_evals=True) - simp_and_known = optimtest.evaluate_model(simplify=True, use_known_evals=True) - np.testing.assert_array_almost_equal(original, simplified, decimal=5) np.testing.assert_array_almost_equal(original, using_known_evals) - np.testing.assert_array_almost_equal(original, simp_and_known, decimal=5) def test_set_up(self): options = {"surface form": "differential"} model = pybamm.lead_acid.Full(options) optimtest = tests.OptimisationsTest(model) - optimtest.set_up_model(simplify=False, to_python=True) - optimtest.set_up_model(simplify=True, to_python=True) - # optimtest.set_up_model(simplify=False, to_python=False) - optimtest.set_up_model(simplify=True, to_python=False) + optimtest.set_up_model(to_python=True) + optimtest.set_up_model(to_python=False) def test_thermal(self): options = {"thermal": "lumped"} diff --git a/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_loqs.py b/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_loqs.py index 1aa83c6315..56268c0eff 100644 --- a/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_loqs.py +++ b/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_loqs.py @@ -19,22 +19,16 @@ def test_optimisations(self): optimtest = tests.OptimisationsTest(model) original = optimtest.evaluate_model() - simplified = optimtest.evaluate_model(simplify=True) using_known_evals = optimtest.evaluate_model(use_known_evals=True) - simp_and_known = optimtest.evaluate_model(simplify=True, use_known_evals=True) - simp_and_python = optimtest.evaluate_model(simplify=True, to_python=True) - np.testing.assert_array_almost_equal(original, simplified) + to_python = optimtest.evaluate_model(to_python=True) np.testing.assert_array_almost_equal(original, using_known_evals) - np.testing.assert_array_almost_equal(original, simp_and_known) - np.testing.assert_array_almost_equal(original, simp_and_python) + np.testing.assert_array_almost_equal(original, to_python) def test_set_up(self): model = pybamm.lead_acid.LOQS() optimtest = tests.OptimisationsTest(model) - optimtest.set_up_model(simplify=False, to_python=True) - optimtest.set_up_model(simplify=True, to_python=True) - optimtest.set_up_model(simplify=False, to_python=False) - optimtest.set_up_model(simplify=True, to_python=False) + optimtest.set_up_model(to_python=True) + optimtest.set_up_model(to_python=False) def test_charge(self): model = pybamm.lead_acid.LOQS() diff --git a/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_loqs_surface_form.py b/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_loqs_surface_form.py index 1e37c35358..7fcf30e235 100644 --- a/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_loqs_surface_form.py +++ b/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_loqs_surface_form.py @@ -49,23 +49,17 @@ def test_optimisations(self): optimtest = tests.OptimisationsTest(model) original = optimtest.evaluate_model() - simplified = optimtest.evaluate_model(simplify=True) using_known_evals = optimtest.evaluate_model(use_known_evals=True) - simp_and_known = optimtest.evaluate_model(simplify=True, use_known_evals=True) - simp_and_python = optimtest.evaluate_model(simplify=True, to_python=True) - np.testing.assert_array_almost_equal(original, simplified, decimal=5) + to_python = optimtest.evaluate_model(to_python=True) np.testing.assert_array_almost_equal(original, using_known_evals) - np.testing.assert_array_almost_equal(original, simp_and_known, decimal=5) - np.testing.assert_array_almost_equal(original, simp_and_python, decimal=5) + np.testing.assert_array_almost_equal(original, to_python, decimal=5) def test_set_up(self): options = {"surface form": "differential"} model = pybamm.lead_acid.LOQS(options) optimtest = tests.OptimisationsTest(model) - optimtest.set_up_model(simplify=False, to_python=True) - optimtest.set_up_model(simplify=True, to_python=True) - optimtest.set_up_model(simplify=False, to_python=False) - optimtest.set_up_model(simplify=True, to_python=False) + optimtest.set_up_model(to_python=True) + optimtest.set_up_model(to_python=False) if __name__ == "__main__": diff --git a/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_side_reactions/test_composite_side_reactions.py b/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_side_reactions/test_composite_side_reactions.py index e0528c4816..82740322b6 100644 --- a/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_side_reactions/test_composite_side_reactions.py +++ b/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_side_reactions/test_composite_side_reactions.py @@ -51,12 +51,8 @@ def test_optimisations(self): optimtest = tests.OptimisationsTest(model) original = optimtest.evaluate_model() - simplified = optimtest.evaluate_model(simplify=True) using_known_evals = optimtest.evaluate_model(use_known_evals=True) - simp_and_known = optimtest.evaluate_model(simplify=True, use_known_evals=True) - np.testing.assert_array_almost_equal(original, simplified, decimal=5) np.testing.assert_array_almost_equal(original, using_known_evals) - np.testing.assert_array_almost_equal(original, simp_and_known, decimal=5) if __name__ == "__main__": diff --git a/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_side_reactions/test_full_side_reactions.py b/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_side_reactions/test_full_side_reactions.py index b712f40910..89f7b5ab7b 100644 --- a/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_side_reactions/test_full_side_reactions.py +++ b/tests/integration/test_models/test_full_battery_models/test_lead_acid/test_side_reactions/test_full_side_reactions.py @@ -51,12 +51,8 @@ def test_optimisations(self): optimtest = tests.OptimisationsTest(model) original = optimtest.evaluate_model() - simplified = optimtest.evaluate_model(simplify=True) using_known_evals = optimtest.evaluate_model(use_known_evals=True) - simp_and_known = optimtest.evaluate_model(simplify=True, use_known_evals=True) - np.testing.assert_array_almost_equal(original, simplified, decimal=5) np.testing.assert_array_almost_equal(original, using_known_evals) - np.testing.assert_array_almost_equal(original, simp_and_known, decimal=5) if __name__ == "__main__": diff --git a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_dfn.py b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_dfn.py index 1608d30a4f..8e92f1f406 100644 --- a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_dfn.py +++ b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_dfn.py @@ -55,23 +55,16 @@ def test_optimisations(self): optimtest = tests.OptimisationsTest(model) original = optimtest.evaluate_model() - simplified = optimtest.evaluate_model(simplify=True) using_known_evals = optimtest.evaluate_model(use_known_evals=True) - simp_and_known = optimtest.evaluate_model(simplify=True, use_known_evals=True) - simp_and_python = optimtest.evaluate_model(simplify=True, to_python=True) - np.testing.assert_array_almost_equal(original, simplified) + to_python = optimtest.evaluate_model(to_python=True) np.testing.assert_array_almost_equal(original, using_known_evals) - np.testing.assert_array_almost_equal(original, simp_and_known) - - np.testing.assert_array_almost_equal(original, simp_and_python) + np.testing.assert_array_almost_equal(original, to_python) def test_set_up(self): model = pybamm.lithium_ion.DFN() optimtest = tests.OptimisationsTest(model) - optimtest.set_up_model(simplify=False, to_python=True) - optimtest.set_up_model(simplify=False, to_python=False) - optimtest.set_up_model(simplify=True, to_python=True) - optimtest.set_up_model(simplify=True, to_python=False) + optimtest.set_up_model(to_python=True) + optimtest.set_up_model(to_python=False) def test_full_thermal(self): options = {"thermal": "x-full"} diff --git a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spm.py b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spm.py index 08a3b8d248..1dccef4c48 100644 --- a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spm.py +++ b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spm.py @@ -54,26 +54,20 @@ def test_optimisations(self): optimtest = tests.OptimisationsTest(model) original = optimtest.evaluate_model() - simplified = optimtest.evaluate_model(simplify=True) using_known_evals = optimtest.evaluate_model(use_known_evals=True) - simp_and_known = optimtest.evaluate_model(simplify=True, use_known_evals=True) - simp_and_python = optimtest.evaluate_model(simplify=True, to_python=True) - np.testing.assert_array_almost_equal(original, simplified) + to_python = optimtest.evaluate_model(to_python=True) np.testing.assert_array_almost_equal(original, using_known_evals) - np.testing.assert_array_almost_equal(original, simp_and_known) - np.testing.assert_array_almost_equal(original, simp_and_python) + np.testing.assert_array_almost_equal(original, to_python) if system() != "Windows": - simp_and_jax = optimtest.evaluate_model(simplify=True, to_jax=True) - np.testing.assert_array_almost_equal(original, simp_and_jax) + to_jax = optimtest.evaluate_model(to_jax=True) + np.testing.assert_array_almost_equal(original, to_jax) def test_set_up(self): model = pybamm.lithium_ion.SPM() optimtest = tests.OptimisationsTest(model) - optimtest.set_up_model(simplify=False, to_python=True) - optimtest.set_up_model(simplify=True, to_python=True) - optimtest.set_up_model(simplify=False, to_python=False) - optimtest.set_up_model(simplify=True, to_python=False) + optimtest.set_up_model(to_python=True) + optimtest.set_up_model(to_python=False) def test_charge(self): options = {"thermal": "isothermal"} diff --git a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spme.py b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spme.py index 2ca68d887f..b651d7ffc2 100644 --- a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spme.py +++ b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spme.py @@ -61,26 +61,20 @@ def test_optimisations(self): optimtest = tests.OptimisationsTest(model) original = optimtest.evaluate_model() - simplified = optimtest.evaluate_model(simplify=True) using_known_evals = optimtest.evaluate_model(use_known_evals=True) - simp_and_known = optimtest.evaluate_model(simplify=True, use_known_evals=True) - simp_and_python = optimtest.evaluate_model(simplify=True, to_python=True) - np.testing.assert_array_almost_equal(original, simplified) + to_python = optimtest.evaluate_model(to_python=True) np.testing.assert_array_almost_equal(original, using_known_evals) - np.testing.assert_array_almost_equal(original, simp_and_known) - np.testing.assert_array_almost_equal(original, simp_and_python) + np.testing.assert_array_almost_equal(original, to_python) if system() != "Windows": - simp_and_jax = optimtest.evaluate_model(simplify=True, to_jax=True) - np.testing.assert_array_almost_equal(original, simp_and_jax) + to_jax = optimtest.evaluate_model(to_jax=True) + np.testing.assert_array_almost_equal(original, to_jax) def test_set_up(self): model = pybamm.lithium_ion.SPMe() optimtest = tests.OptimisationsTest(model) - optimtest.set_up_model(simplify=False, to_python=True) - optimtest.set_up_model(simplify=True, to_python=True) - optimtest.set_up_model(simplify=False, to_python=False) - optimtest.set_up_model(simplify=True, to_python=False) + optimtest.set_up_model(to_python=True) + optimtest.set_up_model(to_python=False) def test_thermal(self): pybamm.settings.debug_mode = True diff --git a/tests/unit/test_expression_tree/test_concatenations.py b/tests/unit/test_expression_tree/test_concatenations.py index 25e1574e63..6eb6018f78 100644 --- a/tests/unit/test_expression_tree/test_concatenations.py +++ b/tests/unit/test_expression_tree/test_concatenations.py @@ -312,20 +312,10 @@ def test_domain_error(self): with self.assertRaisesRegex(pybamm.DomainError, "domain cannot be empty"): pybamm.DomainConcatenation([a, b], None) - def test_numpy_concatenation_simplify(self): + def test_numpy_concatenation(self): a = pybamm.Variable("a") b = pybamm.Variable("b") c = pybamm.Variable("c") - # simplifying flattens the concatenations into a single concatenation - self.assertEqual( - pybamm.NumpyConcatenation(pybamm.NumpyConcatenation(a, b), c).simplify().id, - pybamm.NumpyConcatenation(a, b, c).id, - ) - self.assertEqual( - pybamm.NumpyConcatenation(a, pybamm.NumpyConcatenation(b, c)).simplify().id, - pybamm.NumpyConcatenation(a, b, c).id, - ) - # check it works when calling numpy_concatenation self.assertEqual( pybamm.numpy_concatenation(pybamm.numpy_concatenation(a, b), c).id, pybamm.NumpyConcatenation(a, b, c).id, diff --git a/tests/unit/test_expression_tree/test_d_dt.py b/tests/unit/test_expression_tree/test_d_dt.py index 06309d1272..ce8f86d1e9 100644 --- a/tests/unit/test_expression_tree/test_d_dt.py +++ b/tests/unit/test_expression_tree/test_d_dt.py @@ -19,11 +19,9 @@ def test_time_derivative(self): a = (pybamm.t ** 2).diff(pybamm.t) self.assertEqual(a.id, (2 * pybamm.t ** 1 * 1).id) - self.assertEqual(a.simplify().id, (2 * pybamm.t).id) self.assertEqual(a.evaluate(t=1), 2) a = (2 + pybamm.t ** 2).diff(pybamm.t) - self.assertEqual(a.simplify().id, (2 * pybamm.t).id) self.assertEqual(a.evaluate(t=1), 2) def test_time_derivative_of_variable(self): @@ -35,10 +33,9 @@ def test_time_derivative_of_variable(self): p = pybamm.Parameter("p") a = 1 + p * pybamm.Variable("a") diff_a = a.diff(pybamm.t) - diff_a_simp = diff_a.simplify() - self.assertIsInstance(diff_a_simp, pybamm.Multiplication) - self.assertEqual(diff_a_simp.children[0].name, "p") - self.assertEqual(diff_a_simp.children[1].name, "a'") + self.assertIsInstance(diff_a, pybamm.Multiplication) + self.assertEqual(diff_a.children[0].name, "p") + self.assertEqual(diff_a.children[1].name, "a'") with self.assertRaises(pybamm.ModelError): a = (pybamm.Variable("a")).diff(pybamm.t).diff(pybamm.t) diff --git a/tests/unit/test_expression_tree/test_functions.py b/tests/unit/test_expression_tree/test_functions.py index 14901582fc..35b285b8a0 100644 --- a/tests/unit/test_expression_tree/test_functions.py +++ b/tests/unit/test_expression_tree/test_functions.py @@ -178,11 +178,6 @@ def test_cos(self): places=5, ) - # test simplify - y = pybamm.StateVector(slice(0, 1)) - fun = pybamm.cos(y) - self.assertEqual(fun.id, fun.simplify().id) - def test_cosh(self): a = pybamm.InputParameter("a") fun = pybamm.cosh(a) diff --git a/tests/unit/test_expression_tree/test_interpolant.py b/tests/unit/test_expression_tree/test_interpolant.py index 999c440d41..02076633d7 100644 --- a/tests/unit/test_expression_tree/test_interpolant.py +++ b/tests/unit/test_expression_tree/test_interpolant.py @@ -123,7 +123,6 @@ def test_processing(self): interp = pybamm.Interpolant(x, 2 * x, y) self.assertEqual(interp.id, interp.new_copy().id) - self.assertEqual(interp.id, interp.simplify().id) if __name__ == "__main__": diff --git a/tests/unit/test_expression_tree/test_operations/test_jac.py b/tests/unit/test_expression_tree/test_operations/test_jac.py index 83851d592f..2f641a7b7e 100644 --- a/tests/unit/test_expression_tree/test_operations/test_jac.py +++ b/tests/unit/test_expression_tree/test_operations/test_jac.py @@ -51,7 +51,7 @@ def test_linear(self): A = pybamm.Matrix(2 * eye(2)) func = A @ u jacobian = np.array([[2, 0, 0, 0], [0, 2, 0, 0]]) - dfunc_dy = func.jac(y).simplify().evaluate(y=y0) + dfunc_dy = func.jac(y).evaluate(y=y0) np.testing.assert_array_equal(jacobian, dfunc_dy.toarray()) func = u @ pybamm.StateVector(slice(0, 1)) diff --git a/tests/unit/test_expression_tree/test_operations/test_jac_2D.py b/tests/unit/test_expression_tree/test_operations/test_jac_2D.py index 5fea85ee74..e118435181 100644 --- a/tests/unit/test_expression_tree/test_operations/test_jac_2D.py +++ b/tests/unit/test_expression_tree/test_operations/test_jac_2D.py @@ -79,7 +79,7 @@ def test_linear(self): [0, 0, 0, 0, 0, 2, 0, 0], ] ) - dfunc_dy = func.jac(y).simplify().evaluate(y=y0) + dfunc_dy = func.jac(y).evaluate(y=y0) np.testing.assert_array_equal(jacobian, dfunc_dy.toarray()) # when differentiating by independent part of the state vector diff --git a/tests/unit/test_expression_tree/test_operations/test_simplify.py b/tests/unit/test_expression_tree/test_operations/test_simplify.py deleted file mode 100644 index c1d5fcce6c..0000000000 --- a/tests/unit/test_expression_tree/test_operations/test_simplify.py +++ /dev/null @@ -1,556 +0,0 @@ -# -# Test for the Simplify class -# This test file is a little bit out of date now that many simplifications are -# performed automatically -# -import math -import numpy as np -import pybamm -import unittest -from tests import get_discretisation_for_testing - - -class TestSimplify(unittest.TestCase): - def test_symbol_simplify(self): - a = pybamm.Scalar(0, domain="domain") - b = pybamm.Scalar(1) - c = pybamm.Parameter("c") - d = pybamm.Scalar(-1) - e = pybamm.Scalar(2) - g = pybamm.Variable("g") - gdot = pybamm.VariableDot("g'") - - # function - def sin(x): - return math.sin(x) - - f = pybamm.Function(sin, b) - self.assertIsInstance((f).simplify(), pybamm.Scalar) - self.assertEqual((f).simplify().evaluate(), math.sin(1)) - - def myfunction(x, y): - return x * y - - f = pybamm.Function(myfunction, a, b) - self.assertIsInstance((f).simplify(), pybamm.Scalar) - self.assertEqual((f).simplify().evaluate(), 0) - - # FunctionParameter - f = pybamm.FunctionParameter("function", {"b": b}) - self.assertIsInstance((f).simplify(), pybamm.FunctionParameter) - self.assertEqual((f).simplify().children[0].id, b.id) - - f = pybamm.FunctionParameter("function", {"a": a, "b": b}) - self.assertIsInstance((f).simplify(), pybamm.FunctionParameter) - self.assertEqual((f).simplify().children[0].id, a.id) - self.assertEqual((f).simplify().children[1].id, b.id) - - # Gradient - self.assertIsInstance((pybamm.grad(a)).simplify(), pybamm.Scalar) - self.assertEqual((pybamm.grad(a)).simplify().evaluate(), 0) - v = pybamm.Variable("v", domain="domain") - grad_v = pybamm.grad(v) - self.assertIsInstance(grad_v.simplify(), pybamm.Gradient) - - # Divergence - div_b = pybamm.div(pybamm.PrimaryBroadcastToEdges(b, "domain")) - self.assertIsInstance(div_b.simplify(), pybamm.PrimaryBroadcast) - self.assertEqual(div_b.simplify().child.child.evaluate(), 0) - self.assertIsInstance( - (pybamm.div(pybamm.grad(v))).simplify(), pybamm.Divergence - ) - - # Integral - self.assertIsInstance( - ( - pybamm.Integral(a, pybamm.SpatialVariable("x", domain="domain")) - ).simplify(), - pybamm.Integral, - ) - - def_int = (pybamm.DefiniteIntegralVector(a, vector_type="column")).simplify() - self.assertIsInstance(def_int, pybamm.DefiniteIntegralVector) - self.assertEqual(def_int.vector_type, "column") - - bound_int = (pybamm.BoundaryIntegral(a, region="negative tab")).simplify() - self.assertIsInstance(bound_int, pybamm.BoundaryIntegral) - self.assertEqual(bound_int.region, "negative tab") - - # BoundaryValue - v_neg = pybamm.Variable("v", domain=["negative electrode"]) - self.assertIsInstance( - (pybamm.boundary_value(v_neg, "right")).simplify(), pybamm.BoundaryValue - ) - - # Delta function - self.assertIsInstance( - (pybamm.DeltaFunction(v_neg, "right", "domain")).simplify(), - pybamm.DeltaFunction, - ) - - # More complex expressions - expr = (e * (e * c)).simplify() - self.assertIsInstance(expr, pybamm.Multiplication) - self.assertIsInstance(expr.children[0], pybamm.Scalar) - self.assertIsInstance(expr.children[1], pybamm.Parameter) - - expr = (e / (e * c)).simplify() - self.assertIsInstance(expr, pybamm.Division) - self.assertIsInstance(expr.children[0], pybamm.Scalar) - self.assertEqual(expr.children[0].evaluate(), 1.0) - self.assertIsInstance(expr.children[1], pybamm.Parameter) - - expr = (e * (e / c)).simplify() - self.assertIsInstance(expr, pybamm.Division) - self.assertIsInstance(expr.children[0], pybamm.Scalar) - self.assertEqual(expr.children[0].evaluate(), 4.0) - self.assertIsInstance(expr.children[1], pybamm.Parameter) - - expr = (e * (c / e)).simplify() - self.assertEqual(expr.id, c.id) - - expr = ((e * c) * (c / e)).simplify() - self.assertIsInstance(expr, pybamm.Multiplication) - self.assertIsInstance(expr.children[0], pybamm.Parameter) - self.assertIsInstance(expr.children[1], pybamm.Parameter) - - expr = (e + (e + c)).simplify() - self.assertIsInstance(expr, pybamm.Addition) - self.assertIsInstance(expr.children[0], pybamm.Scalar) - self.assertEqual(expr.children[0].evaluate(), 4.0) - self.assertIsInstance(expr.children[1], pybamm.Parameter) - - expr = (e + (e - c)).simplify() - self.assertIsInstance(expr, pybamm.Addition) - self.assertIsInstance(expr.children[0], pybamm.Scalar) - self.assertEqual(expr.children[0].evaluate(), 4.0) - self.assertIsInstance(expr.children[1], pybamm.Negate) - self.assertIsInstance(expr.children[1].children[0], pybamm.Parameter) - - expr = (e * g * b).simplify() - self.assertIsInstance(expr, pybamm.Multiplication) - self.assertIsInstance(expr.children[0], pybamm.Scalar) - self.assertEqual(expr.children[0].evaluate(), 2.0) - self.assertIsInstance(expr.children[1], pybamm.Variable) - - expr = (e * gdot * b).simplify() - self.assertIsInstance(expr, pybamm.Multiplication) - self.assertIsInstance(expr.children[0], pybamm.Scalar) - self.assertEqual(expr.children[0].evaluate(), 2.0) - self.assertIsInstance(expr.children[1], pybamm.VariableDot) - - expr = (e + (g - c)).simplify() - self.assertIsInstance(expr, pybamm.Addition) - self.assertIsInstance(expr.children[0], pybamm.Scalar) - self.assertEqual(expr.children[0].evaluate(), 2.0) - self.assertIsInstance(expr.children[1], pybamm.Subtraction) - self.assertIsInstance(expr.children[1].children[0], pybamm.Variable) - self.assertIsInstance(expr.children[1].children[1], pybamm.Parameter) - - expr = ((2 + c) + (c + 2)).simplify() - self.assertIsInstance(expr, pybamm.Addition) - self.assertIsInstance(expr.children[0], pybamm.Scalar) - self.assertEqual(expr.children[0].evaluate(), 4.0) - self.assertIsInstance(expr.children[1], pybamm.Multiplication) - self.assertIsInstance(expr.children[1].children[0], pybamm.Scalar) - self.assertEqual(expr.children[1].children[0].evaluate(), 2) - self.assertIsInstance(expr.children[1].children[1], pybamm.Parameter) - - expr = ((-1 + c) - (c + 1) + (c - 1)).simplify() - self.assertIsInstance(expr, pybamm.Addition) - self.assertIsInstance(expr.children[0], pybamm.Scalar) - self.assertEqual(expr.children[0].evaluate(), -3.0) - - # check these don't simplify - self.assertIsInstance((c * e).simplify(), pybamm.Multiplication) - self.assertIsInstance((e / c).simplify(), pybamm.Division) - self.assertIsInstance((c).simplify(), pybamm.Parameter) - c1 = pybamm.Parameter("c1") - self.assertIsInstance((c1 * c).simplify(), pybamm.Multiplication) - - # should simplify division to multiply - self.assertIsInstance((c / e).simplify(), pybamm.Multiplication) - - self.assertIsInstance((c / b).simplify(), pybamm.Parameter) - self.assertIsInstance((c * b).simplify(), pybamm.Parameter) - - # negation with parameter - self.assertIsInstance((-c).simplify(), pybamm.Negate) - - self.assertIsInstance((a + b + a).simplify(), pybamm.Scalar) - self.assertEqual((a + b + a).simplify().evaluate(), 1) - self.assertIsInstance((b + a + a).simplify(), pybamm.Scalar) - self.assertEqual((b + a + a).simplify().evaluate(), 1) - self.assertIsInstance((a * b * b).simplify(), pybamm.Scalar) - self.assertEqual((a * b * b).simplify().evaluate(), 0) - self.assertIsInstance((b * a * b).simplify(), pybamm.Scalar) - self.assertEqual((b * a * b).simplify().evaluate(), 0) - - # power simplification - self.assertIsInstance((c ** a).simplify(), pybamm.Scalar) - self.assertEqual((c ** a).simplify().evaluate(), 1) - self.assertIsInstance((a ** c).simplify(), pybamm.Scalar) - self.assertEqual((a ** c).simplify().evaluate(), 0) - d = pybamm.Scalar(2) - self.assertIsInstance((c ** d).simplify(), pybamm.Power) - - # division - self.assertIsInstance((a / b).simplify(), pybamm.Scalar) - self.assertEqual((a / b).simplify().evaluate(), 0) - self.assertIsInstance((b / b).simplify(), pybamm.Scalar) - self.assertEqual((b / b).simplify().evaluate(), 1) - - with self.assertRaises(ZeroDivisionError): - b / a - - # not implemented for Symbol - sym = pybamm.Symbol("sym") - with self.assertRaises(NotImplementedError): - sym.simplify() - - # A + A = 2A (#323) - a = pybamm.Parameter("A") - expr = (a + a).simplify() - self.assertIsInstance(expr, pybamm.Multiplication) - self.assertIsInstance(expr.children[0], pybamm.Scalar) - self.assertEqual(expr.children[0].evaluate(), 2) - self.assertIsInstance(expr.children[1], pybamm.Parameter) - - expr = (a + a + a + a).simplify() - self.assertIsInstance(expr, pybamm.Multiplication) - self.assertIsInstance(expr.children[0], pybamm.Scalar) - self.assertEqual(expr.children[0].evaluate(), 4) - self.assertIsInstance(expr.children[1], pybamm.Parameter) - - expr = (a - a + a - a + a + a).simplify() - self.assertIsInstance(expr, pybamm.Multiplication) - self.assertIsInstance(expr.children[0], pybamm.Scalar) - self.assertEqual(expr.children[0].evaluate(), 2) - self.assertIsInstance(expr.children[1], pybamm.Parameter) - - # A - A = 0 (#323) - expr = (a - a).simplify() - self.assertIsInstance(expr, pybamm.Scalar) - self.assertEqual(expr.evaluate(), 0) - - # B - (A+A) = B - 2*A (#323) - expr = (b - (a + a)).simplify() - self.assertIsInstance(expr, pybamm.Addition) - self.assertIsInstance(expr.right, pybamm.Negate) - self.assertIsInstance(expr.right.child, pybamm.Multiplication) - self.assertEqual(expr.right.child.left.id, pybamm.Scalar(2).id) - self.assertEqual(expr.right.child.right.id, a.id) - - # B - (1*A + 2*A) = B - 3*A (#323) - expr = (b - (1 * a + 2 * a)).simplify() - self.assertIsInstance(expr, pybamm.Addition) - self.assertIsInstance(expr.right, pybamm.Negate) - self.assertIsInstance(expr.right.child, pybamm.Multiplication) - self.assertEqual(expr.right.child.left.id, pybamm.Scalar(3).id) - self.assertEqual(expr.right.child.right.id, a.id) - - # B - (A + C) = B - (A + C) (not B - (A - C)) - expr = (b - (a + c)).simplify() - self.assertIsInstance(expr, pybamm.Addition) - self.assertIsInstance(expr.right, pybamm.Subtraction) - self.assertEqual(expr.right.left.id, (-a).id) - self.assertEqual(expr.right.right.id, c.id) - - def test_vector_zero_simplify(self): - a1 = pybamm.Scalar(0) - v1 = pybamm.Vector(np.zeros(10)) - a2 = pybamm.Scalar(1) - v2 = pybamm.Vector(np.ones(10)) - - for expr in [a1 * v1, v1 * a1, a2 * v1, v1 * a2, a1 * v2, v2 * a1, v1 * v2]: - self.assertIsInstance(expr.simplify(), pybamm.Vector) - np.testing.assert_array_equal(expr.simplify().entries, np.zeros((10, 1))) - - def test_matrix_simplifications(self): - a = pybamm.Matrix(np.zeros((2, 2))) - b = pybamm.Matrix(np.ones((2, 2))) - - # matrix multiplication - A = pybamm.Matrix([[1, 0], [0, 1]]) - self.assertIsInstance((a @ A).simplify(), pybamm.Matrix) - np.testing.assert_array_equal( - (a @ A).simplify().evaluate().toarray(), np.zeros((2, 2)) - ) - self.assertIsInstance((A @ a).simplify(), pybamm.Matrix) - np.testing.assert_array_equal( - (A @ a).simplify().evaluate().toarray(), np.zeros((2, 2)) - ) - - # matrix * matrix - m1 = pybamm.Matrix([[2, 0], [0, 2]]) - m2 = pybamm.Matrix([[3, 0], [0, 3]]) - v = pybamm.StateVector(slice(0, 2)) - v2 = pybamm.StateVector(slice(2, 4)) - - for expr in [((m2 @ m1) @ v).simplify(), (m2 @ (m1 @ v)).simplify()]: - self.assertIsInstance(expr.children[0], pybamm.Matrix) - self.assertIsInstance(expr.children[1], pybamm.StateVector) - np.testing.assert_array_equal( - expr.children[0].entries, np.array([[6, 0], [0, 6]]) - ) - - # div by a constant - for expr in [((m2 @ m1) @ v / 2).simplify(), (m2 @ (m1 @ v) / 2).simplify()]: - self.assertIsInstance(expr.children[0], pybamm.Matrix) - self.assertIsInstance(expr.children[1], pybamm.StateVector) - np.testing.assert_array_equal( - expr.children[0].entries, np.array([[3, 0], [0, 3]]) - ) - - expr = ((v * v) / 2).simplify() - self.assertIsInstance(expr, pybamm.Multiplication) - self.assertIsInstance(expr.children[0], pybamm.Scalar) - self.assertEqual(expr.children[0].evaluate(), 0.5) - self.assertIsInstance(expr.children[1], pybamm.Multiplication) - - # mat-mul on numerator and denominator - expr = (m2 @ (m1 @ v) / (m2 @ (m1 @ v))).simplify() - for child in expr.children: - self.assertIsInstance(child.children[0], pybamm.Matrix) - self.assertIsInstance(child.children[1], pybamm.StateVector) - np.testing.assert_array_equal( - child.children[0].entries, np.array([[6, 0], [0, 6]]) - ) - - # mat-mul just on denominator - expr = (1 / (m2 @ (m1 @ v))).simplify() - self.assertIsInstance(expr.children[1].children[0], pybamm.Matrix) - self.assertIsInstance(expr.children[1].children[1], pybamm.StateVector) - np.testing.assert_array_equal( - expr.children[1].children[0].entries, np.array([[6, 0], [0, 6]]) - ) - expr = (v2 / (m2 @ (m1 @ v))).simplify() - self.assertIsInstance(expr.children[0], pybamm.StateVector) - self.assertIsInstance(expr.children[1].children[0], pybamm.Matrix) - self.assertIsInstance(expr.children[1].children[1], pybamm.StateVector) - np.testing.assert_array_equal( - expr.children[1].children[0].entries, np.array([[6, 0], [0, 6]]) - ) - - # scalar * matrix - b = pybamm.Scalar(1) - for expr in [ - ((b * m1) @ v).simplify(), - (b * (m1 @ v)).simplify(), - ((m1 * b) @ v).simplify(), - (m1 @ (b * v)).simplify(), - ]: - self.assertIsInstance(expr.children[0], pybamm.Matrix) - self.assertIsInstance(expr.children[1], pybamm.StateVector) - np.testing.assert_array_equal( - expr.children[0].entries, np.array([[2, 0], [0, 2]]) - ) - - # matrix * vector - m1 = pybamm.Matrix([[2, 0], [0, 2]]) - v1 = pybamm.Vector([1, 1]) - - for expr in [(m1 @ v1).simplify()]: - self.assertIsInstance(expr, pybamm.Vector) - np.testing.assert_array_equal(expr.entries, np.array([[2], [2]])) - - # dont expant mult within mat-mult (issue #253) - m1 = pybamm.Matrix(np.ones((300, 299))) - m2 = pybamm.Matrix(np.ones((299, 300))) - m3 = pybamm.Matrix(np.ones((300, 300))) - v1 = pybamm.StateVector(slice(0, 299)) - v2 = pybamm.StateVector(slice(0, 300)) - v3 = pybamm.Vector(np.ones(299)) - - expr = m1 @ (v1 * m2) - self.assertEqual( - expr.simplify().evaluate(y=np.ones((299, 1))).shape, (300, 300) - ) - np.testing.assert_array_equal( - expr.evaluate(y=np.ones((299, 1))), - expr.simplify().evaluate(y=np.ones((299, 1))), - ) - - # more complex expression - expr2 = m1 @ (v1 * (m2 @ v2)) - expr2simp = expr2.simplify() - np.testing.assert_array_equal( - expr2.evaluate(y=np.ones(300)), expr2simp.evaluate(y=np.ones(300)) - ) - self.assertEqual(expr2.id, expr2simp.id) - - expr3 = m1 @ ((m2 @ v2) * (m2 @ v2)) - expr3simp = expr3.simplify() - self.assertEqual(expr3.id, expr3simp.id) - - # more complex expression, with simplification - expr3 = m1 @ (v3 * (m2 @ v2)) - expr3simp = expr3.simplify() - self.assertEqual(expr3.id, expr3simp.id) - np.testing.assert_array_equal( - expr3.evaluate(y=np.ones(300)), expr3simp.evaluate(y=np.ones(300)) - ) - - m1 = pybamm.Matrix(np.ones((300, 300))) - m2 = pybamm.Matrix(np.ones((300, 300))) - m3 = pybamm.Matrix(np.ones((300, 300))) - m4 = pybamm.Matrix(np.ones((300, 300))) - v1 = pybamm.StateVector(slice(0, 300)) - v2 = pybamm.StateVector(slice(300, 600)) - v3 = pybamm.StateVector(slice(600, 900)) - v4 = pybamm.StateVector(slice(900, 1200)) - expr4 = (m1 @ v1) * ((m2 @ v2) / (m3 @ v3) - m4 @ v4) - expr4simp = expr4.simplify() - self.assertEqual(expr4.id, expr4simp.id) - - m2 = pybamm.Matrix(np.ones((299, 300))) - v2 = pybamm.StateVector(slice(0, 300)) - v3 = pybamm.Vector(np.ones(299)) - exprs = [(m2 @ v2) * v3, (m2 @ v2) / v3] - for expr in exprs: - exprsimp = expr.simplify() - self.assertIsInstance(exprsimp, pybamm.MatrixMultiplication) - self.assertIsInstance(exprsimp.children[0], pybamm.Matrix) - self.assertIsInstance(exprsimp.children[1], pybamm.StateVector) - np.testing.assert_array_equal( - expr.evaluate(y=np.ones(300)), exprsimp.evaluate(y=np.ones(300)) - ) - - # A + A = 2A (#323) - a = pybamm.StateVector(slice(0, 300)) - expr = (a + a).simplify() - self.assertIsInstance(expr, pybamm.Multiplication) - self.assertIsInstance(expr.children[0], pybamm.Scalar) - self.assertEqual(expr.children[0].evaluate(), 2) - self.assertIsInstance(expr.children[1], pybamm.StateVector) - - expr = (a + a + a + a).simplify() - self.assertIsInstance(expr, pybamm.Multiplication) - self.assertIsInstance(expr.children[0], pybamm.Scalar) - self.assertEqual(expr.children[0].evaluate(), 4) - self.assertIsInstance(expr.children[1], pybamm.StateVector) - - # A - A = 0 (#323) - expr = (a - a).simplify() - self.assertIsInstance(expr, pybamm.Vector) - self.assertEqual(expr.shape, a.shape) - np.testing.assert_array_equal(expr.evaluate(), 0) - - # zero matrix - m1 = pybamm.Matrix(np.zeros((300, 300))) - for expr in [m1 * v1, v1 * m1]: - expr_simp = expr.simplify() - self.assertIsInstance(expr_simp, pybamm.Matrix) - np.testing.assert_array_equal( - expr_simp.evaluate(y=np.ones(300)).toarray(), m1.evaluate() - ) - - # adding zero - m2 = pybamm.Matrix(np.random.rand(300, 300)) - for expr in [m1 + m2, m2 + m1]: - expr_simp = expr.simplify() - self.assertIsInstance(expr_simp, pybamm.Matrix) - np.testing.assert_array_equal( - expr_simp.evaluate(y=np.ones(300)), m2.evaluate() - ) - - # subtracting zero - for expr in [m1 - m2, -m2 - m1]: - expr_simp = expr.simplify() - self.assertIsInstance(expr_simp, pybamm.Matrix) - np.testing.assert_array_equal( - expr_simp.evaluate(y=np.ones(300)), -m2.evaluate() - ) - - def test_matrix_divide_simplify(self): - m = pybamm.Matrix(np.random.rand(30, 20)) - zero = pybamm.Scalar(0) - - expr1 = (zero / m).simplify() - self.assertIsInstance(expr1, pybamm.Matrix) - self.assertEqual(expr1.shape, m.shape) - np.testing.assert_array_equal(expr1.evaluate().toarray(), np.zeros((30, 20))) - - m = pybamm.Matrix(np.zeros((10, 10))) - a = pybamm.Scalar(7) - expr3 = (m / a).simplify() - self.assertIsInstance(expr3, pybamm.Matrix) - self.assertEqual(expr3.shape, m.shape) - np.testing.assert_array_equal(expr3.evaluate().toarray(), np.zeros((10, 10))) - - def test_domain_concatenation_simplify(self): - # create discretisation - disc = get_discretisation_for_testing() - mesh = disc.mesh - - a_dom = ["negative electrode"] - b_dom = ["positive electrode"] - a = 2 * pybamm.Vector(np.ones_like(mesh[a_dom[0]].nodes), domain=a_dom) - b = pybamm.Vector(np.ones_like(mesh[b_dom[0]].nodes), domain=b_dom) - - conc = pybamm.DomainConcatenation([a, b], mesh) - conc_simp = conc.simplify() - - # should be simplified to a vector - self.assertIsInstance(conc_simp, pybamm.Vector) - np.testing.assert_array_equal( - conc_simp.evaluate(), - np.concatenate( - [ - np.full((mesh[a_dom[0]].npts, 1), 2), - np.full((mesh[b_dom[0]].npts, 1), 1), - ] - ), - ) - - # check it works when calling domain_concatenation - self.assertIsInstance(pybamm.domain_concatenation([a, b], mesh), pybamm.Vector) - - def test_simplify_concatenation_state_vectors(self): - disc = get_discretisation_for_testing() - mesh = disc.mesh - - a = pybamm.Variable("a", domain=["negative electrode"]) - b = pybamm.Variable("b", domain=["separator"]) - c = pybamm.Variable("c", domain=["positive electrode"]) - conc = pybamm.Concatenation(a, b, c) - disc.set_variable_slices([a, b, c]) - conc_disc = disc.process_symbol(conc) - conc_simp = conc_disc.simplify() - - y = mesh.combine_submeshes(*conc.domain).nodes ** 2 - self.assertIsInstance(conc_simp, pybamm.StateVector) - self.assertEqual(len(conc_simp.y_slices), 1) - self.assertEqual(conc_simp.y_slices[0].start, 0) - self.assertEqual(conc_simp.y_slices[0].stop, len(y)) - np.testing.assert_array_equal(conc_disc.evaluate(y=y), conc_simp.evaluate(y=y)) - - def test_simplify_broadcast(self): - v = pybamm.StateVector(slice(0, 1)) - broad = pybamm.PrimaryBroadcast(v, "test") - broad_simp = broad.simplify() - self.assertEqual(broad_simp.id, broad.id) - - def test_simplify_heaviside(self): - a = pybamm.Scalar(1) - b = pybamm.Scalar(2) - self.assertEqual((a < b).simplify().id, pybamm.Scalar(1).id) - self.assertEqual((a >= b).simplify().id, pybamm.Scalar(0).id) - - def test_simplify_inner(self): - a1 = pybamm.Scalar(0) - M2 = pybamm.Matrix(np.ones((10, 10))) - - expr = pybamm.Inner(a1, M2).simplify() - self.assertIsInstance(expr, pybamm.Matrix) - np.testing.assert_array_equal(expr.entries.toarray(), np.zeros((10, 10))) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_expression_tree/test_symbol.py b/tests/unit/test_expression_tree/test_symbol.py index 00a1092034..5aef095c33 100644 --- a/tests/unit/test_expression_tree/test_symbol.py +++ b/tests/unit/test_expression_tree/test_symbol.py @@ -292,6 +292,14 @@ def test_symbol_evaluates_to_constant_number(self): a = 3 * pybamm.t + 2 self.assertFalse(a.evaluates_to_constant_number()) + def test_simplify(self): + a = pybamm.Parameter("A") + #test error + with self.assertRaisesRegex( + pybamm.ModelError, "simplify is deprecated as it now has no effect" + ): + (a + a).simplify() + def test_symbol_repr(self): """ test that __repr___ returns the string diff --git a/tests/unit/test_models/test_base_model.py b/tests/unit/test_models/test_base_model.py index 84d0a148ad..121c30b9cd 100644 --- a/tests/unit/test_models/test_base_model.py +++ b/tests/unit/test_models/test_base_model.py @@ -246,13 +246,11 @@ def test_new_copy(self): d: {"left": (0, "Dirichlet"), "right": (0, "Dirichlet")}, } model.use_jacobian = False - model.use_simplify = False model.convert_to_format = "python" new_model = model.new_copy() self.assertEqual(new_model.name, model.name) self.assertEqual(new_model.use_jacobian, model.use_jacobian) - self.assertEqual(new_model.use_simplify, model.use_simplify) self.assertEqual(new_model.convert_to_format, model.convert_to_format) self.assertEqual(new_model.timescale.id, model.timescale.id) diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spm.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spm.py index d837cf09d3..ef27b090da 100644 --- a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spm.py +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_spm.py @@ -163,7 +163,6 @@ def test_new_model(self): self.assertEqual(new_model_T_eqn.id, model_T_eqn.id) self.assertEqual(new_model.name, model.name) self.assertEqual(new_model.use_jacobian, model.use_jacobian) - self.assertEqual(new_model.use_simplify, model.use_simplify) self.assertEqual(new_model.convert_to_format, model.convert_to_format) self.assertEqual(new_model.timescale.id, model.timescale.id) diff --git a/tests/unit/test_parameters/test_current_functions.py b/tests/unit/test_parameters/test_current_functions.py index 2222d3d3b8..8edc63dd75 100644 --- a/tests/unit/test_parameters/test_current_functions.py +++ b/tests/unit/test_parameters/test_current_functions.py @@ -20,7 +20,7 @@ def test_constant_current(self): } ) processed_current = parameter_values.process_symbol(current) - self.assertIsInstance(processed_current.simplify(), pybamm.Scalar) + self.assertIsInstance(processed_current, pybamm.Scalar) def test_get_current_data(self): # test process parameters