Skip to content

Commit

Permalink
Merge pull request #780 from pybamm-team/issue-749-cache-symbol-shape
Browse files Browse the repository at this point in the history
Issue 749 cache symbol shape
  • Loading branch information
valentinsulzer authored Jan 10, 2020
2 parents b48074f + 4e93f1b commit 23aea3d
Show file tree
Hide file tree
Showing 12 changed files with 31 additions and 22 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

## Optimizations

- Added caching for shape evaluation, used during discretisation ([#780](https://github.com/pybamm-team/PyBaMM/pull/780))
- Added an option to skip model checks during discretisation, which could be slow for large models ([#739](https://github.com/pybamm-team/PyBaMM/pull/739))
- Use CasADi's automatic differentation algorithms by default when solving a model ([#714](https://github.com/pybamm-team/PyBaMM/pull/714))
- Avoid re-checking size when making a copy of an `Index` object ([#656](https://github.com/pybamm-team/PyBaMM/pull/656))
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def evaluate(self, t=None, y=None, u=None, known_evals=None):
right = self.right.evaluate(t, y, u)
return self._binary_evaluate(left, right)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
""" See :meth:`pybamm.Symbol.evaluate_for_shape()`. """
left = self.children[0].evaluate_for_shape()
right = self.children[1].evaluate_for_shape()
Expand Down
6 changes: 3 additions & 3 deletions pybamm/expression_tree/broadcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _unary_new_copy(self, child):
""" See :meth:`pybamm.UnaryOperator.simplify()`. """
return PrimaryBroadcast(child, self.broadcast_domain)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Returns a vector of NaNs to represent the shape of a Broadcast.
See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`
Expand Down Expand Up @@ -210,7 +210,7 @@ def _unary_new_copy(self, child):
""" See :meth:`pybamm.UnaryOperator.simplify()`. """
return SecondaryBroadcast(child, self.broadcast_domain)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Returns a vector of NaNs to represent the shape of a Broadcast.
See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`
Expand Down Expand Up @@ -253,7 +253,7 @@ def _unary_new_copy(self, child):
""" See :meth:`pybamm.UnaryOperator.simplify()`. """
return FullBroadcast(child, self.broadcast_domain, self.auxiliary_domains)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Returns a vector of NaNs to represent the shape of a Broadcast.
See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _concatenation_simplify(self, children):
new_symbol.clear_domains()
return new_symbol

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
""" See :meth:`pybamm.Symbol.evaluate_for_shape` """
if len(self.children) == 0:
return np.array([])
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def evaluate(self, t=None, y=None, u=None, known_evals=None):
evaluated_children = [child.evaluate(t, y, u) for child in self.children]
return self._function_evaluate(evaluated_children)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Default behaviour: has same shape as all child
See :meth:`pybamm.Symbol.evaluate_for_shape()`
Expand Down
4 changes: 2 additions & 2 deletions pybamm/expression_tree/independent_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class IndependentVariable(pybamm.Symbol):
def __init__(self, name, domain=None, auxiliary_domains=None):
super().__init__(name, domain=domain, auxiliary_domains=auxiliary_domains)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
""" See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()` """
return pybamm.evaluate_for_shape_using_domain(
self.domain, self.auxiliary_domains
Expand Down Expand Up @@ -57,7 +57,7 @@ def _base_evaluate(self, t, y=None):
raise ValueError("t must be provided")
return t

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Return the scalar '0' to represent the shape of the independent variable `Time`.
See :meth:`pybamm.Symbol.evaluate_for_shape()`
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/input_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def new_copy(self):
""" See :meth:`pybamm.Symbol.new_copy()`. """
return InputParameter(self.name)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Returns the scalar 'NaN' to represent the shape of a parameter.
See :meth:`pybamm.Symbol.evaluate_for_shape()`
Expand Down
4 changes: 2 additions & 2 deletions pybamm/expression_tree/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def new_copy(self):
""" See :meth:`pybamm.Symbol.new_copy()`. """
return Parameter(self.name, self.domain)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Returns the scalar 'NaN' to represent the shape of a parameter.
See :meth:`pybamm.Symbol.evaluate_for_shape()`
Expand Down Expand Up @@ -118,7 +118,7 @@ def _function_parameter_new_copy(self, children):
"""
return FunctionParameter(self.name, *children, diff_variable=self.diff_variable)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Returns the sum of the evaluated children
See :meth:`pybamm.Symbol.evaluate_for_shape()`
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def new_copy(self):
evaluation_array=self.evaluation_array,
)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Returns a vector of NaNs to represent the shape of a StateVector.
The size of a StateVector is the number of True elements in its evaluation_array
Expand Down
8 changes: 8 additions & 0 deletions pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,14 @@ def evaluate_for_shape(self):
shape is returned instead, using the symbol's domain.
See :meth:`pybamm.Symbol.evaluate()`
"""
try:
return self._saved_evaluate_for_shape
except AttributeError:
self._saved_evaluate_for_shape = self._evaluate_for_shape()
return self._saved_evaluate_for_shape

def _evaluate_for_shape(self):
"See :meth:`Symbol.evaluate_for_shape`"
return self.evaluate()

def is_constant(self):
Expand Down
18 changes: 9 additions & 9 deletions pybamm/expression_tree/unary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def evaluate(self, t=None, y=None, u=None, known_evals=None):
child = self.child.evaluate(t, y, u)
return self._unary_evaluate(child)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Default behaviour: unary operator has same shape as child
See :meth:`pybamm.Symbol.evaluate_for_shape()`
Expand Down Expand Up @@ -228,7 +228,7 @@ def _unary_new_copy(self, child):

return self.__class__(child, self.index, check_size=False)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
return self._unary_evaluate(self.children[0].evaluate_for_shape())

def evaluates_on_edges(self):
Expand Down Expand Up @@ -347,7 +347,7 @@ class Mass(SpatialOperator):
def __init__(self, child):
super().__init__("mass", child)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
return pybamm.evaluate_for_shape_using_domain(self.domain, typ="matrix")


Expand All @@ -361,7 +361,7 @@ class BoundaryMass(SpatialOperator):
def __init__(self, child):
super().__init__("boundary mass", child)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
return pybamm.evaluate_for_shape_using_domain(self.domain, typ="matrix")


Expand Down Expand Up @@ -455,7 +455,7 @@ def _unary_new_copy(self, child):

return self.__class__(child, self.integration_variable)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
""" See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()` """
return pybamm.evaluate_for_shape_using_domain(self.domain)

Expand Down Expand Up @@ -501,7 +501,7 @@ def __init__(self, child, integration_variable):
if isinstance(integration_variable, pybamm.SpatialVariable):
self.name += " on {}".format(integration_variable.domain)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
return self.children[0].evaluate_for_shape()


Expand Down Expand Up @@ -550,7 +550,7 @@ def _unary_new_copy(self, child):

return self.__class__(child, vector_type=self.vector_type)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
""" See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()` """
return pybamm.evaluate_for_shape_using_domain(self.domain)

Expand Down Expand Up @@ -611,7 +611,7 @@ def _unary_new_copy(self, child):

return self.__class__(child, region=self.region)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
""" See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()` """
return pybamm.evaluate_for_shape_using_domain(self.domain)

Expand Down Expand Up @@ -722,7 +722,7 @@ def _unary_new_copy(self, child):
""" See :meth:`UnaryOperator._unary_new_copy()`. """
return self.__class__(child, self.side)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
""" See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()` """
return pybamm.evaluate_for_shape_using_domain(
self.domain, self.auxiliary_domains
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def new_copy(self):
""" See :meth:`pybamm.Symbol.new_copy()`. """
return Variable(self.name, self.domain, self.auxiliary_domains)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
""" See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()` """
return pybamm.evaluate_for_shape_using_domain(
self.domain, self.auxiliary_domains
Expand Down

0 comments on commit 23aea3d

Please sign in to comment.