Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 749 cache symbol shape #780

Merged
merged 14 commits into from
Jan 10, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def evaluate(self, t=None, y=None, u=None, known_evals=None):
right = self.right.evaluate(t, y, u)
return self._binary_evaluate(left, right)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
""" See :meth:`pybamm.Symbol.evaluate_for_shape()`. """
left = self.children[0].evaluate_for_shape()
right = self.children[1].evaluate_for_shape()
Expand Down
6 changes: 3 additions & 3 deletions pybamm/expression_tree/broadcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _unary_new_copy(self, child):
""" See :meth:`pybamm.UnaryOperator.simplify()`. """
return PrimaryBroadcast(child, self.broadcast_domain)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Returns a vector of NaNs to represent the shape of a Broadcast.
See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`
Expand Down Expand Up @@ -210,7 +210,7 @@ def _unary_new_copy(self, child):
""" See :meth:`pybamm.UnaryOperator.simplify()`. """
return SecondaryBroadcast(child, self.broadcast_domain)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Returns a vector of NaNs to represent the shape of a Broadcast.
See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`
Expand Down Expand Up @@ -253,7 +253,7 @@ def _unary_new_copy(self, child):
""" See :meth:`pybamm.UnaryOperator.simplify()`. """
return FullBroadcast(child, self.broadcast_domain, self.auxiliary_domains)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Returns a vector of NaNs to represent the shape of a Broadcast.
See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _concatenation_simplify(self, children):
new_symbol.clear_domains()
return new_symbol

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
""" See :meth:`pybamm.Symbol.evaluate_for_shape` """
if len(self.children) == 0:
return np.array([])
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def evaluate(self, t=None, y=None, u=None, known_evals=None):
evaluated_children = [child.evaluate(t, y, u) for child in self.children]
return self._function_evaluate(evaluated_children)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Default behaviour: has same shape as all child
See :meth:`pybamm.Symbol.evaluate_for_shape()`
Expand Down
4 changes: 2 additions & 2 deletions pybamm/expression_tree/independent_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class IndependentVariable(pybamm.Symbol):
def __init__(self, name, domain=None, auxiliary_domains=None):
super().__init__(name, domain=domain, auxiliary_domains=auxiliary_domains)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
""" See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()` """
return pybamm.evaluate_for_shape_using_domain(
self.domain, self.auxiliary_domains
Expand Down Expand Up @@ -57,7 +57,7 @@ def _base_evaluate(self, t, y=None):
raise ValueError("t must be provided")
return t

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Return the scalar '0' to represent the shape of the independent variable `Time`.
See :meth:`pybamm.Symbol.evaluate_for_shape()`
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/input_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def new_copy(self):
""" See :meth:`pybamm.Symbol.new_copy()`. """
return InputParameter(self.name)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Returns the scalar 'NaN' to represent the shape of a parameter.
See :meth:`pybamm.Symbol.evaluate_for_shape()`
Expand Down
4 changes: 2 additions & 2 deletions pybamm/expression_tree/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def new_copy(self):
""" See :meth:`pybamm.Symbol.new_copy()`. """
return Parameter(self.name, self.domain)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Returns the scalar 'NaN' to represent the shape of a parameter.
See :meth:`pybamm.Symbol.evaluate_for_shape()`
Expand Down Expand Up @@ -118,7 +118,7 @@ def _function_parameter_new_copy(self, children):
"""
return FunctionParameter(self.name, *children, diff_variable=self.diff_variable)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Returns the sum of the evaluated children
See :meth:`pybamm.Symbol.evaluate_for_shape()`
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def new_copy(self):
evaluation_array=self.evaluation_array,
)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Returns a vector of NaNs to represent the shape of a StateVector.
The size of a StateVector is the number of True elements in its evaluation_array
Expand Down
8 changes: 8 additions & 0 deletions pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,14 @@ def evaluate_for_shape(self):
shape is returned instead, using the symbol's domain.
See :meth:`pybamm.Symbol.evaluate()`
"""
try:
return self._saved_evaluate_for_shape
except AttributeError:
self._saved_evaluate_for_shape = self._evaluate_for_shape()
return self._saved_evaluate_for_shape

def _evaluate_for_shape(self):
"See :meth:`Symbol.evaluate_for_shape`"
return self.evaluate()

def is_constant(self):
Expand Down
18 changes: 9 additions & 9 deletions pybamm/expression_tree/unary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def evaluate(self, t=None, y=None, u=None, known_evals=None):
child = self.child.evaluate(t, y, u)
return self._unary_evaluate(child)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Default behaviour: unary operator has same shape as child
See :meth:`pybamm.Symbol.evaluate_for_shape()`
Expand Down Expand Up @@ -228,7 +228,7 @@ def _unary_new_copy(self, child):

return self.__class__(child, self.index, check_size=False)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
return self._unary_evaluate(self.children[0].evaluate_for_shape())

def evaluates_on_edges(self):
Expand Down Expand Up @@ -347,7 +347,7 @@ class Mass(SpatialOperator):
def __init__(self, child):
super().__init__("mass", child)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
return pybamm.evaluate_for_shape_using_domain(self.domain, typ="matrix")


Expand All @@ -361,7 +361,7 @@ class BoundaryMass(SpatialOperator):
def __init__(self, child):
super().__init__("boundary mass", child)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
return pybamm.evaluate_for_shape_using_domain(self.domain, typ="matrix")


Expand Down Expand Up @@ -455,7 +455,7 @@ def _unary_new_copy(self, child):

return self.__class__(child, self.integration_variable)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
""" See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()` """
return pybamm.evaluate_for_shape_using_domain(self.domain)

Expand Down Expand Up @@ -501,7 +501,7 @@ def __init__(self, child, integration_variable):
if isinstance(integration_variable, pybamm.SpatialVariable):
self.name += " on {}".format(integration_variable.domain)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
return self.children[0].evaluate_for_shape()


Expand Down Expand Up @@ -550,7 +550,7 @@ def _unary_new_copy(self, child):

return self.__class__(child, vector_type=self.vector_type)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
""" See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()` """
return pybamm.evaluate_for_shape_using_domain(self.domain)

Expand Down Expand Up @@ -611,7 +611,7 @@ def _unary_new_copy(self, child):

return self.__class__(child, region=self.region)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
""" See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()` """
return pybamm.evaluate_for_shape_using_domain(self.domain)

Expand Down Expand Up @@ -722,7 +722,7 @@ def _unary_new_copy(self, child):
""" See :meth:`UnaryOperator._unary_new_copy()`. """
return self.__class__(child, self.side)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
""" See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()` """
return pybamm.evaluate_for_shape_using_domain(
self.domain, self.auxiliary_domains
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def new_copy(self):
""" See :meth:`pybamm.Symbol.new_copy()`. """
return Variable(self.name, self.domain, self.auxiliary_domains)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
""" See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()` """
return pybamm.evaluate_for_shape_using_domain(
self.domain, self.auxiliary_domains
Expand Down